stouputils.data_science.models.keras_utils.callbacks.progressive_unfreezing module#

class ProgressiveUnfreezing(base_model: Model, steps_per_epoch: int, epochs: int, reset_weights: bool = False, reset_optimizer_function: Callable[[], Optimizer] | None = None, update_per_epoch: bool = True, update_interval: int = 5, progressive_freeze: bool = False)[source]#

Bases: Callback

Callback inspired by the Learning Rate Finder to progressively unfreeze model layers during training.

Warning: This callback is not compatible with model.fit() as it modifies the trainable state of the model. Prefer doing your own training loop instead.

This callback can operate in two modes: 1. Start with all layers frozen and incrementally unfreeze them from 0% to 100% (progressive_freeze=False) 2. Start with all layers unfrozen and incrementally freeze them from 100% to 0% (progressive_freeze=True)

base_model: Model#

Base model to unfreeze.

model: Model#

Model to apply the progressive unfreezing to.

steps_per_epoch: int#

Number of steps per epoch.

epochs: int#

Total number of epochs.

reset_weights: bool#

If True, reset weights after each unfreeze.

reset_optimizer_function: Callable[[], Optimizer] | None#

If reset_weights is True and this is not None, use this function to get a new optimizer.

update_per_epoch: bool#

If True, unfreeze per epoch, else per batch.

update_interval: int#

Number of steps between each unfreeze to allow model to stabilize.

progressive_freeze: bool#

If True, start with all layers unfrozen and progressively freeze them.

total_steps: int#

Total number of update steps (considering update_interval).

fraction_unfrozen: list[float]#

Fraction of layers unfrozen.

losses: list[float]#

Losses.

_all_layers: list[Any]#

All layers.

_initial_trainable: list[bool]#

Initial trainable states.

_initial_weights: list[Any] | None#

Initial weights of the model.

_last_update_step: int#

Last step when layers were unfrozen.

on_train_begin(logs: dict[str, Any] | None = None) None[source]#

Set initial layer trainable states at the start of training and store initial states and weights.

Parameters:

logs (dict | None) – Training logs.

_update_layers(step: int) None[source]#

Update layer trainable states based on the current step and mode. Reset weights after each update to prevent bias in the results.

Parameters:

step (int) – Current training step.

_track_loss(logs: dict[str, Any] | None = None) None[source]#

Track the current loss.

Parameters:

logs (dict | None) – Training logs containing loss information.

on_batch_begin(batch: int, logs: dict[str, Any] | None = None) None[source]#

Update layer trainable states at the start of each batch if not updating per epoch.

Parameters:
  • batch (int) – Current batch index.

  • logs (dict | None) – Training logs.

on_batch_end(batch: int, logs: dict[str, Any] | None = None) None[source]#

Track loss at the end of each batch if not updating per epoch.

Parameters:
  • batch (int) – Current batch index.

  • logs (dict | None) – Training logs.

on_epoch_begin(epoch: int, logs: dict[str, Any] | None = None) None[source]#

Update layer trainable states at the start of each epoch if updating per epoch.

Parameters:
  • epoch (int) – Current epoch index.

  • logs (dict | None) – Training logs.

on_epoch_end(epoch: int, logs: dict[str, Any] | None = None) None[source]#

Track loss at the end of each epoch if updating per epoch.

Parameters:
  • epoch (int) – Current epoch index.

  • logs (dict | None) – Training logs.

on_train_end(logs: dict[str, Any] | None = None) None[source]#

Restore original trainable states at the end of training.

Parameters:

logs (dict | None) – Training logs.

get_results(multiply_by_100: bool = True) tuple[list[float], list[float]][source]#

Get the results of the progressive unfreezing from 0% to 100% even if progressive_freeze is True.

Parameters:

multiply_by_100 (bool) – If True, multiply the fractions by 100 to get percentages.

Returns:

fractions of layers unfrozen, and losses.

Return type:

tuple[list[float], list[float]]