Source code for stouputils.data_science.models.keras_utils.callbacks.progressive_unfreezing


# pyright: reportMissingTypeStubs=false

# Imports
from collections.abc import Callable
from typing import Any

from keras.callbacks import Callback
from keras.models import Model
from keras.optimizers import Optimizer


[docs] class ProgressiveUnfreezing(Callback): """ Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training. Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model. Prefer doing your own training loop instead. This callback can operate in two modes: 1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False) 2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True) """ def __init__( self, base_model: Model, steps_per_epoch: int, epochs: int, reset_weights: bool = False, reset_optimizer_function: Callable[[], Optimizer] | None = None, update_per_epoch: bool = True, update_interval: int = 5, progressive_freeze: bool = False ) -> None: """ Initialize the progressive unfreezing callback. Args: base_model (Model): Base model to unfreeze. steps_per_epoch (int): Number of steps per epoch. epochs (int): Total number of epochs. reset_weights (bool): If True, reset weights after each unfreeze. reset_optimizer_function (Callable | None): If set, use this function to reset the optimizer every update_interval. The function should return a compiled optimizer, e.g. `lambda: model._get_optimizer(AdamW(...))`. update_per_epoch (bool): If True, unfreeze per epoch, else per batch. update_interval (int): Number of steps between each unfreeze to allow model to stabilize. progressive_freeze (bool): If True, start with all layers unfrozen and progressively freeze them. """ super().__init__() self.base_model: Model = base_model """ Base model to unfreeze. """ self.model: Model """ Model to apply the progressive unfreezing to. """ self.steps_per_epoch: int = int(steps_per_epoch) """ Number of steps per epoch. """ self.epochs: int = int(epochs) """ Total number of epochs. """ self.reset_weights: bool = bool(reset_weights) """ If True, reset weights after each unfreeze. """ self.reset_optimizer_function: Callable[[], Optimizer] | None = reset_optimizer_function """ If reset_weights is True and this is not None, use this function to get a new optimizer. """ self.update_per_epoch: bool = bool(update_per_epoch) """ If True, unfreeze per epoch, else per batch. """ self.update_interval: int = max(1, int(update_interval)) """ Number of steps between each unfreeze to allow model to stabilize. """ self.progressive_freeze: bool = bool(progressive_freeze) """ If True, start with all layers unfrozen and progressively freeze them. """ # If updating per epoch, remove to self.epochs the update interval to allow the last step to train with 100% unfreeze if self.update_per_epoch: self.epochs -= self.update_interval # Calculate total steps considering the update interval total_steps_raw: int = self.epochs if self.update_per_epoch else self.steps_per_epoch * self.epochs self.total_steps: int = total_steps_raw // self.update_interval """ Total number of update steps (considering update_interval). """ self.fraction_unfrozen: list[float] = [] """ Fraction of layers unfrozen. """ self.losses: list[float] = [] """ Losses. """ self._all_layers: list[Any] = [] """ All layers. """ self._initial_trainable: list[bool] = [] """ Initial trainable states. """ self._initial_weights: list[Any] | None = None """ Initial weights of the model. """ self._last_update_step: int = -1 """ Last step when layers were unfrozen. """ self.params: dict[str, Any]
[docs] def on_train_begin(self, logs: dict[str, Any] | None = None) -> None: """ Set initial layer trainable states at the start of training and store initial states and weights. Args: logs (dict | None): Training logs. """ # Collect all layers from the model and preserve their original trainable states for potential restoration self._all_layers = self.base_model.layers self._initial_trainable = [bool(layer.trainable) for layer in self._all_layers] # Store initial weights to reset after each unfreeze if self.reset_weights: self._initial_weights = self.model.get_weights() # Set initial trainable state based on mode for layer in self._all_layers: layer.trainable = self.progressive_freeze # If progressive_freeze, start with all layers unfrozen
[docs] def _update_layers(self, step: int) -> None: """ Update layer trainable states based on the current step and mode. Reset weights after each update to prevent bias in the results. Args: step (int): Current training step. """ # Calculate the effective step considering the update interval effective_step: int = step // self.update_interval # Skip if we haven't reached the next update interval if effective_step <= self._last_update_step: return self._last_update_step = effective_step # Calculate the number of layers to unfreeze based on current effective step n_layers: int = len(self._all_layers) if self.progressive_freeze: # For progressive freezing, start at 1.0 (all unfrozen) and decrease to 0.0 fraction: float = max(0.0, 1.0 - (effective_step + 1) / self.total_steps) else: # For progressive unfreezing, start at 0.0 (all frozen) and increase to 1.0 fraction: float = min(1.0, (effective_step + 1) / self.total_steps) n_unfreeze: int = int(n_layers * fraction) # Number of layers to keep unfrozen self.fraction_unfrozen.append(fraction) # Set trainable state for each layer based on position # For both modes, we unfreeze from the top (output layers) to the bottom (input layers) for i, layer in enumerate(self._all_layers): layer.trainable = i >= (n_layers - n_unfreeze) # Reset weights to initial state to prevent bias and reset optimizer if self._initial_weights is not None: self.model.set_weights(self._initial_weights) # pyright: ignore [reportUnknownMemberType] if self.reset_optimizer_function is not None: self.model.optimizer = self.reset_optimizer_function() self.model.optimizer.build(self.model.trainable_variables) # pyright: ignore [reportUnknownMemberType]
[docs] def _track_loss(self, logs: dict[str, Any] | None = None) -> None: """ Track the current loss. Args: logs (dict | None): Training logs containing loss information. """ if logs and "loss" in logs: self.losses.append(logs["loss"])
[docs] def on_batch_begin(self, batch: int, logs: dict[str, Any] | None = None) -> None: """ Update layer trainable states at the start of each batch if not updating per epoch. Args: batch (int): Current batch index. logs (dict | None): Training logs. """ # Skip if we're updating per epoch instead of per batch if self.update_per_epoch: return # Calculate the current step across all epochs and update layers step: int = self.params.get("steps", self.steps_per_epoch) * self.params.get("epoch", 0) + batch self._update_layers(step)
[docs] def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None: """ Track loss at the end of each batch if not updating per epoch. Args: batch (int): Current batch index. logs (dict | None): Training logs. """ # Skip if we're updating per epoch instead of per batch if self.update_per_epoch: return # Record the loss if update interval is reached if batch % self.update_interval == 0: self._track_loss(logs)
[docs] def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None: """ Update layer trainable states at the start of each epoch if updating per epoch. Args: epoch (int): Current epoch index. logs (dict | None): Training logs. """ # Skip if we're updating per batch instead of per epoch if not self.update_per_epoch: return # Update layers based on current epoch self._update_layers(epoch)
[docs] def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None: """ Track loss at the end of each epoch if updating per epoch. Args: epoch (int): Current epoch index. logs (dict | None): Training logs. """ # Skip if we're updating per batch instead of per epoch if not self.update_per_epoch: return # Record the loss if update interval is reached if epoch % self.update_interval == 0: self._track_loss(logs)
[docs] def on_train_end(self, logs: dict[str, Any] | None = None) -> None: """ Restore original trainable states at the end of training. Args: logs (dict | None): Training logs. """ # Restore each layer's original trainable state for layer, trainable in zip(self._all_layers, self._initial_trainable, strict=False): layer.trainable = trainable
[docs] def get_results(self, multiply_by_100: bool = True) -> tuple[list[float], list[float]]: """ Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True. Args: multiply_by_100 (bool): If True, multiply the fractions by 100 to get percentages. Returns: tuple[list[float], list[float]]: fractions of layers unfrozen, and losses. """ fractions: list[float] = self.fraction_unfrozen # Reverse the order if progressive_freeze is True if self.progressive_freeze: fractions = fractions[::-1] # Multiply by 100 if requested if multiply_by_100: fractions = [x * 100 for x in fractions] # Return the results return fractions, self.losses