Source code for stouputils.data_science.models.keras_utils.callbacks.model_checkpoint_v2
# pyright: reportMissingTypeStubs=false
# pyright: reportUnknownMemberType=false
# Imports
from typing import Any
from keras.callbacks import ModelCheckpoint
[docs]
class ModelCheckpointV2(ModelCheckpoint):
""" Model checkpoint callback but only starts after a given number of epochs.
Args:
epochs_before_start (int): Number of epochs before starting the checkpointing
"""
def __init__(self, epochs_before_start: int = 3, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.epochs_before_start = epochs_before_start
self.current_epoch = 0
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
if self.current_epoch >= self.epochs_before_start:
super().on_batch_end(batch, logs)
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
self.current_epoch = epoch
if epoch >= self.epochs_before_start:
super().on_epoch_end(epoch, logs)