Source code for stouputils.data_science.models.keras_utils.visualizations

""" Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability. """

# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
# pyright: reportUnknownArgumentType=false
# pyright: reportMissingTypeStubs=false

# Imports
import os
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from keras.models import Model
from matplotlib.colors import Colormap
from numpy.typing import NDArray
from PIL import Image

from ....decorators import handle_error
from ....print import error, warning
from ...config.get import DataScienceConfig


[docs] @handle_error(error_log=DataScienceConfig.ERROR_LOG) def make_gradcam_heatmap( model: Model, img: NDArray[Any], class_idx: int = 0, last_conv_layer_name: str = "", one_per_channel: bool = False ) -> list[NDArray[Any]]: """ Generate a Grad-CAM heatmap for a given image and model. Args: model (Model): The pre-trained TensorFlow model img (NDArray[Any]): The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?)) class_idx (int): The class index to use for the Grad-CAM heatmap last_conv_layer_name (str): Name of the last convolutional layer in the model (optional, will try to find it automatically) one_per_channel (bool): If True, return one heatmap per channel Returns: list[NDArray[Any]]: The Grad-CAM heatmap(s) Examples: .. code-block:: python > model: Model = ... > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB")) > last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out" > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0] > Image.fromarray(heatmap).save("heatmap.png") """ # Assertions assert isinstance(model, Model), "Model must be a valid Keras model" assert isinstance(img, np.ndarray), "Image array must be a valid numpy array" # If img is not a batch of 1, convert it to a batch of 1 if img.ndim == 3: img = np.expand_dims(img, axis=0) assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))" # If last_conv_layer_name is not provided, find it automatically if not last_conv_layer_name: last_conv_layer_name = find_last_conv_layer(model) if not last_conv_layer_name: error("Last convolutional layer not found. Please provide the name of the last convolutional layer.") # Get the last convolutional layer last_layer: Model | list[Model] = model.get_layer(last_conv_layer_name) if isinstance(last_layer, list): last_layer = last_layer[0] # Create a model that outputs both the last conv layer's activations and predictions grad_model: Model = Model( [model.inputs], [last_layer.output, model.output] ) # Record operations for automatic differentiation using GradientTape. with tf.GradientTape() as tape: # Forward pass: get the activations of the last conv layer and the predictions. conv_outputs: tf.Tensor predictions: tf.Tensor conv_outputs, predictions = grad_model(img) # print("conv_outputs shape:", conv_outputs.shape) # print("predictions shape:", predictions.shape) # Compute the loss with respect to the positive class loss: tf.Tensor = predictions[:, class_idx] # Compute the gradient of the loss with respect to the activations .. # .. of the last conv layer grads: tf.Tensor = tf.convert_to_tensor(tape.gradient(loss, conv_outputs)) # print("grads shape:", grads.shape) # Initialize heatmaps list heatmaps: list[NDArray[Any]] = [] if one_per_channel: # Return one heatmap per channel # Get the number of filters in the last conv layer num_filters: int = int(tf.shape(conv_outputs)[-1]) for i in range(num_filters): # Compute the mean intensity of the gradients for this filter pooled_grad: tf.Tensor = tf.reduce_mean(grads[:, :, :, i]) # Get the activation map for this filter activation_map: tf.Tensor = conv_outputs[:, :, :, i] # Weight the activation map by the gradient weighted_map: tf.Tensor = activation_map * pooled_grad # Ensure that the heatmap has non-negative values and normalize it heatmap: tf.Tensor = tf.maximum(weighted_map, 0) / tf.math.reduce_max(weighted_map) # Remove possible dims of size 1 + convert the heatmap to a numpy array heatmaps.append(tf.squeeze(heatmap).numpy()) else: # Compute the mean intensity of the gradients for each filter in the last conv layer. pooled_grads: tf.Tensor = tf.reduce_mean(grads, axis=(0, 1, 2)) # Multiply each activation map in the last conv layer by the corresponding # gradient average, which emphasizes the parts of the image that are important. heatmap: tf.Tensor = conv_outputs @ tf.expand_dims(pooled_grads, axis=-1) # print("heatmap shape (before squeeze):", heatmap.shape) # Ensure that the heatmap has non-negative values and normalize it. heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) # Remove possible dims of size 1 + convert the heatmap to a numpy array. heatmaps.append(tf.squeeze(heatmap).numpy()) return heatmaps
[docs] @handle_error(error_log=DataScienceConfig.ERROR_LOG) def make_saliency_map( model: Model, img: NDArray[Any], class_idx: int = 0, one_per_channel: bool = False ) -> list[NDArray[Any]]: """ Generate a saliency map for a given image and model. A saliency map shows which pixels in the input image have the greatest influence on the model's prediction. Args: model (Model): The pre-trained TensorFlow model img (NDArray[Any]): The preprocessed image array (batch of 1) class_idx (int): The class index to use for the saliency map one_per_channel (bool): If True, return one saliency map per channel Returns: list[NDArray[Any]]: The saliency map(s) normalized to range [0,1] Examples: .. code-block:: python > model: Model = ... > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB")) > saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0] """ # Assertions assert isinstance(model, Model), "Model must be a valid Keras model" assert isinstance(img, np.ndarray), "Image array must be a valid numpy array" # If img is not a batch of 1, convert it to a batch of 1 if img.ndim == 3: img = np.expand_dims(img, axis=0) assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))" # Convert the numpy array to a tensor img_tensor = tf.convert_to_tensor(img, dtype=tf.float32) # Record operations for automatic differentiation with tf.GradientTape() as tape: tape.watch(img_tensor) predictions = model(img_tensor) # Use class index for positive class prediction loss = predictions[:, class_idx] # Compute gradients of loss with respect to input image try: grads = tape.gradient(loss, img_tensor) # Cast to tensor to satisfy type checker grads = tf.convert_to_tensor(grads) except Exception as e: warning(f"Error computing gradients: {e}") return [] # Initialize saliency maps list saliency_maps: list[NDArray[Any]] = [] if one_per_channel: # Return one saliency map per channel # Get the last dimension size as an integer num_channels: int = int(tf.shape(grads)[-1]) for i in range(num_channels): # Extract gradients for this channel channel_grads: tf.Tensor = tf.abs(grads[:, :, :, i]) # Apply smoothing to make the map more visually coherent channel_grads = tf.nn.avg_pool2d( tf.expand_dims(channel_grads, -1), ksize=3, strides=1, padding='SAME' ) channel_grads = tf.squeeze(channel_grads) # Normalize the saliency map for this channel channel_max = tf.math.reduce_max(channel_grads) if channel_max > 0: channel_saliency: tf.Tensor = channel_grads / channel_max else: channel_saliency: tf.Tensor = channel_grads saliency_maps.append(channel_saliency.numpy().squeeze()) else: # Take absolute value of gradients abs_grads = tf.abs(grads) # Sum across color channels for RGB images saliency = tf.reduce_sum(abs_grads, axis=-1) # Apply smoothing to make the map more visually coherent saliency = tf.nn.avg_pool2d( tf.expand_dims(saliency, -1), ksize=3, strides=1, padding='SAME' ) saliency = tf.squeeze(saliency) # Normalize saliency map saliency_max = tf.math.reduce_max(saliency) if saliency_max > 0: saliency = saliency / saliency_max saliency_maps.append(saliency.numpy().squeeze()) return saliency_maps
[docs] @handle_error(error_log=DataScienceConfig.ERROR_LOG) def find_last_conv_layer(model: Model) -> str: """ Find the name of the last convolutional layer in a model. Args: model (Model): The TensorFlow model to analyze Returns: str: Name of the last convolutional layer if found, otherwise an empty string Examples: .. code-block:: python > model: Model = ... > last_conv_layer: str = Utils.find_last_conv_layer(model) > print(last_conv_layer) 'conv5_block3_out' """ assert isinstance(model, Model), "Model must be a valid Keras model" # Find the last convolutional layer by iterating through the layers in reverse last_conv_layer_name: str = "" for layer in reversed(model.layers): if isinstance(layer, tf.keras.layers.Conv2D): last_conv_layer_name = layer.name break # Return the name of the last convolutional layer return last_conv_layer_name
[docs] @handle_error(error_log=DataScienceConfig.ERROR_LOG) def create_visualization_overlay( original_img: NDArray[Any] | Image.Image, heatmap: NDArray[Any], alpha: float = 0.4, colormap: str = "jet" ) -> NDArray[Any]: """ Create an overlay of the original image with a heatmap visualization. Args: original_img (NDArray[Any] | Image.Image): The original image array or PIL Image heatmap (NDArray[Any]): The heatmap to overlay (normalized to 0-1) alpha (float): Transparency level of overlay (0-1) colormap (str): Matplotlib colormap to use for heatmap Returns: NDArray[Any]: The overlaid image Examples: .. code-block:: python > original: NDArray[Any] | Image.Image = ... > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0] > overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap) > Image.fromarray(overlay).save("overlay.png") """ # Ensure heatmap is normalized to 0-1 if heatmap.max() > 0: heatmap = heatmap / heatmap.max() ## Apply colormap to heatmap # Get the colormap cmap: Colormap = plt.cm.get_cmap(colormap) # Ensure heatmap is at least 2D to prevent tuple return from cmap if np.isscalar(heatmap) or heatmap.ndim == 0: heatmap = np.array([[heatmap]]) # Apply colormap and ensure it's a numpy array heatmap_rgb: NDArray[Any] = np.array(cmap(heatmap)) # Remove alpha channel if present if heatmap_rgb.ndim >= 3: heatmap_rgb = heatmap_rgb[:, :, :3] heatmap_rgb = (255 * heatmap_rgb).astype(np.uint8) # Convert original image to PIL Image if it's a numpy array if isinstance(original_img, np.ndarray): # Scale to 0-255 range if image is normalized (0-1) if original_img.max() <= 1: img: NDArray[Any] = (original_img * 255).astype(np.uint8) else: img: NDArray[Any] = original_img.astype(np.uint8) original_pil: Image.Image = Image.fromarray(img) else: original_pil: Image.Image = original_img # Convert heatmap to PIL Image heatmap_pil: Image.Image = Image.fromarray(heatmap_rgb) # Resize heatmap to match original image size if needed if heatmap_pil.size != original_pil.size: heatmap_pil = heatmap_pil.resize(original_pil.size, Image.Resampling.LANCZOS) # Ensure both images have the same mode before blending if original_pil.mode != heatmap_pil.mode: if original_pil.mode == "RGB" and heatmap_pil.mode == "RGBA": heatmap_pil = heatmap_pil.convert("RGB") else: original_pil = original_pil.convert(heatmap_pil.mode) # Blend images overlaid_img: Image.Image = Image.blend(original_pil, heatmap_pil, alpha=alpha) return np.array(overlaid_img)
[docs] @handle_error(error_log=DataScienceConfig.ERROR_LOG) def all_visualizations_for_image( model: Model, folder_path: str, img: NDArray[Any], base_name: str, class_idx: int, class_name: str, files: tuple[str, ...], data_type: str ) -> None: """ Process a single image to generate visualizations and determine prediction correctness. Args: model (Model): The pre-trained TensorFlow model folder_path (str): The path to the folder where the visualizations will be saved img (NDArray[Any]): The preprocessed image array (batch of 1) base_name (str): The base name of the image class_idx (int): The **true** class index for the image class_name (str): The **true** class name for organizing folders files (tuple[str, ...]): List of original file paths for the subject data_type (str): Type of data ("test" or "train") """ # Set directories based on data type saliency_dir: str = f"{folder_path}/{data_type}/saliency_maps" grad_cam_dir: str = f"{folder_path}/{data_type}/grad_cams" # Convert label to class name and load all original images original_imgs: list[Image.Image] = [Image.open(f).convert("RGB") for f in files] # Find the maximum dimensions across all images max_shape: tuple[int, int] = max(i.size for i in original_imgs) # Resize all images to match the largest dimensions while preserving aspect ratio resized_imgs: list[NDArray[Any]] = [] for original_img in original_imgs: if original_img.size != max_shape: original_img = original_img.resize(max_shape, resample=Image.Resampling.LANCZOS) resized_imgs.append(np.array(original_img)) # Take mean of resized images subject_image: NDArray[Any] = np.mean(np.stack(resized_imgs), axis=0).astype(np.uint8) # Perform prediction to determine correctness # Ensure img is in batch format (add dimension if needed) img_batch: NDArray[Any] = np.expand_dims(img, axis=0) if img.ndim == 3 else img predicted_class_idx: int = int(np.argmax(model.predict(img_batch, verbose=0)[0])) # pyright: ignore [reportArgumentType] prediction_correct: bool = (predicted_class_idx == class_idx) status_suffix: str = "_correct" if prediction_correct else "_missed" # Create and save Grad-CAM visualization heatmap: NDArray[Any] = make_gradcam_heatmap(model, img, class_idx=class_idx, one_per_channel=False)[0] overlay: NDArray[Any] = create_visualization_overlay(subject_image, heatmap) # Save the overlay visualization with status suffix grad_cam_path = f"{grad_cam_dir}/{class_name}/{base_name}{status_suffix}_gradcam.png" os.makedirs(os.path.dirname(grad_cam_path), exist_ok=True) Image.fromarray(overlay).save(grad_cam_path) # Create and save saliency maps saliency_map: NDArray[Any] = make_saliency_map(model, img, class_idx=class_idx, one_per_channel=False)[0] overlay: NDArray[Any] = create_visualization_overlay(subject_image, saliency_map) # Save the overlay visualization with status suffix saliency_path = f"{saliency_dir}/{class_name}/{base_name}{status_suffix}_saliency.png" os.makedirs(os.path.dirname(saliency_path), exist_ok=True) Image.fromarray(overlay).save(saliency_path)