![]() |
Class MultiHead
Creates a Head
for multi-objective learning.
Inherits From: Head
Aliases:
This class merges the output of multiple Head
objects. Specifically:
- For training, sums losses of each head, calls
train_op_fn
with this final loss. - For eval, merges metrics by adding
head.name
suffix to the keys in eval metrics, such asprecision/head1.name
,precision/head2.name
. - For prediction, merges predictions and updates keys in prediction dict to a
2-tuple,
(head.name, prediction_key)
. Mergesexport_outputs
such that by default the first head is served.
Usage:
# In `input_fn`, specify labels as a dict keyed by head name:
def input_fn():
features = ...
labels1 = ...
labels2 = ...
return features, {'head1.name': labels1, 'head2.name': labels2}
# In `model_fn`, specify logits as a dict keyed by head name:
def model_fn(features, labels, mode):
# Create simple heads and specify head name.
head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')
head2 = tf.estimator.BinaryClassHead(name='head2')
# Create MultiHead from two simple heads.
head = tf.estimator.MultiHead([head1, head2])
# Create logits for each head, and combine them into a dict.
logits1, logits2 = logit_fn()
logits = {'head1.name': logits1, 'head2.name': logits2}
# Return the merged EstimatorSpec
return head.create_estimator_spec(..., logits=logits, ...)
# Create an estimator with this model_fn.
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=input_fn)
Also supports logits
as a Tensor
of shape
[D0, D1, ... DN, logits_dimension]
. It will split the Tensor
along the
last dimension and distribute it appropriately among the heads. E.g.:
Input logits.
logits = np.array([[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]],
dtype=np.float32)
Suppose head1.logits_dimension = 2 and head2.logits_dimension = 3. After
splitting, the result is:
logits_dict = {'head1_name': [[-1., 1.], [-1.5, 1.]],
'head2_name': [[2., -2., 2.], [-3., 2., -2.]]}
Usage:
def model_fn(features, labels, mode):
# Create simple heads and specify head name.
head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')
head2 = tf.estimator.BinaryClassHead(name='head2')
# Create multi-head from two simple heads.
head = tf.estimator.MultiHead([head1, head2])
# Create logits for the multihead. The result of logits is a `Tensor`.
logits = logit_fn(logits_dimension=head.logits_dimension)
# Return the merged EstimatorSpec
return head.create_estimator_spec(..., logits=logits, ...)
Args:
heads
: List or tuple ofHead
instances. All heads must havename
specified. The first head in the list is the default used at serving time.head_weights
: Optional list of weights, same length asheads
. Used when merging losses to calculate the weighted sum of losses from each head. IfNone
, all losses are weighted equally.
__init__
__init__(
heads,
head_weights=None
)
Initialize self. See help(type(self)) for accurate signature.
Properties
logits_dimension
See base_head.Head
for details.
loss_reduction
See base_head.Head
for details.
name
See base_head.Head
for details.
Methods
tf.estimator.MultiHead.create_estimator_spec
create_estimator_spec(
features,
mode,
logits,
labels=None,
optimizer=None,
trainable_variables=None,
train_op_fn=None,
update_ops=None,
regularization_losses=None
)
Returns a model_fn.EstimatorSpec
.
Args:
features
: Inputdict
ofTensor
orSparseTensor
objects.mode
: Estimator'sModeKeys
.logits
: Inputdict
keyed by head name, or logitsTensor
with shape[D0, D1, ... DN, logits_dimension]
. For many applications, theTensor
shape is[batch_size, logits_dimension]
. If logits is aTensor
, it will split theTensor
along the last dimension and distribute it appropriately among the heads. CheckMultiHead
for examples.labels
: Inputdict
keyed by head name. For each head, the label value can be integer or stringTensor
with shape matching its correspondinglogits
.labels
is a required argument whenmode
equalsTRAIN
orEVAL
.optimizer
: Antf.keras.optimizers.Optimizer
instance to optimize the loss in TRAIN mode. Namely, setstrain_op = optimizer.get_updates(loss, trainable_variables)
, which updates variables to minimizeloss
.trainable_variables
: A list or tuple ofVariable
objects to update to minimizeloss
. In Tensorflow 1.x, by default these are the list of variables collected in the graph under the keyGraphKeys.TRAINABLE_VARIABLES
. As Tensorflow 2.x doesn't have collections and GraphKeys, trainable_variables need to be passed explicitly here.train_op_fn
: Function that takes a scalar lossTensor
and returnstrain_op
. Used ifoptimizer
isNone
.update_ops
: A list or tuple of update ops to be run at training time. For example, layers such as BatchNormalization create mean and variance update ops that need to be run at training time. In Tensorflow 1.x, these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have collections, update_ops need to be passed explicitly here.regularization_losses
: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results, in each head, users need to use the defaultloss_reduction=SUM_OVER_BATCH_SIZE
to avoid scaling errors. Compared to the regularization losses for each head, this loss is to regularize the merged loss of all heads in multi head, and will be added to the overall training loss of multi head.
Returns:
A model_fn.EstimatorSpec
instance.
Raises:
ValueError
: If bothtrain_op_fn
andoptimizer
areNone
in TRAIN mode, or if both are set. Ifmode
is not in Estimator'sModeKeys
.
tf.estimator.MultiHead.loss
loss(
labels,
logits,
features=None,
mode=None,
regularization_losses=None
)
Returns regularized training loss. See base_head.Head
for details.
tf.estimator.MultiHead.metrics
metrics(regularization_losses=None)
Creates metrics. See base_head.Head
for details.
tf.estimator.MultiHead.predictions
predictions(
logits,
keys=None
)
Create predictions. See base_head.Head
for details.
tf.estimator.MultiHead.update_metrics
update_metrics(
eval_metrics,
features,
logits,
labels,
regularization_losses=None
)
Updates eval metrics. See base_head.Head
for details.