Source code for stouputils.data_science.models.keras_utils.callbacks.learning_rate_finder
# pyright: reportMissingTypeStubs=false
# Imports
from typing import Any
import tensorflow as tf
from keras.callbacks import Callback
from keras.models import Model
[docs]
class LearningRateFinder(Callback):
""" Callback to find optimal learning rate by increasing LR during training.
Sources:
- Inspired by: https://github.com/WittmannF/LRFinder
- Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 (first description of the method)
This callback gradually increases the learning rate from a minimum to a maximum value
during training, allowing you to identify the optimal learning rate range for your model.
It works by:
1. Starting with a very small learning rate
2. Exponentially increasing it after each batch or epoch
3. Recording the loss at each learning rate
4. Restoring the model's initial weights after training
The optimal learning rate is typically found where the loss is decreasing most rapidly
before it starts to diverge.
.. image:: https://blog.dataiku.com/hubfs/training%20loss.png
:alt: Learning rate finder curve example
"""
def __init__(
self,
min_lr: float,
max_lr: float,
steps_per_epoch: int,
epochs: int,
update_per_epoch: bool = False,
update_interval: int = 5
) -> None:
""" Initialize the learning rate finder.
Args:
min_lr (float): Minimum learning rate
max_lr (float): Maximum learning rate
steps_per_epoch (int): Steps per epoch
epochs (int): Number of epochs
update_per_epoch (bool): If True, update LR once per epoch instead of every batch.
update_interval (int): Number of steps between each lr increase, bigger value means more stable loss.
"""
super().__init__()
self.min_lr: float = min_lr
""" Minimum learning rate. """
self.max_lr: float = max_lr
""" Maximum learning rate. """
self.total_updates: int = (epochs if update_per_epoch else steps_per_epoch * epochs) // update_interval
""" Total number of update steps (considering update_interval). """
self.update_per_epoch: bool = update_per_epoch
""" Whether to update learning rate per epoch instead of per batch. """
self.update_interval: int = max(1, int(update_interval))
""" Number of steps between each lr increase, bigger value means more stable loss. """
self.lr_mult: float = (max_lr / min_lr) ** (1 / self.total_updates)
""" Learning rate multiplier. """
self.learning_rates: list[float] = []
""" List of learning rates. """
self.losses: list[float] = []
""" List of losses. """
self.best_lr: float = min_lr
""" Best learning rate. """
self.best_loss: float = float("inf")
""" Best loss. """
self.model: Model
""" Model to apply the learning rate finder to. """
self.initial_weights: list[Any] | None = None
""" Stores the initial weights of the model. """
[docs]
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
""" Set initial learning rate and save initial model weights at the start of training.
Args:
logs (dict | None): Training logs.
"""
self.initial_weights = self.model.get_weights()
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.min_lr) # type: ignore
[docs]
def _update_lr_and_track_metrics(self, logs: dict[str, Any] | None = None) -> None:
""" Update learning rate and track metrics.
Args:
logs (dict | None): Logs from training
"""
if logs is None:
return
# Get current learning rate and loss
current_lr: float = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
current_loss: float = logs["loss"]
# Record values
self.learning_rates.append(current_lr)
self.losses.append(current_loss)
# Track best values
if current_loss < self.best_loss:
self.best_loss = current_loss
self.best_lr = current_lr
# Update learning rate
new_lr: float = current_lr * self.lr_mult
tf.keras.backend.set_value(self.model.optimizer.learning_rate, new_lr) # type: ignore
[docs]
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
""" Record loss and increase learning rate after each batch if not updating per epoch.
Args:
batch (int): Current batch index.
logs (dict | None): Training logs.
"""
if self.update_per_epoch:
return
if batch % self.update_interval == 0:
self._update_lr_and_track_metrics(logs)
[docs]
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
""" Record loss and increase learning rate after each epoch if updating per epoch.
Args:
epoch (int): Current epoch index.
logs (dict | None): Training logs.
"""
if not self.update_per_epoch:
return
if epoch % self.update_interval == 0:
self._update_lr_and_track_metrics(logs)
[docs]
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
""" Restore initial model weights at the end of training.
Args:
logs (dict | None): Training logs.
"""
if self.initial_weights is not None:
self.model.set_weights(self.initial_weights) # pyright: ignore [reportUnknownMemberType]