![]() |
"Builds input layer for sequence input.
tf.contrib.feature_column.sequence_input_layer(
features,
feature_columns,
weight_collections=None,
trainable=True
)
All feature_columns
must be sequence dense columns with the same
sequence_length
. The output of this method can be fed into sequence
networks, such as RNN.
The output of this method is a 3D Tensor
of shape [batch_size, T, D]
.
T
is the maximum sequence length for this batch, which could differ from
batch to batch.
If multiple feature_columns
are given with Di
num_elements
each, their
outputs are concatenated. So, the final Tensor
has shape
[batch_size, T, D0 + D1 + ... + Dn]
.
Example:
rating = sequence_numeric_column('rating')
watches = sequence_categorical_column_with_identity(
'watches', num_buckets=1000)
watches_embedding = embedding_column(watches, dimension=10)
columns = [rating, watches]
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
input_layer, sequence_length = sequence_input_layer(features, columns)
rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
outputs, state = tf.compat.v1.nn.dynamic_rnn(
rnn_cell, inputs=input_layer, sequence_length=sequence_length)
Args:
features
: A dict mapping keys to tensors.feature_columns
: An iterable of dense sequence columns. Valid columns areembedding_column
that wraps asequence_categorical_column_with_*
sequence_numeric_column
.
weight_collections
: A list of collection names to which the Variable will be added. Note that variables will also be added to collectionstf.GraphKeys.GLOBAL_VARIABLES
andops.GraphKeys.MODEL_VARIABLES
.trainable
: IfTrue
also add the variable to the graph collectionGraphKeys.TRAINABLE_VARIABLES
.
Returns:
An (input_layer, sequence_length)
tuple where:
- input_layer: A float Tensor
of shape [batch_size, T, D]
.
T
is the maximum sequence length for this batch, which could differ
from batch to batch. D
is the sum of num_elements
for all
feature_columns
.
- sequence_length: An int Tensor
of shape [batch_size]
. The sequence
length for each example.
Raises:
ValueError
: If any of thefeature_columns
is the wrong type.