tf.contrib.checkpoint.List

View source on GitHub

Class List

An append-only sequence type which is trackable.

Maintains checkpoint dependencies on its contents (which must also be trackable), and forwards any Layer metadata such as updates and losses.

Note that List is purely a container. It lets a tf.keras.Model or other trackable object know about its contents, but does not call any Layer instances which are added to it. To indicate a sequence of Layer instances which should be called sequentially, use tf.keras.Sequential.

Example usage:

class HasList(tf.keras.Model):

  def __init__(self):
    super(HasList, self).__init__()
    self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
    self.layer_list.append(layers.Dense(4))

  def call(self, x):
    aggregation = 0.
    for l in self.layer_list:
      x = l(x)
      aggregation += tf.reduce_sum(x)
    return aggregation

This kind of wrapping is necessary because Trackable objects do not (yet) deeply inspect regular Python data structures, so for example assigning a regular list (self.layer_list = [layers.Dense(3)]) does not create a checkpoint dependency and does not add the Layer instance's weights to its parent Model.

__init__

View source

__init__(
    *args,
    **kwargs
)

Construct a new sequence. Arguments are passed to list().

Properties

layers

losses

Aggregate losses from any Layer instances.

non_trainable_variables

non_trainable_weights

trainable

trainable_variables

trainable_weights

updates

Aggregate updates from any Layer instances.

variables

weights

Methods

tf.contrib.checkpoint.List.__add__

View source

__add__(other)

__contains__

__contains__(value)

tf.contrib.checkpoint.List.__eq__

View source

__eq__(other)

Return self==value.

tf.contrib.checkpoint.List.__getitem__

View source

__getitem__(key)

__iter__

__iter__()

tf.contrib.checkpoint.List.__len__

View source

__len__()

tf.contrib.checkpoint.List.__mul__

View source

__mul__(n)

tf.contrib.checkpoint.List.__radd__

View source

__radd__(other)

tf.contrib.checkpoint.List.__rmul__

View source

__rmul__(n)

tf.contrib.checkpoint.List.append

View source

append(value)

Add a new trackable value.

tf.contrib.checkpoint.List.copy

View source

copy()

count

count(value)

S.count(value) -> integer -- return number of occurrences of value

tf.contrib.checkpoint.List.extend

View source

extend(values)

Add a sequence of trackable values.

index

index(
    value,
    start=0,
    stop=None
)

S.index(value, [start, [stop]]) -> integer -- return first index of value. Raises ValueError if the value is not present.