stouputils.data_science.models.keras_utils.visualizations module#

Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability.

make_gradcam_heatmap(model: Model, img: ndarray[Any, dtype[Any]], class_idx: int = 0, last_conv_layer_name: str = '', one_per_channel: bool = False) list[ndarray[Any, dtype[Any]]][source]#

Generate a Grad-CAM heatmap for a given image and model.

Parameters:
  • model (Model) – The pre-trained TensorFlow model

  • img (NDArray[Any]) – The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))

  • class_idx (int) – The class index to use for the Grad-CAM heatmap

  • last_conv_layer_name (str) – Name of the last convolutional layer in the model (optional, will try to find it automatically)

  • one_per_channel (bool) – If True, return one heatmap per channel

Returns:

The Grad-CAM heatmap(s)

Return type:

list[NDArray[Any]]

Examples

> model: Model = ...
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
> last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
> Image.fromarray(heatmap).save("heatmap.png")
make_saliency_map(model: Model, img: ndarray[Any, dtype[Any]], class_idx: int = 0, one_per_channel: bool = False) list[ndarray[Any, dtype[Any]]][source]#

Generate a saliency map for a given image and model.

A saliency map shows which pixels in the input image have the greatest influence on the model’s prediction.

Parameters:
  • model (Model) – The pre-trained TensorFlow model

  • img (NDArray[Any]) – The preprocessed image array (batch of 1)

  • class_idx (int) – The class index to use for the saliency map

  • one_per_channel (bool) – If True, return one saliency map per channel

Returns:

The saliency map(s) normalized to range [0,1]

Return type:

list[NDArray[Any]]

Examples

> model: Model = ...
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
> saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
find_last_conv_layer(model: Model) str[source]#

Find the name of the last convolutional layer in a model.

Parameters:

model (Model) – The TensorFlow model to analyze

Returns:

Name of the last convolutional layer if found, otherwise an empty string

Return type:

str

Examples

> model: Model = ...
> last_conv_layer: str = Utils.find_last_conv_layer(model)
> print(last_conv_layer)
'conv5_block3_out'
create_visualization_overlay(original_img: ndarray[Any, dtype[Any]] | Image, heatmap: ndarray[Any, dtype[Any]], alpha: float = 0.4, colormap: str = 'jet') ndarray[Any, dtype[Any]][source]#

Create an overlay of the original image with a heatmap visualization.

Parameters:
  • original_img (NDArray[Any] | Image.Image) – The original image array or PIL Image

  • heatmap (NDArray[Any]) – The heatmap to overlay (normalized to 0-1)

  • alpha (float) – Transparency level of overlay (0-1)

  • colormap (str) – Matplotlib colormap to use for heatmap

Returns:

The overlaid image

Return type:

NDArray[Any]

Examples

> original: NDArray[Any] | Image.Image = ...
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
> overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
> Image.fromarray(overlay).save("overlay.png")
all_visualizations_for_image(model: Model, folder_path: str, img: ndarray[Any, dtype[Any]], base_name: str, class_idx: int, class_name: str, files: tuple[str, ...], data_type: str) None[source]#

Process a single image to generate visualizations and determine prediction correctness.

Parameters:
  • model (Model) – The pre-trained TensorFlow model

  • folder_path (str) – The path to the folder where the visualizations will be saved

  • img (NDArray[Any]) – The preprocessed image array (batch of 1)

  • base_name (str) – The base name of the image

  • class_idx (int) – The true class index for the image

  • class_name (str) – The true class name for organizing folders

  • files (tuple[str, ...]) – List of original file paths for the subject

  • data_type (str) – Type of data (“test” or “train”)