stouputils.data_science.models.keras_utils.callbacks.learning_rate_finder module#

class LearningRateFinder(min_lr: float, max_lr: float, steps_per_epoch: int, epochs: int, update_per_epoch: bool = False, update_interval: int = 5)[source]#

Bases: Callback

Callback to find optimal learning rate by increasing LR during training.

Sources: - Inspired by: WittmannF/LRFinder - Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 (first description of the method)

This callback gradually increases the learning rate from a minimum to a maximum value during training, allowing you to identify the optimal learning rate range for your model.

It works by:

  1. Starting with a very small learning rate

  2. Exponentially increasing it after each batch or epoch

  3. Recording the loss at each learning rate

  4. Restoring the model’s initial weights after training

The optimal learning rate is typically found where the loss is decreasing most rapidly before it starts to diverge.

Learning rate finder curve example
min_lr: float#

Minimum learning rate.

max_lr: float#

Maximum learning rate.

total_updates: int#

Total number of update steps (considering update_interval).

update_per_epoch: bool#

Whether to update learning rate per epoch instead of per batch.

update_interval: int#

Number of steps between each lr increase, bigger value means more stable loss.

lr_mult: float#

Learning rate multiplier.

learning_rates: list[float]#

List of learning rates.

losses: list[float]#

List of losses.

best_lr: float#

Best learning rate.

best_loss: float#

Best loss.

model: Model#

Model to apply the learning rate finder to.

initial_weights: list[Any] | None#

Stores the initial weights of the model.

on_train_begin(logs: dict[str, Any] | None = None) None[source]#

Set initial learning rate and save initial model weights at the start of training.

Parameters:

logs (dict | None) – Training logs.

_update_lr_and_track_metrics(logs: dict[str, Any] | None = None) None[source]#

Update learning rate and track metrics.

Parameters:

logs (dict | None) – Logs from training

on_batch_end(batch: int, logs: dict[str, Any] | None = None) None[source]#

Record loss and increase learning rate after each batch if not updating per epoch.

Parameters:
  • batch (int) – Current batch index.

  • logs (dict | None) – Training logs.

on_epoch_end(epoch: int, logs: dict[str, Any] | None = None) None[source]#

Record loss and increase learning rate after each epoch if updating per epoch.

Parameters:
  • epoch (int) – Current epoch index.

  • logs (dict | None) – Training logs.

on_train_end(logs: dict[str, Any] | None = None) None[source]#

Restore initial model weights at the end of training.

Parameters:

logs (dict | None) – Training logs.