Source code for stouputils.data_science.data_processing.image_preprocess


# Imports
import os
import shutil
from typing import Any

import cv2
import numpy as np
from numpy.typing import NDArray

from ...decorators import handle_error
from ...parallel import multiprocessing, CPU_COUNT
from ...print import warning, error
from ...io import clean_path, super_copy
from .technique import ProcessingTechnique


# Image dataset augmentation class
[docs] class ImageDatasetPreprocess: """ Image dataset preprocessing class. Check the class constructor for more information. """ # Class constructor (configuration) def __init__(self, techniques: list[ProcessingTechnique] | None = None) -> None: """ Initialize the image dataset augmentation class with the given parameters. Args: techniques (list[ProcessingTechnique]): List of processing techniques to apply. """ if techniques is None: techniques = [] assert all(isinstance(x, ProcessingTechnique) for x in techniques), ( "All techniques must be ProcessingTechnique objects" ) self.techniques: list[ProcessingTechnique] = [x.deterministic(use_default=True) for x in techniques]
[docs] @handle_error(message="Error while getting files recursively") def get_files_recursively( self, source: str, destination: str, extensions: tuple[str,...] = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif") ) -> dict[str, str]: """ Recursively get all files in a directory and their destinations. Args: source (str): Path to the source directory destination (str): Path to the destination directory extensions (tuple[str,...]): Tuple of extensions to consider (e.g. (".jpg", ".png")) Returns: dict[str, str]: Dictionary mapping source paths to destination paths """ files: dict[str, str] = {} if os.path.isfile(source) and source.endswith(extensions): files[source] = destination elif os.path.isdir(source): for item in os.listdir(source): item_path: str = f"{source}/{item}" item_dest: str = f"{destination}/{item}" files.update(self.get_files_recursively(item_path, item_dest, extensions)) return files
[docs] @handle_error(message="Error while getting queue of files to process") def get_queue(self, dataset_path: str, destination_path: str) -> list[tuple[str, str, list[ProcessingTechnique]]]: """ Get the queue of images to process with their techniques. This method converts the processing techniques ranges to fixed values and builds a queue of files to process by recursively finding all images in the dataset path. Args: dataset_path (str): Path to the dataset directory destination_path (str): Path to the destination directory where processed images will be saved Returns: list[tuple[str, str, list[ProcessingTechnique]]]: Queue of (source_path, dest_path, techniques) tuples """ # Convert technique ranges to fixed values self.techniques = [x.deterministic(use_default=True) for x in self.techniques] # Build queue by recursively finding all images and their destinations return [ (path, dest, self.techniques) for path, dest in self.get_files_recursively(dataset_path, destination_path).items() ]
[docs] @handle_error(message="Error while processing the dataset") def process_dataset( self, dataset_path: str, destination_path: str, max_workers: int = CPU_COUNT, ignore_confirmation: bool = False ) -> None: """ Preprocess the dataset by applying the given processing techniques to the images. Args: dataset_path (str): Path to the dataset destination_path (str): Path to the destination dataset max_workers (int): Number of workers to use (Defaults to CPU_COUNT) ignore_confirmation (bool): If True, don't ask for confirmation """ # Clean paths dataset_path = clean_path(dataset_path) destination_path = clean_path(destination_path) # If destination folder exists, ask user if they want to delete it if os.path.isdir(destination_path): if not ignore_confirmation: warning(f"Destination folder '{destination_path}' already exists.\nDo you want to delete it? (y/N)") if input().lower() == "y": shutil.rmtree(destination_path) else: error("Aborting...", exit=False) return else: warning(f"Destination folder '{destination_path}' already exists.\nDeleting it...") shutil.rmtree(destination_path) # Prepare the multiprocessing arguments (image path, destination path, techniques) queue: list[tuple[str, str, list[ProcessingTechnique]]] = self.get_queue(dataset_path, destination_path) # Apply the processing techniques in parallel splitted: list[str] = dataset_path.split('/') short_path: str = f".../{splitted[-1]}" if len(splitted) > 2 else dataset_path multiprocessing( self.apply_techniques, queue, use_starmap=True, desc=f"Processing dataset '{short_path}'", max_workers=max_workers )
[docs] @staticmethod def apply_techniques(path: str, dest: str, techniques: list[ProcessingTechnique], use_padding: bool = True) -> None: """ Apply the processing techniques to the image. Args: path (str): Path to the image dest (str): Path to the destination image techniques (list[ProcessingTechnique]): List of processing techniques to apply use_padding (bool): If True, add padding to the image before applying techniques """ if not techniques: super_copy(path, dest) return # Read the image img: NDArray[Any] = cv2.imread(path, cv2.IMREAD_UNCHANGED) if not use_padding: # Add a padding (to avoid cutting the image) previous_shape: tuple[int, ...] = img.shape[:2] padding: int = max(previous_shape[0], previous_shape[1]) // 2 img = np.pad( # pyright: ignore [reportUnknownMemberType] img, pad_width=((padding, padding), (padding, padding), (0, 0)), mode="constant", constant_values=0 ) # Compute the dividers that will be used to adjust techniques parameters dividers: tuple[float, float] = ( img.shape[0] / previous_shape[0], img.shape[1] / previous_shape[1] ) else: dividers = (1.0, 1.0) padding = 0 # Apply the processing techniques for technique in techniques: img = technique.apply(img, dividers) # Remove the padding if not use_padding: img = img[padding:-padding, padding:-padding, :] # Save the image os.makedirs(os.path.dirname(dest), exist_ok=True) cv2.imwrite(dest, img)