Source code for stouputils.parallel.multi


# Imports
import time
from collections.abc import Callable, Iterable
from typing import Any

from ..ctx import SetMPStartMethod
from ..print import BAR_FORMAT, MAGENTA
from .capturer import CaptureOutput
from .common import CPU_COUNT, handle_parameters, nice_wrapper


# Small test functions for doctests
def doctest_square(x: int) -> int:
	return x * x
def doctest_slow(x: int) -> int:
	time.sleep(0.1)
	return x

# Functions
[docs] def multiprocessing[T, R]( func: Callable[..., R] | list[Callable[..., R]], args: Iterable[T], use_starmap: bool = False, chunksize: int = 1, desc: str = "", max_workers: int | float = CPU_COUNT, capture_output: bool = False, delay_first_calls: float = 0, nice: int | None = None, color: str = MAGENTA, bar_format: str = BAR_FORMAT, ascii: bool = False, smooth_tqdm: bool = True, **tqdm_kwargs: Any ) -> list[R]: r""" Method to execute a function in parallel using multiprocessing - For CPU-bound operations where the GIL (Global Interpreter Lock) is a bottleneck. - When the task can be divided into smaller, independent sub-tasks that can be executed concurrently. - For computationally intensive tasks like scientific simulations, data analysis, or machine learning workloads. Args: func (Callable | list[Callable]): Function to execute, or list of functions (one per argument) args (Iterable): Iterable of arguments to pass to the function(s) use_starmap (bool): Whether to use starmap or not (Defaults to False): True means the function will be called like func(\*args[i]) instead of func(args[i]) chunksize (int): Number of arguments to process at a time (Defaults to 1 for proper progress bar display) desc (str): Description displayed in the progress bar (if not provided no progress bar will be displayed) max_workers (int | float): Number of workers to use (Defaults to CPU_COUNT), -1 means CPU_COUNT. If float between 0 and 1, it's treated as a percentage of CPU_COUNT. If negative float between -1 and 0, it's treated as a percentage of len(args). capture_output (bool): Whether to capture stdout/stderr from the worker processes (Defaults to True) delay_first_calls (float): Apply i*delay_first_calls seconds delay to the first "max_workers" calls. For instance, the first process will be delayed by 0 seconds, the second by 1 second, etc. (Defaults to 0): This can be useful to avoid functions being called in the same second. nice (int | None): Adjust the priority of worker processes (Defaults to None). Use Unix-style values: -20 (highest priority) to 19 (lowest priority). Positive values reduce priority, negative values increase it. Automatically converted to appropriate priority class on Windows. If None, no priority adjustment is made. color (str): Color of the progress bar (Defaults to MAGENTA) bar_format (str): Format of the progress bar (Defaults to BAR_FORMAT) ascii (bool): Whether to use ASCII or Unicode characters for the progress bar smooth_tqdm (bool): Whether to enable smooth progress bar updates by setting miniters and mininterval (Defaults to True) **tqdm_kwargs (Any): Additional keyword arguments to pass to tqdm Returns: list[object]: Results of the function execution Examples: .. code-block:: python > multiprocessing(doctest_square, args=[1, 2, 3]) [1, 4, 9] > multiprocessing(int.__mul__, [(1,2), (3,4), (5,6)], use_starmap=True) [2, 12, 30] > # Using a list of functions (one per argument) > multiprocessing([doctest_square, doctest_square, doctest_square], [1, 2, 3]) [1, 4, 9] > # Will process in parallel with progress bar > multiprocessing(doctest_slow, range(10), desc="Processing") [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] > # Will process in parallel with progress bar and delay the first threads > multiprocessing( . doctest_slow, . range(10), . desc="Processing with delay", . max_workers=2, . delay_first_calls=0.6 . ) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ # Imports import multiprocessing as mp from multiprocessing import Pool from tqdm.auto import tqdm from tqdm.contrib.concurrent import process_map # pyright: ignore[reportUnknownVariableType] # Handle parameters args = list(args) # Ensure we have a list (not other iterable) if max_workers == -1: max_workers = CPU_COUNT if isinstance(max_workers, float): if max_workers > 0: assert max_workers <= 1, "max_workers as positive float must be between 0 and 1 (percentage of CPU_COUNT)" max_workers = int(max_workers * CPU_COUNT) else: assert -1 <= max_workers < 0, "max_workers as negative float must be between -1 and 0 (percentage of len(args))" max_workers = int(-max_workers * len(args)) verbose: bool = desc != "" desc, func, args = handle_parameters(func, args, use_starmap, delay_first_calls, max_workers, desc, color) if bar_format == BAR_FORMAT: bar_format = bar_format.replace(MAGENTA, color) if smooth_tqdm: tqdm_kwargs.setdefault("mininterval", 0.0) try: total = len(args) # type: ignore import shutil width = shutil.get_terminal_size().columns tqdm_kwargs.setdefault("miniters", max(1, total // width)) except (TypeError, OSError): tqdm_kwargs.setdefault("miniters", 1) # Do multiprocessing only if there is more than 1 argument and more than 1 CPU if max_workers > 1 and len(args) > 1: # Wrap function with nice if specified if nice is not None: wrapped_args = [(nice, func, arg) for arg in args] wrapped_func = nice_wrapper else: wrapped_args = args wrapped_func = func # Capture output if specified capturer: CaptureOutput | None = None if capture_output: capturer = CaptureOutput() capturer.start_listener() wrapped_args = [(capturer, wrapped_func, arg) for arg in wrapped_args] wrapped_func = capture_subprocess_output def process() -> list[Any]: if verbose: return list(process_map( wrapped_func, wrapped_args, max_workers=max_workers, chunksize=chunksize, desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs )) # type: ignore else: with Pool(max_workers) as pool: return list(pool.map(wrapped_func, wrapped_args, chunksize=chunksize)) # type: ignore try: return process() except RuntimeError as e: if "SemLock created in a fork context is being shared with a process in a spawn context" in str(e): # Try with alternate start method with SetMPStartMethod("spawn" if mp.get_start_method() != "spawn" else "fork"): return process() else: # Re-raise if it's not the SemLock error raise finally: if capturer is not None: capturer.parent_close_write() capturer.join_listener(timeout=5.0) # Single process execution else: if verbose: return [func(arg) for arg in tqdm(args, total=len(args), desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs)] else: return [func(arg) for arg in args]
[docs] def multithreading[T, R]( func: Callable[..., R] | list[Callable[..., R]], args: Iterable[T], use_starmap: bool = False, desc: str = "", max_workers: int | float = CPU_COUNT, delay_first_calls: float = 0, color: str = MAGENTA, bar_format: str = BAR_FORMAT, ascii: bool = False, smooth_tqdm: bool = True, **tqdm_kwargs: Any ) -> list[R]: r""" Method to execute a function in parallel using multithreading, you should use it: - For I/O-bound operations where the GIL is not a bottleneck, such as network requests or disk operations. - When the task involves waiting for external resources, such as network responses or user input. - For operations that involve a lot of waiting, such as GUI event handling or handling user input. Args: func (Callable | list[Callable]): Function to execute, or list of functions (one per argument) args (Iterable): Iterable of arguments to pass to the function(s) use_starmap (bool): Whether to use starmap or not (Defaults to False): True means the function will be called like func(\*args[i]) instead of func(args[i]) desc (str): Description displayed in the progress bar (if not provided no progress bar will be displayed) max_workers (int | float): Number of workers to use (Defaults to CPU_COUNT), -1 means CPU_COUNT. If float between 0 and 1, it's treated as a percentage of CPU_COUNT. If negative float between -1 and 0, it's treated as a percentage of len(args). delay_first_calls (float): Apply i*delay_first_calls seconds delay to the first "max_workers" calls. For instance with value to 1, the first thread will be delayed by 0 seconds, the second by 1 second, etc. (Defaults to 0): This can be useful to avoid functions being called in the same second. color (str): Color of the progress bar (Defaults to MAGENTA) bar_format (str): Format of the progress bar (Defaults to BAR_FORMAT) ascii (bool): Whether to use ASCII or Unicode characters for the progress bar smooth_tqdm (bool): Whether to enable smooth progress bar updates by setting miniters and mininterval (Defaults to True) **tqdm_kwargs (Any): Additional keyword arguments to pass to tqdm Returns: list[object]: Results of the function execution Examples: .. code-block:: python > multithreading(doctest_square, args=[1, 2, 3]) [1, 4, 9] > multithreading(int.__mul__, [(1,2), (3,4), (5,6)], use_starmap=True) [2, 12, 30] > # Using a list of functions (one per argument) > multithreading([doctest_square, doctest_square, doctest_square], [1, 2, 3]) [1, 4, 9] > # Will process in parallel with progress bar > multithreading(doctest_slow, range(10), desc="Threading") [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] > # Will process in parallel with progress bar and delay the first threads > multithreading( . doctest_slow, . range(10), . desc="Threading with delay", . max_workers=2, . delay_first_calls=0.6 . ) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ # Imports from concurrent.futures import ThreadPoolExecutor from tqdm.auto import tqdm # Handle parameters args = list(args) # Ensure we have a list (not other iterable) if max_workers == -1: max_workers = CPU_COUNT if isinstance(max_workers, float): if max_workers > 0: assert max_workers <= 1, "max_workers as positive float must be between 0 and 1 (percentage of CPU_COUNT)" max_workers = int(max_workers * CPU_COUNT) else: assert -1 <= max_workers < 0, "max_workers as negative float must be between -1 and 0 (percentage of len(args))" max_workers = int(-max_workers * len(args)) verbose: bool = desc != "" desc, func, args = handle_parameters(func, args, use_starmap, delay_first_calls, max_workers, desc, color) if bar_format == BAR_FORMAT: bar_format = bar_format.replace(MAGENTA, color) if smooth_tqdm: tqdm_kwargs.setdefault("mininterval", 0.0) try: total = len(args) # type: ignore import shutil width = shutil.get_terminal_size().columns tqdm_kwargs.setdefault("miniters", max(1, total // width)) except (TypeError, OSError): tqdm_kwargs.setdefault("miniters", 1) # Do multithreading only if there is more than 1 argument and more than 1 CPU if max_workers > 1 and len(args) > 1: if verbose: with ThreadPoolExecutor(max_workers) as executor: return list(tqdm(executor.map(func, args), total=len(args), desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs)) else: with ThreadPoolExecutor(max_workers) as executor: return list(executor.map(func, args)) # Single process execution else: if verbose: return [func(arg) for arg in tqdm(args, total=len(args), desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs)] else: return [func(arg) for arg in args]
# "Private" function for capturing multiprocessing subprocess
[docs] def capture_subprocess_output[T, R](args: tuple[CaptureOutput, Callable[[T], R], T]) -> R: """ Wrapper function to execute the target function in a subprocess with optional output capture. Args: tuple[CaptureOutput,Callable,T]: Tuple containing: CaptureOutput: Capturer object to redirect stdout/stderr Callable: Target function to execute T: Argument to pass to the target function """ capturer, func, arg = args capturer.redirect() return func(arg)