Source code for stouputils.decorators

"""
This module provides decorators for various purposes:

- silent(): Make a function silent (disable stdout, and stderr if specified) (alternative to stouputils.ctx.Muffle)
- measure_time(): Measure the execution time of a function and print it with the given print function
- handle_error(): Handle an error with different log levels
- simple_cache(): Easy cache function with parameter caching method
- deprecated(): Mark a function as deprecated

.. image:: https://raw.githubusercontent.com/Stoupy51/stouputils/refs/heads/main/assets/decorators_module_1.gif
  :alt: stouputils decorators examples

.. image:: https://raw.githubusercontent.com/Stoupy51/stouputils/refs/heads/main/assets/decorators_module_2.gif
  :alt: stouputils decorators examples
"""

# Imports
from collections.abc import Callable
from enum import Enum
from functools import wraps
from pickle import dumps as pickle_dumps
from traceback import format_exc
from typing import Any, Literal
import time

from .print import debug, error, warning
from .ctx import Muffle, MeasureTime


def get_func_name(func: Callable[..., Any]) -> str:
	try:
		return func.__name__
	except AttributeError:
		return "<unknown>"

[docs] def get_wrapper_name(decorator_name: str, func: Callable[..., Any]) -> str: """ Get a descriptive name for a wrapper function. Args: decorator_name (str): Name of the decorator func (Callable[..., Any]): Function being decorated Returns: str: Combined name for the wrapper function """ return f"{decorator_name}_{get_func_name(func)}"
# Decorator that make a function silent (disable stdout)
[docs] def silent( func: Callable[..., Any] | None = None, *, mute_stderr: bool = False ) -> Callable[..., Any]: """ Decorator that makes a function silent (disable stdout, and stderr if specified). Alternative to stouputils.ctx.Muffle. Args: func (Callable[..., Any] | None): Function to make silent mute_stderr (bool): Whether to mute stderr or not Examples: >>> @silent ... def test(): ... print("Hello, world!") >>> test() >>> @silent(mute_stderr=True) ... def test2(): ... print("Hello, world!") >>> test2() >>> silent(print)("Hello, world!") """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Use Muffle context manager to silence output with Muffle(mute_stderr=mute_stderr): return func(*args, **kwargs) wrapper.__name__ = get_wrapper_name("stouputils.decorators.silent", func) return wrapper # Handle both @silent and @silent(mute_stderr=...) if func is None: return decorator return decorator(func)
# Execution time decorator
[docs] def measure_time( print_func: Callable[..., None] = debug, message: str = "", perf_counter: bool = True ) -> Callable[..., Any]: """ Decorator that will measure the execution time of a function and print it with the given print function Args: print_func (Callable): Function to use to print the execution time (e.g. debug, info, warning, error, etc.) message (str): Message to display with the execution time (e.g. "Execution time of Something"), defaults to "Execution time of {func.__name__}" perf_counter (bool): Whether to use time.perf_counter_ns or time.time_ns Returns: Callable: Decorator to measure the time of the function. Examples: .. code-block:: python > @measure_time(info) > def test(): > pass > test() # [INFO HH:MM:SS] Execution time of test: 0.000ms (400ns) """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: # Set the message if not specified, else use the provided one new_msg: str = message if message else f"Execution time of {get_func_name(func)}" @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: with MeasureTime(print_func=print_func, message=new_msg, perf_counter=perf_counter): return func(*args, **kwargs) wrapper.__name__ = get_wrapper_name("stouputils.decorators.measure_time", func) return wrapper return decorator
# Decorator that handle an error with different log levels
[docs] class LogLevels(Enum): """ Log level for the errors in the decorator handle_error() """ NONE = 0 """ Do nothing """ WARNING = 1 """ Show as warning """ WARNING_TRACEBACK = 2 """ Show as warning with traceback """ ERROR_TRACEBACK = 3 """ Show as error with traceback """ RAISE_EXCEPTION = 4 """ Raise exception """
force_raise_exception: bool = False """ If true, error_log parameter will be set to RAISE_EXCEPTION for every next handle_error calls, useful for doctests """
[docs] def handle_error( func: Callable[..., Any] | None = None, *, exceptions: tuple[type[BaseException], ...] | type[BaseException] = (Exception,), message: str = "", error_log: LogLevels = LogLevels.WARNING_TRACEBACK, sleep_time: float = 0.0 ) -> Callable[..., Any]: """ Decorator that handle an error with different log levels. Args: func (Callable[..., Any] | None): Function to decorate exceptions (tuple[type[BaseException]], ...): Exceptions to handle message (str): Message to display with the error. (e.g. "Error during something") error_log (LogLevels): Log level for the errors LogLevels.NONE: None LogLevels.WARNING: Show as warning LogLevels.WARNING_TRACEBACK: Show as warning with traceback LogLevels.ERROR_TRACEBACK: Show as error with traceback LogLevels.RAISE_EXCEPTION: Raise exception sleep_time (float): Time to sleep after the error (e.g. 0.0 to not sleep, 1.0 to sleep for 1 second) Examples: >>> @handle_error ... def might_fail(): ... raise ValueError("Let's fail") >>> @handle_error(error_log=LogLevels.WARNING) ... def test(): ... raise ValueError("Let's fail") >>> # test() # [WARNING HH:MM:SS] Error during test: (ValueError) Let's fail """ # Update error_log if needed if force_raise_exception: error_log = LogLevels.RAISE_EXCEPTION def decorator(func: Callable[..., Any]) -> Callable[..., Any]: if message != "": msg: str = f"{message}, " else: msg: str = message @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: try: return func(*args, **kwargs) except exceptions as e: if error_log == LogLevels.WARNING: warning(f"{msg}Error during {get_func_name(func)}: ({type(e).__name__}) {e}") elif error_log == LogLevels.WARNING_TRACEBACK: warning(f"{msg}Error during {get_func_name(func)}:\n{format_exc()}") elif error_log == LogLevels.ERROR_TRACEBACK: error(f"{msg}Error during {get_func_name(func)}:\n{format_exc()}", exit=True) elif error_log == LogLevels.RAISE_EXCEPTION: raise e # Sleep for the specified time, only if the error_log is not ERROR_TRACEBACK (because it's blocking) if sleep_time > 0.0 and error_log != LogLevels.ERROR_TRACEBACK: time.sleep(sleep_time) wrapper.__name__ = get_wrapper_name("stouputils.decorators.handle_error", func) return wrapper # Handle both @handle_error and @handle_error(exceptions=..., message=..., error_log=...) if func is None: return decorator return decorator(func)
# Easy cache function with parameter caching method
[docs] def simple_cache( func: Callable[..., Any] | None = None, *, method: Literal["str", "pickle"] = "str" ) -> Callable[..., Any]: """ Decorator that caches the result of a function based on its arguments. The str method is often faster than the pickle method (by a little). Args: func (Callable[..., Any] | None): Function to cache method (Literal["str", "pickle"]): The method to use for caching. Returns: Callable[..., Any]: A decorator that caches the result of a function. Examples: >>> @simple_cache ... def test1(a: int, b: int) -> int: ... return a + b >>> @simple_cache(method="str") ... def test2(a: int, b: int) -> int: ... return a + b >>> test2(1, 2) 3 >>> test2(1, 2) 3 >>> test2(3, 4) 7 """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: # Create the cache dict cache_dict: dict[bytes, Any] = {} # Create the wrapper @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Get the hashed key if method == "str": hashed: bytes = str(args).encode() + str(kwargs).encode() elif method == "pickle": hashed: bytes = pickle_dumps((args, kwargs)) else: raise ValueError("Invalid caching method. Supported methods are 'str' and 'pickle'.") # If the key is in the cache, return it if hashed in cache_dict: return cache_dict[hashed] # Else, call the function and add the result to the cache else: result: Any = func(*args, **kwargs) cache_dict[hashed] = result return result # Return the wrapper wrapper.__name__ = get_wrapper_name("stouputils.decorators.simple_cache", func) return wrapper # Handle both @simple_cache and @simple_cache(method=...) if func is None: return decorator return decorator(func)
[docs] def deprecated( func: Callable[..., Any] | None = None, *, message: str = "", error_log: LogLevels = LogLevels.WARNING ) -> Callable[..., Any]: """ Decorator that marks a function as deprecated. Args: func (Callable[..., Any] | None): Function to mark as deprecated message (str): Additional message to display with the deprecation warning error_log (LogLevels): Log level for the deprecation warning LogLevels.NONE: None LogLevels.WARNING: Show as warning LogLevels.WARNING_TRACEBACK: Show as warning with traceback LogLevels.ERROR_TRACEBACK: Show as error with traceback LogLevels.RAISE_EXCEPTION: Raise exception Returns: Callable[..., Any]: Decorator that marks a function as deprecated Examples: >>> @deprecated ... def old_function(): ... pass >>> @deprecated(message="Use 'new_function()' instead", error_log=LogLevels.WARNING) ... def another_old_function(): ... pass """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: # Build deprecation message msg: str = f"Function '{get_func_name(func)}()' is deprecated" if message: msg += f". {message}" # Handle deprecation warning based on log level if error_log == LogLevels.WARNING: warning(msg) elif error_log == LogLevels.WARNING_TRACEBACK: warning(f"{msg}\n{format_exc()}") elif error_log == LogLevels.ERROR_TRACEBACK: error(f"{msg}\n{format_exc()}", exit=True) elif error_log == LogLevels.RAISE_EXCEPTION: raise DeprecationWarning(msg) # Call the original function return func(*args, **kwargs) wrapper.__name__ = get_wrapper_name("stouputils.decorators.deprecated", func) return wrapper # Handle both @deprecated and @deprecated(message=..., error_log=...) if func is None: return decorator return decorator(func)
[docs] def abstract( func: Callable[..., Any] | None = None, *, error_log: LogLevels = LogLevels.RAISE_EXCEPTION ) -> Callable[..., Any]: """ Decorator that marks a function as abstract. Contrary to the abstractmethod decorator from the abc module that raises a TypeError when you try to instantiate a class that has abstract methods, this decorator raises a NotImplementedError ONLY when the decorated function is called, indicating that the function must be implemented by a subclass. Args: func (Callable[..., Any] | None): The function to mark as abstract error_log (LogLevels): Log level for the error handling LogLevels.NONE: None LogLevels.WARNING: Show as warning LogLevels.WARNING_TRACEBACK: Show as warning with traceback LogLevels.ERROR_TRACEBACK: Show as error with traceback LogLevels.RAISE_EXCEPTION: Raise exception Returns: Callable[..., Any]: Decorator that raises NotImplementedError when called Examples: >>> class Base: ... @abstract ... def method(self): ... pass >>> Base().method() Traceback (most recent call last): ... NotImplementedError: Function 'method' is abstract and must be implemented by a subclass """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: message: str = f"Function '{get_func_name(func)}' is abstract and must be implemented by a subclass" if not func.__doc__: func.__doc__ = message @wraps(func) @handle_error(exceptions=NotImplementedError, error_log=error_log) def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: raise NotImplementedError(message) wrapper.__name__ = get_wrapper_name("stouputils.decorators.abstract", func) return wrapper # Handle both @abstract and @abstract(error_log=...) if func is None: return decorator return decorator(func)