Source code for stouputils.data_science.models.base_keras

""" Keras-specific model implementation with TensorFlow integration.
Provides concrete implementations for Keras model operations.

Features:

- Transfer learning layer freezing/unfreezing
- Keras-specific callbacks (early stopping, LR reduction)
- Model checkpointing/weight management
- GPU-optimized prediction pipelines
- Keras metric/loss configuration
- Model serialization/deserialization

Implements ModelInterface for Keras-based models.
"""
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
# pyright: reportUnknownArgumentType=false
# pyright: reportArgumentType=false
# pyright: reportCallIssue=false
# pyright: reportMissingTypeStubs=false
# pyright: reportOptionalMemberAccess=false
# pyright: reportOptionalCall=false

# Imports
import gc
import multiprocessing
import multiprocessing.queues
import os
from collections.abc import Iterable
from tempfile import TemporaryDirectory
from typing import Any

import mlflow
import mlflow.keras
import numpy as np
import tensorflow as tf
from keras.backend import clear_session
from keras.callbacks import Callback, CallbackList, EarlyStopping, History, ReduceLROnPlateau, TensorBoard
from keras.layers import Dense, GlobalAveragePooling2D
from keras.losses import CategoricalCrossentropy, CategoricalFocalCrossentropy, Loss
from keras.metrics import AUC, CategoricalAccuracy, F1Score, Metric
from keras.models import Model, Sequential
from keras.optimizers import Adam, AdamW, Lion, Optimizer
from keras.utils import set_random_seed
from numpy.typing import NDArray

from ...decorators import measure_time
from ...print import info, warning, debug, progress
from ...parallel import colored_for_loop
from ...ctx import Muffle
from .. import mlflow_utils
from ..config.get import DataScienceConfig
from ..dataset import Dataset, GroupingStrategy
from ..utils import Utils
from .keras_utils.callbacks import ColoredProgressBar, LearningRateFinder, ModelCheckpointV2, ProgressiveUnfreezing, WarmupScheduler
from .keras_utils.losses import NextGenerationLoss
from .keras_utils.visualizations import all_visualizations_for_image
from .model_interface import ModelInterface


[docs] class BaseKeras(ModelInterface): """ Base class for Keras models with common functionality. """
[docs] def class_load(self) -> None: """ Clear the session and collect garbage, reset random seeds and call the parent class method. """ super().class_load() clear_session() gc.collect() set_random_seed(DataScienceConfig.SEED) self.final_model: Model
[docs] def _fit( self, model: Model, x: Any, y: Any | None = None, validation_data: tuple[Any, Any] | None = None, shuffle: bool = True, batch_size: int | None = None, epochs: int = 1, callbacks: list[Callback] | None = None, class_weight: dict[int, float] | None = None, verbose: int = 0, *args: Any, **kwargs: Any ) -> History: """ Manually fit the model with a custom training loop instead of using model.fit(). This method implements a custom training loop for more control over the training process. It's useful for implementing custom training behaviors that aren't easily done with model.fit() such as unfreezing layers during training, resetting the optimizer, etc. Args: model (Model): The model to train x (Any): Training data inputs y (Any | None): Training data targets validation_data (tuple[Any, Any] | None): Validation data as a tuple of (inputs, targets) shuffle (bool): Whether to shuffle the training data every epoch batch_size (int | None): Number of samples per gradient update. epochs (int): Number of epochs to train the model. callbacks (list[Callback] | None): List of callbacks to apply during training. class_weight (dict[int, float] | None): Optional dictionary mapping class indices to weights. verbose (int): Verbosity mode. Returns: History: Training history """ # Set TensorFlow to use the XLA compiler tf.config.optimizer.set_jit(True) # Build training dataset if y is None and isinstance(x, tf.data.Dataset): train_dataset: tf.data.Dataset = x else: train_dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices((x, y) if y is not None else x) # Optimize dataset pipeline if shuffle: buffer_size: int = len(x) if hasattr(x, '__len__') else 10000 buffer_size = min(buffer_size, 50000) train_dataset = train_dataset.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=True) if batch_size is not None: train_dataset = train_dataset.batch(batch_size) train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) # Handle validation data val_dataset: tf.data.Dataset | None = None if validation_data is not None: x_val, y_val = validation_data val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)) if batch_size is not None: val_dataset = val_dataset.batch(batch_size) val_dataset = val_dataset.cache().prefetch(tf.data.AUTOTUNE) # Handle callbacks callback_list: CallbackList = CallbackList( callbacks, add_history=True, add_progbar=verbose != 0, model=model, verbose=verbose, epochs=epochs, steps=tf.data.experimental.cardinality(train_dataset).numpy(), ) # Precompute class weights tensor outside the training loop class_weight_tensor: NDArray[Any] | None = None if class_weight: class_weight_values: list[float] = [float(class_weight.get(i, 1.0)) for i in range(self.num_classes)] class_weight_tensor = tf.constant(class_weight_values, dtype=tf.float32) # Precompute the gather weights function outside the training loop @tf.function(jit_compile=True, experimental_relax_shapes=True) def gather_weights(label_indices: tf.Tensor) -> tf.Tensor | None: if class_weight_tensor is not None: return tf.gather(class_weight_tensor, label_indices) return None # Get optimizer (will use loss scaling automatically under mixed-precision) is_ls: bool = isinstance(model.optimizer, tf.keras.mixed_precision.LossScaleOptimizer) # Training step with proper loss scaling @tf.function(jit_compile=True, experimental_relax_shapes=True) def train_step(xb: tf.Tensor, yb: tf.Tensor, training: bool = True) -> dict[str, Any]: """ Execute a single training step with gradient calculation and optimization. Args: xb (tf.Tensor): Input batch data yb (tf.Tensor): Target batch data Returns: dict[str, Any]: The metrics for the training step """ labels = tf.cast(tf.argmax(yb, axis=1), tf.int32) sw = gather_weights(labels) with tf.GradientTape(watch_accessed_variables=training) as tape: preds = model(xb, training=training) loss = model.compiled_loss(yb, preds, sample_weight=sw) loss = tf.reduce_mean(loss) # Scale loss if using LossScaleOptimizer if is_ls: loss = model.optimizer.get_scaled_loss(loss) # Backpropagate the loss if training: model.optimizer.minimize(loss, model.trainable_weights, tape=tape) # Update the metrics model.compiled_metrics.update_state(yb, preds, sample_weight=sw) return model.get_metrics_result() # Start callbacks logs: dict[str, Any] = {"loss": 0.0} callback_list.on_train_begin() # Custom training loop for epoch in range(epochs): # Callbacks and reset metrics callback_list.on_epoch_begin(epoch) model.compiled_metrics.reset_state() model.compiled_loss.reset_state() # Train on all batches for step, (x_batch, y_batch) in enumerate(train_dataset): callback_list.on_batch_begin(step) logs.update(train_step(x_batch, y_batch, training=True)) callback_list.on_batch_end(step, logs) # Compute metrics for validation if val_dataset is not None: model.compiled_metrics.reset_state() model.compiled_loss.reset_state() # Run through all validation data for x_val, y_val in val_dataset: train_step(x_val, y_val, training=False) # Prefix "val_" to the metrics for key, value in model.get_metrics_result().items(): logs[f"val_{key}"] = value callback_list.on_epoch_end(epoch, logs) callback_list.on_train_end(logs) # Return history return model.history # pyright: ignore [reportReturnType]
[docs] def _get_architectures( self, optimizer: Any = None, loss: Any = None, metrics: list[Any] | None = None ) -> tuple[Model, Model]: """ Get the model architecture and compile it if enough information is provided. This method builds and returns the model architecture. If optimizer, loss, and (optionally) metrics are provided, the model will be compiled. Args: optimizer (Any): The optimizer to use for training loss (Any): The loss function to use for training metrics (list[Any] | None): The metrics to use for evaluation Returns: tuple[Model, Model]: The final model and the base model """ # Get the base model (use imagenet anyway) base_model: Model = self._get_base_model() # Add a top layer since the base model doesn't have one output_layer: Model = Sequential([ GlobalAveragePooling2D(), Dense(self.num_classes, activation="softmax") ])(base_model.output) final_model: Model = Model(inputs=base_model.input, outputs=output_layer) # If no optimizer is provided, return the uncompiled models if optimizer is None: return final_model, base_model # Load transfer learning weights if provided if os.path.exists(self.transfer_learning): try: final_model.load_weights(self.transfer_learning) info(f"Transfer learning weights loaded from '{self.transfer_learning}'") except Exception as e: warning(f"Error loading transfer learning weights from '{self.transfer_learning}': {e}") # Freeze the base model except for the last layers (if unfreeze percentage is less than 100%) if self.unfreeze_percentage < 100.0: base_model.trainable = False last_layers: list[Model] = base_model.layers[-self.fine_tune_last_layers:] for layer in last_layers: layer.trainable = True info( f"Fine-tune from layer {max(0, len(base_model.layers) - self.fine_tune_last_layers)} " f"to {len(base_model.layers)} ({self.fine_tune_last_layers} layers)" ) # Add XLA specific optimizations for compilation compile_options = {} if hasattr(tf.config.optimizer, "get_jit") and tf.config.optimizer.get_jit(): compile_options["steps_per_execution"] = 10 # Batch multiple steps for XLA # Compile the model and return it final_model.compile( optimizer=optimizer, loss=loss, metrics=metrics if metrics is not None else [], jit_compile=True, **compile_options ) return final_model, base_model
# Protected methods for training
[docs] def _get_callbacks(self) -> list[Callback]: """ Get the callbacks for training. """ callbacks: list[Callback] = [] # Add warmup scheduler if enabled if self.warmup_epochs > 0: warmup_scheduler: WarmupScheduler = WarmupScheduler( warmup_epochs=self.warmup_epochs, initial_lr=self.initial_warmup_lr, target_lr=self.learning_rate ) callbacks.append(warmup_scheduler) # Add ReduceLROnPlateau callbacks.append(ReduceLROnPlateau( monitor="val_loss", mode="min", factor=self.factor, patience=self.reduce_lr_patience, min_delta=self.min_delta, min_lr=self.min_lr )) # Add TensorBoard for profiling log_dir: str = f"{DataScienceConfig.TENSORBOARD_FOLDER}/{self.run_name}" os.makedirs(log_dir, exist_ok=True) callbacks.append(TensorBoard( log_dir=log_dir, histogram_freq=1, # Log histogram visualizations every epoch profile_batch=(10, 20) # Profile batches 10-20 )) # Add EarlyStopping to prevent overfitting callbacks.append(EarlyStopping( monitor="val_loss", mode="min", patience=self.early_stop_patience, verbose=0 )) return callbacks
[docs] def _get_metrics(self) -> list[Metric]: """ Get the metrics for training. Returns: list: List of metrics to track during training including accuracy, AUC, etc. """ # Fix the F1Score dtype if mixed precision is enabled f1score_dtype: tf.DType = tf.float16 if DataScienceConfig.MIXED_PRECISION_POLICY == "mixed_float16" else tf.float32 f1score: F1Score = F1Score(name="f1_score", average="macro", dtype=f1score_dtype) f1score.beta = tf.constant(1.0, dtype=f1score_dtype) # pyright: ignore [reportAttributeAccessIssue] return [ CategoricalAccuracy(name="categorical_accuracy"), AUC(name="auc"), f1score, ]
[docs] def _get_optimizer(self, learning_rate: float = 0.0, mode: int = 1) -> Optimizer: """ Get the optimizer for training. Args: learning_rate (float): Learning rate mode (int): Mode to use Returns: Optimizer: Optimizer """ lr: float = self.learning_rate if learning_rate == 0.0 else learning_rate if mode == 0: return Adam(lr, self.beta_1, self.beta_2) elif mode == 1: return AdamW(lr, self.beta_1, self.beta_2) else: return Lion(lr)
[docs] def _get_loss(self, mode: int = 0) -> Loss: """ Get the loss function for training depending on the mode. - 0: CategoricalCrossentropy (default) - 1: CategoricalFocalCrossentropy - 2: Next Generation Loss (with alpha = 2.4092) Args: mode (int): Mode to use Returns: Loss: Loss function """ if mode == 0: return CategoricalCrossentropy(name="categorical_crossentropy") elif mode == 1: return CategoricalFocalCrossentropy(name="categorical_focal_crossentropy") elif mode == 2: return NextGenerationLoss(name="ngl_loss") else: raise ValueError(f"Invalid mode: {mode}")
[docs] def _find_best_learning_rate_subprocess( self, dataset: Dataset, queue: multiprocessing.queues.Queue | None = None, verbose: int = 0 # type: ignore ) -> dict[str, Any] | None: """ Helper to run learning rate finder, potentially in a subprocess. Args: dataset (Dataset): Dataset to use for training. queue (multiprocessing.Queue | None): Queue to put results in (if running in subprocess). verbose (int): Verbosity level. Returns: dict[str, Any] | None: Return values """ X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array() # Set random seeds for reproducibility within the process/subprocess set_random_seed(DataScienceConfig.SEED) # Create LR finder callback lr_finder: LearningRateFinder = LearningRateFinder( min_lr=self.lr_finder_min_lr, max_lr=self.lr_finder_max_lr, steps_per_epoch=np.ceil(len(X_train) / self.batch_size), epochs=self.lr_finder_epochs, update_per_epoch=self.lr_finder_update_per_epoch, update_interval=self.lr_finder_update_interval ) # Get compiled model with the optimizer and loss final_model, _ = self._get_architectures(self._get_optimizer(), self._get_loss()) # Create callbacks callbacks: list[Callback] = [lr_finder] if verbose > 0: callbacks.append(ColoredProgressBar("LR Finder", show_lr=True)) # Run a mini training to find the best learning rate self._fit( final_model, X_train, y_train, batch_size=self.batch_size, epochs=self.lr_finder_epochs, callbacks=callbacks, class_weight=self.class_weight, verbose=0 ) # Prepare results results: dict[str, Any] = { "learning_rates": lr_finder.learning_rates, "losses": lr_finder.losses } # Return values if no queue, otherwise put them in the queue if queue is None: return results else: return queue.put(results)
[docs] def _find_best_unfreeze_percentage_subprocess( self, dataset: Dataset, queue: multiprocessing.queues.Queue | None = None, verbose: int = 0 # type: ignore ) -> dict[str, Any] | None: """ Helper to run unfreeze percentage finder, potentially in a subprocess. Args: dataset (Dataset): Dataset to use for training. queue (multiprocessing.Queue | None): Queue to put results in (if running in subprocess). verbose (int): Verbosity level. Returns: dict[str, Any] | None: Return values """ X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array() # Set random seeds for reproducibility within the process/subprocess set_random_seed(DataScienceConfig.SEED) # Get compiled model with the optimizer and loss lr: float = self.learning_rate optimizer = self._get_optimizer(lr) loss_fn = self._get_loss() final_model, base_model = self._get_architectures(optimizer, loss_fn) # Function to get compiled optimizer def get_compiled_optimizer() -> Optimizer: optimizer: Optimizer = self._get_optimizer(lr) return final_model._get_optimizer(optimizer) # pyright: ignore [reportPrivateUsage] # Create unfreeze finder callback unfreeze_finder: ProgressiveUnfreezing = ProgressiveUnfreezing( base_model=base_model, steps_per_epoch=np.ceil(len(X_train) / self.batch_size), epochs=self.unfreeze_finder_epochs, reset_weights=True, reset_optimizer_function=get_compiled_optimizer, update_per_epoch=self.unfreeze_finder_update_per_epoch, update_interval=self.unfreeze_finder_update_interval, progressive_freeze=True # Start from 100% unfrozen to 0% unfrozen to prevent biases ) # Create callbacks callbacks: list[Callback] = [unfreeze_finder] if verbose > 0: callbacks.append(ColoredProgressBar("Unfreeze Finder")) self._fit( final_model, X_train, y_train, batch_size=self.batch_size, epochs=self.unfreeze_finder_epochs, callbacks=callbacks, class_weight=self.class_weight, verbose=0 ) # Prepare results unfreeze_percentages, losses = unfreeze_finder.get_results() results: dict[str, Any] = { "unfreeze_percentages": unfreeze_percentages, "losses": losses } # Return values if no queue, otherwise put them in the queue if queue is None: return results else: return queue.put(results)
[docs] def _train_subprocess( self, dataset: Dataset, checkpoint_path: str, temp_dir: TemporaryDirectory[str] | None = None, queue: multiprocessing.queues.Queue | None = None, # type: ignore verbose: int = 0 ) -> dict[str, Any] | None: """ Train the model in a subprocess. The reason for this is that when training too much models on the same process, your process may be killed by the OS since it used too much resources over time. So we train each model in a separate process to avoid this issue. Args: model (Model): Model to train dataset (Dataset): Dataset to train on checkpoint_path (str): Path to save the best model checkpoint temp_dir (TemporaryDirectory[str] | None): Temporary directory to save the visualizations queue (multiprocessing.Queue | None): Queue to put the history in verbose (int): Verbosity level Returns: dict[str, Any]: Return values """ to_return: dict[str, Any] = {} set_random_seed(DataScienceConfig.SEED) # Extract the training and validation data X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array() X_val, y_val, _ = dataset.val_data.ungrouped_array() X_test, y_test, test_filepaths = dataset.test_data.ungrouped_array() true_classes: NDArray[Any] = Utils.convert_to_class_indices(y_val) # Create the checkpoint callback os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) model_checkpoint: ModelCheckpointV2 = ModelCheckpointV2( epochs_before_start=self.model_checkpoint_delay, filepath=checkpoint_path, monitor="val_loss", mode="min", save_best_only=True, save_weights_only=True, verbose=0 ) # Get the compiled model model, _ = self._get_architectures(self._get_optimizer(), self._get_loss(), self._get_metrics()) # Create the callbacks, add the progress bar if verbose is 1 callbacks = [model_checkpoint, *self._get_callbacks()] if verbose > 0: callbacks.append(ColoredProgressBar("Training", show_lr=True)) # Train the model history: History = self._fit( model, X_train, y_train, validation_data=(X_val, y_val), batch_size=self.batch_size, epochs=self.epochs, callbacks=callbacks, class_weight=self.class_weight, verbose=0 ) # Load the best model from the checkpoint file and remove it debug(f"Loading best model from '{checkpoint_path}'") model.load_weights(checkpoint_path) os.remove(checkpoint_path) debug(f"Best model loaded from '{checkpoint_path}', deleting it...") # Evaluate the model to_return["history"] = history.history to_return["eval_results"] = model.evaluate(X_test, y_test, return_dict=True, verbose=0) to_return["predictions"] = model.predict(X_test, verbose=0) to_return["true_classes"] = true_classes to_return["training_predictions"] = model.predict(X_train, verbose=0) to_return["training_true_classes"] = Utils.convert_to_class_indices(y_train) # --- Visualization Generation (Using viz_kwargs) --- if temp_dir is not None: # Ensure fold_number > 0 for LOO visualization test_images: list[NDArray[Any]] = list(X_test) # Prepare the arguments for the visualizations viz_args_list: list[tuple[NDArray[Any], Any, tuple[str, ...], str]] = [ (test_images[i], true_classes[i], test_filepaths[i], "test_folds") for i in range(len(test_images)) ] # Generate visualizations in the provided temporary directory for img_viz, label_idx, files, data_type in viz_args_list: # Extract the base name of the file/group if dataset.grouping_strategy == GroupingStrategy.NONE: base_name: str = os.path.splitext(os.path.basename(files[0]))[0] else: base_name: str = os.path.basename(os.path.dirname(files[0])) # Generate all visualizations for the image all_visualizations_for_image( model=model, # Use the trained model from this subprocess folder_path=temp_dir.name, img=img_viz, base_name=base_name, class_idx=label_idx, class_name=dataset.labels[label_idx], files=files, data_type=data_type, ) # Return values if no queue, otherwise put them in the queue if queue is None: to_return["model"] = model # Add the trained model to the return values if not in a subprocess return to_return else: return queue.put(to_return)
# Predict method
[docs] def class_predict(self, X_test: Iterable[NDArray[Any]] | tf.data.Dataset) -> Iterable[NDArray[Any]]: """ Predict the class for the given input data. Args: X_test (Iterable[NDArray[Any]]): List of inputs to predict (e.g. a batch of images) Returns: Iterable[NDArray[Any]]: A batch of predictions (model.predict()) """ # Create a tf.data.Dataset to avoid retracing if isinstance(X_test, tf.data.Dataset): dataset: tf.data.Dataset = X_test was_dataset: bool = True else: dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(32).prefetch(tf.data.AUTOTUNE) was_dataset: bool = False # Create an optimized prediction function @tf.function(jit_compile=True) def optimized_predict(x_batch: tf.Tensor) -> tf.Tensor: return self.final_model(x_batch, training=False) # For each model, predict the class model_preds: list[NDArray[Any]] = [] for batch in dataset: pred: tf.Tensor = optimized_predict(batch) model_preds.append(pred.numpy()) # Clear RAM if not was_dataset: del dataset gc.collect() # Return the predictions return np.concatenate(model_preds) if model_preds else np.array([])
# Protected methods for evaluation
[docs] @measure_time(progress) def _log_final_model(self) -> None: """ Log the best model (and its weights). """ with Muffle(mute_stderr=True): mlflow.keras.log_model(self.final_model, "best_model") # pyright: ignore [reportPrivateImportUsage] mlflow.set_tag(key="has_saved_model", value="True") # Get the weights path and create the directory if it doesn't exist weights_path: str = mlflow_utils.get_weights_path() os.makedirs(os.path.dirname(weights_path), exist_ok=True) # Save the best model's weights without the last layer self.final_model.save_weights(weights_path)
[docs] def class_evaluate( self, dataset: Dataset, metrics_names: tuple[str, ...] = (), save_model: bool = False, verbose: int = 0 ) -> bool: """ Evaluate the model using the given predictions and labels. Args: dataset (Dataset): Dataset containing the training and testing data metrics_names (list[str]): List of metrics to plot (default to all metrics) save_model (bool): Whether to save the best model verbose (int): Level of verbosity Returns: bool: True if evaluation was successful """ # First perform standard evaluation from parent class result: bool = super().class_evaluate(dataset, metrics_names, save_model, verbose) if not DataScienceConfig.DO_SALIENCY_AND_GRADCAM: return result # Get test and train data X_test, y_test, test_filepaths = dataset.test_data.ungrouped_array() test_images: list[NDArray[Any]] = list(X_test) test_labels: list[int] = Utils.convert_to_class_indices(y_test).tolist() X_train, y_train, train_filepaths = dataset.training_data.remove_augmented_files().ungrouped_array() train_images: list[NDArray[Any]] = list(X_train) train_labels: list[int] = Utils.convert_to_class_indices(y_train).tolist() # Process test images test_args_list: list[tuple[NDArray[Any], int, tuple[str, ...], str]] = [ (test_images[i], test_labels[i], test_filepaths[i], "test") for i in range(min(20, len(test_images))) ] # Process train images train_args_list: list[tuple[NDArray[Any], int, tuple[str, ...], str]] = [ (train_images[i], train_labels[i], train_filepaths[i], "train") for i in range(min(10, len(train_images))) ] # Combine both lists all_args_list = test_args_list + train_args_list # Create the description desc: str = "" if verbose > 0: desc = f"Generating visualizations for {len(test_args_list)} test and {len(train_args_list)} train images" # For each image, generate all visualizations, then log them to MLFlow with TemporaryDirectory() as temp_dir: for img, label, files, data_type in colored_for_loop(all_args_list, desc=desc): # Extract the base name of the file if dataset.grouping_strategy == GroupingStrategy.NONE: base_name: str = os.path.splitext(os.path.basename(files[0]))[0] else: base_name: str = os.path.basename(os.path.dirname(files[0])) # Generate all visualizations for the image all_visualizations_for_image( model=self.final_model, folder_path=temp_dir, img=img, base_name=base_name, class_idx=label, class_name=dataset.labels[label], files=files, data_type=data_type, ) # Log the visualizations mlflow.log_artifacts(temp_dir) return result