Source code for stouputils.parallel

"""
This module provides utility functions for parallel processing, such as:

- multiprocessing(): Execute a function in parallel using multiprocessing
- multithreading(): Execute a function in parallel using multithreading
- run_in_subprocess(): Execute a function in a subprocess with args and kwargs

I highly encourage you to read the function docstrings to understand when to use each method.

.. image:: https://raw.githubusercontent.com/Stoupy51/stouputils/refs/heads/main/assets/parallel_module.gif
  :alt: stouputils parallel examples
"""

# Imports
import os
import time
from collections.abc import Callable
from typing import Any, TypeVar, cast

from .decorators import LogLevels, handle_error
from .print import BAR_FORMAT, MAGENTA


# Small test functions for doctests
def doctest_square(x: int) -> int:
	return x * x
def doctest_slow(x: int) -> int:
	time.sleep(0.1)
	return x

# Constants
CPU_COUNT: int = cast(int, os.cpu_count())
T = TypeVar("T")
R = TypeVar("R")

# Functions
[docs] @handle_error(error_log=LogLevels.ERROR_TRACEBACK) def multiprocessing( func: Callable[[T], R] | list[Callable[[T], R]], args: list[T], use_starmap: bool = False, chunksize: int = 1, desc: str = "", max_workers: int = CPU_COUNT, delay_first_calls: float = 0, color: str = MAGENTA, bar_format: str = BAR_FORMAT, ascii: bool = False, ) -> list[R]: r""" Method to execute a function in parallel using multiprocessing - For CPU-bound operations where the GIL (Global Interpreter Lock) is a bottleneck. - When the task can be divided into smaller, independent sub-tasks that can be executed concurrently. - For computationally intensive tasks like scientific simulations, data analysis, or machine learning workloads. Args: func (Callable | list[Callable]): Function to execute, or list of functions (one per argument) args (list): List of arguments to pass to the function(s) use_starmap (bool): Whether to use starmap or not (Defaults to False): True means the function will be called like func(\*args[i]) instead of func(args[i]) chunksize (int): Number of arguments to process at a time (Defaults to 1 for proper progress bar display) desc (str): Description displayed in the progress bar (if not provided no progress bar will be displayed) max_workers (int): Number of workers to use (Defaults to CPU_COUNT), -1 means CPU_COUNT delay_first_calls (float): Apply i*delay_first_calls seconds delay to the first "max_workers" calls. For instance, the first process will be delayed by 0 seconds, the second by 1 second, etc. (Defaults to 0): This can be useful to avoid functions being called in the same second. color (str): Color of the progress bar (Defaults to MAGENTA) bar_format (str): Format of the progress bar (Defaults to BAR_FORMAT) ascii (bool): Whether to use ASCII or Unicode characters for the progress bar Returns: list[object]: Results of the function execution Examples: .. code-block:: python > multiprocessing(doctest_square, args=[1, 2, 3]) [1, 4, 9] > multiprocessing(int.__mul__, [(1,2), (3,4), (5,6)], use_starmap=True) [2, 12, 30] > # Using a list of functions (one per argument) > multiprocessing([doctest_square, doctest_square, doctest_square], [1, 2, 3]) [1, 4, 9] > # Will process in parallel with progress bar > multiprocessing(doctest_slow, range(10), desc="Processing") [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] > # Will process in parallel with progress bar and delay the first threads > multiprocessing( . doctest_slow, . range(10), . desc="Processing with delay", . max_workers=2, . delay_first_calls=0.6 . ) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ # Imports import multiprocessing as mp from multiprocessing import Pool from tqdm.auto import tqdm from tqdm.contrib.concurrent import process_map # pyright: ignore[reportUnknownVariableType] # Handle parameters if max_workers == -1: max_workers = CPU_COUNT verbose: bool = desc != "" desc, func, args = _handle_parameters(func, args, use_starmap, delay_first_calls, max_workers, desc, color) if bar_format == BAR_FORMAT: bar_format = bar_format.replace(MAGENTA, color) # Do multiprocessing only if there is more than 1 argument and more than 1 CPU if max_workers > 1 and len(args) > 1: def process() -> list[Any]: if verbose: return list(process_map( func, args, max_workers=max_workers, chunksize=chunksize, desc=desc, bar_format=bar_format, ascii=ascii )) # type: ignore else: with Pool(max_workers) as pool: return list(pool.map(func, args, chunksize=chunksize)) # type: ignore try: return process() except RuntimeError as e: if "SemLock created in a fork context is being shared with a process in a spawn context" in str(e): # Try with alternate start method old_method: str | None = mp.get_start_method(allow_none=True) new_method: str = "spawn" if old_method in (None, "fork") else "fork" mp.set_start_method(new_method, force=True) try: return process() finally: mp.set_start_method(old_method, force=True) else: # Re-raise if it's not the SemLock error raise # Single process execution else: if verbose: return [func(arg) for arg in tqdm(args, total=len(args), desc=desc, bar_format=bar_format, ascii=ascii)] else: return [func(arg) for arg in args]
[docs] @handle_error(error_log=LogLevels.ERROR_TRACEBACK) def multithreading( func: Callable[[T], R] | list[Callable[[T], R]], args: list[T], use_starmap: bool = False, desc: str = "", max_workers: int = CPU_COUNT, delay_first_calls: float = 0, color: str = MAGENTA, bar_format: str = BAR_FORMAT, ascii: bool = False, ) -> list[R]: r""" Method to execute a function in parallel using multithreading, you should use it: - For I/O-bound operations where the GIL is not a bottleneck, such as network requests or disk operations. - When the task involves waiting for external resources, such as network responses or user input. - For operations that involve a lot of waiting, such as GUI event handling or handling user input. Args: func (Callable | list[Callable]): Function to execute, or list of functions (one per argument) args (list): List of arguments to pass to the function(s) use_starmap (bool): Whether to use starmap or not (Defaults to False): True means the function will be called like func(\*args[i]) instead of func(args[i]) desc (str): Description displayed in the progress bar (if not provided no progress bar will be displayed) max_workers (int): Number of workers to use (Defaults to CPU_COUNT), -1 means CPU_COUNT delay_first_calls (float): Apply i*delay_first_calls seconds delay to the first "max_workers" calls. For instance with value to 1, the first thread will be delayed by 0 seconds, the second by 1 second, etc. (Defaults to 0): This can be useful to avoid functions being called in the same second. color (str): Color of the progress bar (Defaults to MAGENTA) bar_format (str): Format of the progress bar (Defaults to BAR_FORMAT) ascii (bool): Whether to use ASCII or Unicode characters for the progress bar Returns: list[object]: Results of the function execution Examples: .. code-block:: python > multithreading(doctest_square, args=[1, 2, 3]) [1, 4, 9] > multithreading(int.__mul__, [(1,2), (3,4), (5,6)], use_starmap=True) [2, 12, 30] > # Using a list of functions (one per argument) > multithreading([doctest_square, doctest_square, doctest_square], [1, 2, 3]) [1, 4, 9] > # Will process in parallel with progress bar > multithreading(doctest_slow, range(10), desc="Threading") [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] > # Will process in parallel with progress bar and delay the first threads > multithreading( . doctest_slow, . range(10), . desc="Threading with delay", . max_workers=2, . delay_first_calls=0.6 . ) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ # Imports from concurrent.futures import ThreadPoolExecutor from tqdm.auto import tqdm # Handle parameters if max_workers == -1: max_workers = CPU_COUNT verbose: bool = desc != "" desc, func, args = _handle_parameters(func, args, use_starmap, delay_first_calls, max_workers, desc, color) if bar_format == BAR_FORMAT: bar_format = bar_format.replace(MAGENTA, color) # Do multithreading only if there is more than 1 argument and more than 1 CPU if max_workers > 1 and len(args) > 1: if verbose: with ThreadPoolExecutor(max_workers) as executor: return list(tqdm(executor.map(func, args), total=len(args), desc=desc, bar_format=bar_format, ascii=ascii)) else: with ThreadPoolExecutor(max_workers) as executor: return list(executor.map(func, args)) # Single process execution else: if verbose: return [func(arg) for arg in tqdm(args, total=len(args), desc=desc, bar_format=bar_format, ascii=ascii)] else: return [func(arg) for arg in args]
[docs] @handle_error(error_log=LogLevels.ERROR_TRACEBACK) def run_in_subprocess( func: Callable[..., R], *args: Any, **kwargs: Any ) -> R: """ Execute a function in a subprocess with positional and keyword arguments. This is useful when you need to run a function in isolation to avoid memory leaks, resource conflicts, or to ensure a clean execution environment. The subprocess will be created, run the function with the provided arguments, and return the result. Args: func (Callable): The function to execute in a subprocess. (SHOULD BE A TOP-LEVEL FUNCTION TO BE PICKLABLE) *args (Any): Positional arguments to pass to the function. **kwargs (Any): Keyword arguments to pass to the function. Returns: R: The return value of the function. Raises: RuntimeError: If the subprocess exits with a non-zero exit code. Examples: .. code-block:: python > # Simple function execution > run_in_subprocess(doctest_square, 5) 25 > # Function with multiple arguments > def add(a: int, b: int) -> int: . return a + b > run_in_subprocess(add, 10, 20) 30 > # Function with keyword arguments > def greet(name: str, greeting: str = "Hello") -> str: . return f"{greeting}, {name}!" > run_in_subprocess(greet, "World", greeting="Hi") 'Hi, World!' """ import multiprocessing as mp from multiprocessing import Queue # Create a queue to get the result from the subprocess result_queue: Queue[R | Exception] = Queue() # Create and start the subprocess using the module-level wrapper process: mp.Process = mp.Process( target=_subprocess_wrapper, args=(result_queue, func, args, kwargs) ) process.start() process.join() # Check exit code if process.exitcode != 0: raise RuntimeError(f"Subprocess failed with exit code {process.exitcode}") # Retrieve the result if not result_queue.empty(): result_or_exception = result_queue.get() if isinstance(result_or_exception, Exception): raise result_or_exception return result_or_exception else: raise RuntimeError("Subprocess did not return any result")
# "Private" function for subprocess wrapper (must be at module level for pickling on Windows)
[docs] def _subprocess_wrapper( result_queue: Any, func: Callable[..., R], args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: """ Wrapper function to execute the target function and store the result in the queue. Must be at module level to be pickable on Windows (spawn context). Args: result_queue (multiprocessing.Queue): Queue to store the result or exception. func (Callable): The target function to execute. args (tuple): Positional arguments for the function. kwargs (dict): Keyword arguments for the function. """ try: result: R = func(*args, **kwargs) result_queue.put(result) except Exception as e: result_queue.put(e)
# "Private" function to use starmap
[docs] def _starmap(args: tuple[Callable[[T], R], list[T]]) -> R: r""" Private function to use starmap using args[0](\*args[1]) Args: args (tuple): Tuple containing the function and the arguments list to pass to the function Returns: object: Result of the function execution """ func, arguments = args return func(*arguments)
# "Private" function to apply delay before calling the target function
[docs] def _delayed_call(args: tuple[Callable[[T], R], float, T]) -> R: """ Private function to apply delay before calling the target function Args: args (tuple): Tuple containing the function, delay in seconds, and the argument to pass to the function Returns: object: Result of the function execution """ func, delay, arg = args time.sleep(delay) return func(arg)
# "Private" function to handle parameters for multiprocessing or multithreading functions
[docs] def _handle_parameters( func: Callable[[T], R] | list[Callable[[T], R]], args: list[T], use_starmap: bool, delay_first_calls: float, max_workers: int, desc: str, color: str ) -> tuple[str, Callable[[T], R], list[T]]: r""" Private function to handle the parameters for multiprocessing or multithreading functions Args: func (Callable | list[Callable]): Function to execute, or list of functions (one per argument) args (list): List of arguments to pass to the function(s) use_starmap (bool): Whether to use starmap or not (Defaults to False): True means the function will be called like func(\*args[i]) instead of func(args[i]) delay_first_calls (int): Apply i*delay_first_calls seconds delay to the first "max_workers" calls. For instance, the first process will be delayed by 0 seconds, the second by 1 second, etc. (Defaults to 0): This can be useful to avoid functions being called in the same second. max_workers (int): Number of workers to use (Defaults to CPU_COUNT) desc (str): Description of the function execution displayed in the progress bar color (str): Color of the progress bar Returns: tuple[str, Callable[[T], R], list[T]]: Tuple containing the description, function, and arguments """ desc = color + desc # Handle list of functions: validate and convert to starmap format if isinstance(func, list): func = cast(list[Callable[[T], R]], func) assert len(func) == len(args), f"Length mismatch: {len(func)} functions but {len(args)} arguments" args = [(f, arg) for f, arg in zip(func, args, strict=False)] # type: ignore func = _starmap # type: ignore # If use_starmap is True, we use the _starmap function elif use_starmap: args = [(func, arg) for arg in args] # type: ignore func = _starmap # type: ignore # Prepare delayed function calls if delay_first_calls is set if delay_first_calls > 0: args = [ (func, i * delay_first_calls if i < max_workers else 0, arg) # type: ignore for i, arg in enumerate(args) ] func = _delayed_call # type: ignore return desc, func, args # type: ignore