stouputils.data_science.models.keras_utils.visualizations module#

Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability.

make_gradcam_heatmap(
model: Model,
img: ndarray[tuple[int, ...], dtype[Any]],
class_idx: int = 0,
last_conv_layer_name: str = '',
one_per_channel: bool = False,
) list[ndarray[tuple[int, ...], dtype[Any]]][source]#

Generate a Grad-CAM heatmap for a given image and model.

Parameters:
  • model (Model) – The pre-trained TensorFlow model

  • img (NDArray[Any]) – The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))

  • class_idx (int) – The class index to use for the Grad-CAM heatmap

  • last_conv_layer_name (str) – Name of the last convolutional layer in the model (optional, will try to find it automatically)

  • one_per_channel (bool) – If True, return one heatmap per channel

Returns:

The Grad-CAM heatmap(s)

Return type:

list[NDArray[Any]]

Examples

> model: Model = ...
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
> last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
> Image.fromarray(heatmap).save("heatmap.png")
make_saliency_map(
model: Model,
img: ndarray[tuple[int, ...], dtype[Any]],
class_idx: int = 0,
one_per_channel: bool = False,
) list[ndarray[tuple[int, ...], dtype[Any]]][source]#

Generate a saliency map for a given image and model.

A saliency map shows which pixels in the input image have the greatest influence on the model’s prediction.

Parameters:
  • model (Model) – The pre-trained TensorFlow model

  • img (NDArray[Any]) – The preprocessed image array (batch of 1)

  • class_idx (int) – The class index to use for the saliency map

  • one_per_channel (bool) – If True, return one saliency map per channel

Returns:

The saliency map(s) normalized to range [0,1]

Return type:

list[NDArray[Any]]

Examples

> model: Model = ...
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
> saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
find_last_conv_layer(
model: Model,
) str[source]#

Find the name of the last convolutional layer in a model.

Parameters:

model (Model) – The TensorFlow model to analyze

Returns:

Name of the last convolutional layer if found, otherwise an empty string

Return type:

str

Examples

> model: Model = ...
> last_conv_layer: str = Utils.find_last_conv_layer(model)
> print(last_conv_layer)
'conv5_block3_out'
create_visualization_overlay(
original_img: ndarray[tuple[int, ...], dtype[Any]] | Image,
heatmap: ndarray[tuple[int, ...], dtype[Any]],
alpha: float = 0.4,
colormap: str = 'jet',
) ndarray[tuple[int, ...], dtype[Any]][source]#

Create an overlay of the original image with a heatmap visualization.

Parameters:
  • original_img (NDArray[Any] | Image.Image) – The original image array or PIL Image

  • heatmap (NDArray[Any]) – The heatmap to overlay (normalized to 0-1)

  • alpha (float) – Transparency level of overlay (0-1)

  • colormap (str) – Matplotlib colormap to use for heatmap

Returns:

The overlaid image

Return type:

NDArray[Any]

Examples

> original: NDArray[Any] | Image.Image = ...
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
> overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
> Image.fromarray(overlay).save("overlay.png")
all_visualizations_for_image(
model: Model,
folder_path: str,
img: ndarray[tuple[int, ...], dtype[Any]],
base_name: str,
class_idx: int,
class_name: str,
files: tuple[str, ...],
data_type: str,
) None[source]#

Process a single image to generate visualizations and determine prediction correctness.

Parameters:
  • model (Model) – The pre-trained TensorFlow model

  • folder_path (str) – The path to the folder where the visualizations will be saved

  • img (NDArray[Any]) – The preprocessed image array (batch of 1)

  • base_name (str) – The base name of the image

  • class_idx (int) – The true class index for the image

  • class_name (str) – The true class name for organizing folders

  • files (tuple[str, ...]) – List of original file paths for the subject

  • data_type (str) – Type of data (“test” or “train”)