![]() |
Class CosineDecayRestarts
A LearningRateSchedule that uses a cosine decay schedule with restarts.
Inherits From: LearningRateSchedule
Aliases:
- Class
tf.compat.v1.keras.experimental.CosineDecayRestarts
- Class
tf.compat.v2.keras.experimental.CosineDecayRestarts
__init__
__init__(
initial_learning_rate,
first_decay_steps,
t_mul=2.0,
m_mul=1.0,
alpha=0.0,
name=None
)
Applies cosine decay with restarts to the learning rate.
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent with Warm Restarts. https://arxiv.org/abs/1608.03983
When training a model, it is often recommended to lower the learning rate as
the training progresses. This schedule applies a cosine decay function with
restarts to an optimizer step, given a provided initial learning rate.
It requires a step
value to compute the decayed learning rate. You can
just pass a TensorFlow variable that you increment at each training step.
The schedule a 1-arg callable that produces a decayed learning rate when passed the current optimizer step. This can be useful for changing the learning rate value across different invocations of optimizer functions.
The learning rate multiplier first decays
from 1 to alpha
for first_decay_steps
steps. Then, a warm
restart is performed. Each new warm restart runs for t_mul
times more
steps and with m_mul
times smaller initial learning rate.
Example usage:
first_decay_steps = 1000
lr_decayed_fn = (
tf.keras.experimental.CosineDecayRestarts(
initial_learning_rate,
first_decay_steps))
You can pass this schedule directly into a tf.keras.optimizers.Optimizer
as the learning rate. The learning rate schedule is also serializable and
deserializable using tf.keras.optimizers.schedules.serialize
and
tf.keras.optimizers.schedules.deserialize
.
Args:
initial_learning_rate
: A scalarfloat32
orfloat64
Tensor or a Python number. The initial learning rate.first_decay_steps
: A scalarint32
orint64
Tensor
or a Python number. Number of steps to decay over.t_mul
: A scalarfloat32
orfloat64
Tensor
or a Python number. Used to derive the number of iterations in the i-th periodm_mul
: A scalarfloat32
orfloat64
Tensor
or a Python number. Used to derive the initial learning rate of the i-th period:alpha
: A scalarfloat32
orfloat64
Tensor or a Python number. Minimum learning rate value as a fraction of the initial_learning_rate.name
: String. Optional name of the operation. Defaults to 'SGDRDecay'.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar Tensor
of the same
type as initial_learning_rate
.
Methods
tf.keras.experimental.CosineDecayRestarts.__call__
__call__(step)
Call self as a function.
tf.keras.experimental.CosineDecayRestarts.from_config
from_config(
cls,
config
)
Instantiates a LearningRateSchedule
from its config.
Args:
config
: Output ofget_config()
.
Returns:
A LearningRateSchedule
instance.
tf.keras.experimental.CosineDecayRestarts.get_config
get_config()