![]() |
Class List
An append-only sequence type which is trackable.
Maintains checkpoint dependencies on its contents (which must also be
trackable), and forwards any Layer
metadata such as updates and losses.
Note that List
is purely a container. It lets a tf.keras.Model
or
other trackable object know about its contents, but does not call any
Layer
instances which are added to it. To indicate a sequence of Layer
instances which should be called sequentially, use tf.keras.Sequential
.
Example usage:
class HasList(tf.keras.Model):
def __init__(self):
super(HasList, self).__init__()
self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
self.layer_list.append(layers.Dense(4))
def call(self, x):
aggregation = 0.
for l in self.layer_list:
x = l(x)
aggregation += tf.reduce_sum(x)
return aggregation
This kind of wrapping is necessary because Trackable
objects do not
(yet) deeply inspect regular Python data structures, so for example assigning
a regular list (self.layer_list = [layers.Dense(3)]
) does not create a
checkpoint dependency and does not add the Layer
instance's weights to its
parent Model
.
__init__
__init__(
*args,
**kwargs
)
Construct a new sequence. Arguments are passed to list()
.
Properties
layers
losses
Aggregate losses from any Layer
instances.
non_trainable_variables
non_trainable_weights
trainable
trainable_variables
trainable_weights
updates
Aggregate updates from any Layer
instances.
variables
weights
Methods
tf.contrib.checkpoint.List.__add__
__add__(other)
__contains__
__contains__(value)
tf.contrib.checkpoint.List.__eq__
__eq__(other)
Return self==value.
tf.contrib.checkpoint.List.__getitem__
__getitem__(key)
__iter__
__iter__()
tf.contrib.checkpoint.List.__len__
__len__()
tf.contrib.checkpoint.List.__mul__
__mul__(n)
tf.contrib.checkpoint.List.__radd__
__radd__(other)
tf.contrib.checkpoint.List.__rmul__
__rmul__(n)
tf.contrib.checkpoint.List.append
append(value)
Add a new trackable value.
tf.contrib.checkpoint.List.copy
copy()
count
count(value)
S.count(value) -> integer -- return number of occurrences of value
tf.contrib.checkpoint.List.extend
extend(values)
Add a sequence of trackable values.
index
index(
value,
start=0,
stop=None
)
S.index(value, [start, [stop]]) -> integer -- return first index of value. Raises ValueError if the value is not present.