![]() |
Class Bidirectional
Bidirectional wrapper for RNNs.
Inherits From: Wrapper
Aliases:
Arguments:
layer
:Recurrent
instance.merge_mode
: Mode by which outputs of the forward and backward RNNs will be combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the outputs will not be combined, they will be returned as a list.backward_layer
: OptionalRecurrent
instance to be used to handle backwards input processing. Ifbackward_layer
is not provided, the layer instance passed as thelayer
argument will be used to generate the backward layer automatically. Note that the providedbackward_layer
layer should have properties matching those of thelayer
argument, in particular it should have the same values forstateful
,return_states
,return_sequence
, etc. In addition,backward_layer
andlayer
should have differentgo_backwards
argument values. AValueError
will be raised if these requirements are not met.
Call arguments:
The call arguments for this layer are the same as those of the wrapped RNN layer.
Raises:
ValueError
: 1. Iflayer
orbackward_layer
is not aLayer
instance.- In case of invalid
merge_mode
argument. - If
backward_layer
has mismatched properties compared tolayer
.
- In case of invalid
Examples:
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10)))
model.add(Bidirectional(LSTM(10)))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# With custom backward layer
model = Sequential()
forward_layer = LSTM(10, return_sequences=True)
backard_layer = LSTM(10, activation='relu', return_sequences=True,
go_backwards=True)
model.add(Bidirectional(forward_layer, backward_layer=backward_layer,
input_shape=(5, 10)))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
__init__
__init__(
layer,
merge_mode='concat',
weights=None,
backward_layer=None,
**kwargs
)
Properties
constraints
Methods
tf.keras.layers.Bidirectional.reset_states
reset_states()