Source code for stouputils.decorators

"""
This module provides decorators for various purposes:

- silent(): Make a function silent (disable stdout, and stderr if specified) (alternative to stouputils.ctx.Muffle)
- measure_time(): Measure the execution time of a function and print it with the given print function
- handle_error(): Handle an error with different log levels
- simple_cache(): Easy cache function with parameter caching method
- deprecated(): Mark a function as deprecated

.. image:: https://raw.githubusercontent.com/Stoupy51/stouputils/refs/heads/main/assets/decorators_module_1.gif
  :alt: stouputils decorators examples

.. image:: https://raw.githubusercontent.com/Stoupy51/stouputils/refs/heads/main/assets/decorators_module_2.gif
  :alt: stouputils decorators examples
"""

# Imports
import os
import sys
import time
from enum import Enum
from pickle import dumps as pickle_dumps
from traceback import format_exc
from typing import Callable, Literal, Any
from functools import wraps
from .print import debug, warning, error


# Decorator that make a function silent (disable stdout)
[docs] def silent( func: Callable[..., Any], mute_stderr: bool = False ) -> Callable[..., Any]: """ Decorator that makes a function silent (disable stdout, and stderr if specified). Alternative to stouputils.ctx.Muffle. Args: func (Callable[..., Any]): Function to make silent mute_stderr (bool): Whether to mute stderr or not Examples: >>> @silent ... def test(): ... print("Hello, world!") >>> test() >>> silent(print)("Hello, world!") """ @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Disable stdout and stderr _original_stdout: Any = sys.stdout _original_stderr: Any = None sys.stdout = open(os.devnull, "w", encoding="utf-8") if mute_stderr: _original_stderr = sys.stderr sys.stderr = open(os.devnull, "w", encoding="utf-8") # Call the function result: Any = func(*args, **kwargs) # Re-Enable stdout and stderr sys.stdout.close() sys.stdout = _original_stdout if mute_stderr: sys.stderr.close() sys.stderr = _original_stderr return result return wrapper
# Execution time decorator
[docs] def measure_time( print_func: Callable[..., None] = debug, message: str = "", perf_counter: bool = True ) -> Callable[..., Any]: """ Decorator that will measure the execution time of a function and print it with the given print function Args: print_func (Callable): Function to use to print the execution time (e.g. debug, info, warning, error, etc.) message (str): Message to display with the execution time (e.g. "Execution time of Something"), defaults to "Execution time of {func.__name__}" perf_counter (bool): Whether to use time.perf_counter_ns or time.time_ns Returns: Callable: Decorator to measure the time of the function. Examples: .. code-block:: python > @measure_time(info) > def test(): > pass > test() # [INFO HH:MM:SS] Execution time of test: 0.000ms (400ns) """ ns: Callable[[], int] = time.perf_counter_ns if perf_counter else time.time_ns def decorator(func: Callable[..., Any]) -> Callable[..., Any]: # Set the message if not specified nonlocal message if not message: message = f"Execution time of {func.__name__}" @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Measure the execution time (nanoseconds and seconds) start_ns: int = ns() result = func(*args, **kwargs) total_ns: int = ns() - start_ns total_ms: float = total_ns / 1_000_000 total_s: float = total_ns / 1_000_000_000 # Print the execution time (nanoseconds if less than 0.3s, seconds otherwise) if total_ms < 300: print_func(f"{message}: {total_ms:.3f}ms ({total_ns}ns)") elif total_s < 60: print_func(f"{message}: {(total_s):.5f}s") else: minutes: int = int(total_s) // 60 seconds: int = int(total_s) % 60 if minutes < 60: print_func(f"{message}: {minutes}m {seconds}s") else: hours: int = minutes // 60 minutes: int = minutes % 60 if hours < 24: print_func(f"{message}: {hours}h {minutes}m {seconds}s") else: days: int = hours // 24 hours: int = hours % 24 print_func(f"{message}: {days}d {hours}h {minutes}m {seconds}s") return result return wrapper return decorator
# Decorator that handle an error with different log levels
[docs] class LogLevels(Enum): """ Log level for the errors in the decorator handle_error() """ NONE = 0 """ Do nothing """ WARNING = 1 """ Show as warning """ WARNING_TRACEBACK = 2 """ Show as warning with traceback """ ERROR_TRACEBACK = 3 """ Show as error with traceback """ RAISE_EXCEPTION = 4 """ Raise exception """
force_raise_exception: bool = False """ If true, the error_log parameter will be set to RAISE_EXCEPTION for every next handle_error calls, useful for doctests """
[docs] def handle_error( exceptions: tuple[type[BaseException], ...] | type[BaseException] = (Exception,), message: str = "", error_log: LogLevels = LogLevels.WARNING_TRACEBACK ) -> Callable[..., Any]: """ Decorator that handle an error with different log levels. Args: exceptions (tuple[type[BaseException]], ...): Exceptions to handle message (str): Message to display with the error. (e.g. "Error during something") error_log (LogLevels): Log level for the errors LogLevels.NONE: None LogLevels.WARNING: Show as warning LogLevels.WARNING_TRACEBACK: Show as warning with traceback LogLevels.ERROR_TRACEBACK: Show as error with traceback LogLevels.RAISE_EXCEPTION: Raise exception (as if the decorator didn't exist) Examples: .. code-block:: python > @handle_error(error_log=LogLevels.WARNING) > def test(): > raise ValueError("Let's fail") > test() # [WARNING HH:MM:SS] Error during test: (ValueError) Let's fail """ # Update error_log if needed if force_raise_exception: error_log = LogLevels.RAISE_EXCEPTION def decorator(func: Callable[..., Any]) -> Callable[..., Any]: if message != "": msg: str = f"{message}, " else: msg: str = message @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: try: return func(*args, **kwargs) except exceptions as e: if error_log == LogLevels.WARNING: warning(f"{msg}Error during {func.__name__}: ({type(e).__name__}) {e}") elif error_log == LogLevels.WARNING_TRACEBACK: warning(f"{msg}Error during {func.__name__}:\n{format_exc()}") elif error_log == LogLevels.ERROR_TRACEBACK: error(f"{msg}Error during {func.__name__}:\n{format_exc()}", exit=True) elif error_log == LogLevels.RAISE_EXCEPTION: raise e return wrapper return decorator
# Easy cache function with parameter caching method
[docs] def simple_cache(method: Literal["str", "pickle"] = "str") -> Callable[..., Callable[..., Any]]: """ Decorator that caches the result of a function based on its arguments. The str method is often faster than the pickle method (by a little). Args: method (Literal["str", "pickle"]): The method to use for caching. Returns: Callable[..., Callable[..., Any]]: A decorator that caches the result of a function. Examples: >>> @simple_cache(method="str") ... def test(a: int, b: int) -> int: ... return a + b >>> test(1, 2) # 3 3 >>> test(1, 2) # 3 3 >>> test(3, 4) # 7 7 """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: # Create the cache dict cache_dict: dict[bytes, Any] = {} # Create the wrapper @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Get the hashed key if method == "str": hashed: bytes = str(args).encode() + str(kwargs).encode() elif method == "pickle": hashed: bytes = pickle_dumps((args, kwargs)) else: raise ValueError("Invalid caching method. Supported methods are 'str' and 'pickle'.") # If the key is in the cache, return it if hashed in cache_dict: return cache_dict[hashed] # Else, call the function and add the result to the cache else: result: Any = func(*args, **kwargs) cache_dict[hashed] = result return result # Return the wrapper return wrapper # Return the decorator return decorator
[docs] def deprecated( message: str = "", error_log: LogLevels = LogLevels.WARNING ) -> Callable[..., Any]: """ Decorator that marks a function as deprecated. Args: message (str): Additional message to display with the deprecation warning error_log (LogLevels): Log level for the deprecation warning LogLevels.NONE: None LogLevels.WARNING: Show as warning LogLevels.WARNING_TRACEBACK: Show as warning with traceback LogLevels.ERROR_TRACEBACK: Show as error with traceback LogLevels.RAISE_EXCEPTION: Raise exception Returns: Callable[..., Any]: Decorator that marks a function as deprecated Examples: .. code-block:: python > @deprecated(message="Use 'this_function()' instead", error_log=LogLevels.WARNING) > def test(): > pass > test() # [WARNING HH:MM:SS] Function 'test()' is deprecated. Use 'this_function()' instead """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Build deprecation message msg: str = f"Function '{func.__name__}()' is deprecated" if message: msg += f". {message}" # Handle deprecation warning based on log level if error_log == LogLevels.WARNING: warning(msg) elif error_log == LogLevels.WARNING_TRACEBACK: warning(f"{msg}\n{format_exc()}") elif error_log == LogLevels.ERROR_TRACEBACK: error(f"{msg}\n{format_exc()}", exit=True) elif error_log == LogLevels.RAISE_EXCEPTION: raise DeprecationWarning(msg) # Call the original function return func(*args, **kwargs) return wrapper return decorator