Source code for stouputils.data_science.models.keras_utils.callbacks.warmup_scheduler
# pyright: reportMissingTypeStubs=false
# Imports
from typing import Any
import tensorflow as tf
from keras.callbacks import Callback
from keras.models import Model
[docs]
class WarmupScheduler(Callback):
""" Keras Callback for learning rate warmup.
Sources:
- Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
- Attention Is All You Need: https://arxiv.org/abs/1706.03762
This callback implements a learning rate warmup strategy where the learning rate
gradually increases from an initial value to a target value over a specified
number of epochs. This helps stabilize training in the early stages.
The learning rate increases linearly from the initial value to the target value
over the warmup period, and then remains at the target value.
"""
def __init__(self, warmup_epochs: int, initial_lr: float, target_lr: float) -> None:
""" Initialize the warmup scheduler.
Args:
warmup_epochs (int): Number of epochs for warmup.
initial_lr (float): Starting learning rate for warmup.
target_lr (float): Target learning rate after warmup.
"""
super().__init__()
self.warmup_epochs: int = warmup_epochs
""" Number of epochs for warmup. """
self.initial_lr: float = initial_lr
""" Starting learning rate for warmup. """
self.target_lr: float = target_lr
""" Target learning rate after warmup. """
self.model: Model
""" Model to apply the warmup scheduler to. """
# Pre-compute learning rates for each epoch to avoid calculations during training
self.epoch_learning_rates: list[float] = []
for epoch in range(warmup_epochs + 1):
if epoch < warmup_epochs:
lr = initial_lr + (target_lr - initial_lr) * (epoch + 1) / warmup_epochs
else:
lr = target_lr
self.epoch_learning_rates.append(lr)
[docs]
def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
""" Adjust learning rate at the beginning of each epoch during warmup.
Args:
epoch (int): Current epoch index.
logs (dict | None): Training logs.
"""
if self.warmup_epochs <= 0 or epoch > self.warmup_epochs:
return
# Use pre-computed learning rate to avoid calculations during training
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.epoch_learning_rates[epoch]) # type: ignore