# Imports
import time
from collections.abc import Callable
from typing import Any
from ..typing import JsonDict
from .capturer import CaptureOutput
[docs]
class RemoteSubprocessError(RuntimeError):
""" Raised in the parent when the child raised an exception - contains the child's formatted traceback. """
def __init__(self, exc_type: str, exc_repr: str, traceback_str: str):
msg = f"Exception in subprocess ({exc_type}): {exc_repr}\n\nRemote traceback:\n{traceback_str}"
super().__init__(msg)
self.remote_type = exc_type
self.remote_repr = exc_repr
self.remote_traceback = traceback_str
[docs]
def run_in_subprocess[R](
func: Callable[..., R],
*args: Any,
timeout: float | None = None,
no_join: bool = False,
capture_output: bool = False,
**kwargs: Any
) -> R:
""" Execute a function in a subprocess with positional and keyword arguments.
This is useful when you need to run a function in isolation to avoid memory leaks,
resource conflicts, or to ensure a clean execution environment. The subprocess will
be created, run the function with the provided arguments, and return the result.
Args:
func (Callable): The function to execute in a subprocess.
(SHOULD BE A TOP-LEVEL FUNCTION TO BE PICKLABLE)
*args (Any): Positional arguments to pass to the function.
timeout (float | None): Maximum time in seconds to wait for the subprocess.
If None, wait indefinitely. If the subprocess exceeds this time, it will be terminated.
no_join (bool): If True, do not wait for the subprocess to finish (fire-and-forget).
capture_output (bool): If True, capture the subprocess' stdout/stderr and relay it
in real time to the parent's stdout. This enables seeing print() output
from the subprocess in the main process.
**kwargs (Any): Keyword arguments to pass to the function.
Returns:
R: The return value of the function.
Raises:
RemoteSubprocessError: If the child raised an exception - contains the child's formatted traceback.
RuntimeError: If the subprocess exits with a non-zero exit code or did not return a result.
TimeoutError: If the subprocess exceeds the specified timeout.
Examples:
.. code-block:: python
> # Simple function execution
> run_in_subprocess(doctest_square, 5)
25
> # Function with multiple arguments
> def add(a: int, b: int) -> int:
. return a + b
> run_in_subprocess(add, 10, 20)
30
> # Function with keyword arguments
> def greet(name: str, greeting: str = "Hello") -> str:
. return f"{greeting}, {name}!"
> run_in_subprocess(greet, "World", greeting="Hi")
'Hi, World!'
> # With timeout to prevent hanging
> run_in_subprocess(some_gpu_func, data, timeout=300.0)
"""
import multiprocessing as mp
from multiprocessing import Queue
# Create a queue to get the result from the subprocess (only if we need to wait)
result_queue: Queue[JsonDict] | None = None if no_join else Queue()
# Optionally setup output capture pipe and listener
capturer: CaptureOutput | None = None
if capture_output:
capturer = CaptureOutput()
# Create and start the subprocess using the module-level wrapper
process: mp.Process = mp.Process(
target=_subprocess_wrapper,
args=(result_queue, func, args, kwargs),
kwargs={"_capturer": capturer}
)
process.start()
# For capture_output we must close the parent's copy of the write fd and start listener
if capturer is not None:
capturer.parent_close_write()
capturer.start_listener()
# Detach process if no_join (fire-and-forget)
if result_queue is None:
# If capturing, leave listener running in background (daemon)
return None # type: ignore
# Use a single try/finally to ensure we always drain the listener once
# and avoid repeating join calls in multiple branches.
try:
process.join(timeout=timeout)
# Check if process is still alive (timed out)
if process.is_alive():
process.terminate()
time.sleep(0.5) # Give it a moment to terminate gracefully
if process.is_alive():
process.kill()
process.join()
raise TimeoutError(f"Subprocess exceeded timeout of {timeout} seconds and was terminated")
# Retrieve the payload if present
result_payload: JsonDict | None = result_queue.get_nowait() if not result_queue.empty() else None
# If the child sent a structured exception, raise it with the formatted traceback
if isinstance(result_payload, dict):
if result_payload.pop("ok", False) is False:
raise RemoteSubprocessError(**result_payload)
else:
return result_payload["result"]
# Raise an error according to the exit code presence
if process.exitcode != 0:
raise RuntimeError(f"Subprocess failed with exit code {process.exitcode}")
raise RuntimeError("Subprocess did not return any result")
# Finally, ensure we drain/join the listener if capturing output
finally:
if capturer is not None:
capturer.join_listener(timeout=5.0)
# "Private" function for subprocess wrapper (must be at module level for pickling on Windows)
[docs]
def _subprocess_wrapper[R](
result_queue: Any,
func: Callable[..., R],
args: tuple[Any, ...],
kwargs: dict[str, Any],
_capturer: CaptureOutput | None = None
) -> None:
""" Wrapper function to execute the target function and store the result in the queue.
Must be at module level to be pickable on Windows (spawn context).
Args:
result_queue (multiprocessing.Queue | None): Queue to store the result or exception (None if detached).
func (Callable): The target function to execute.
args (tuple): Positional arguments for the function.
kwargs (dict): Keyword arguments for the function.
_capturer (CaptureOutput | None): Optional CaptureOutput instance for stdout capture.
"""
try:
# If a CaptureOutput instance was passed, redirect stdout/stderr to the pipe.
if _capturer is not None:
_capturer.redirect()
# Execute the target function and put the result in the queue
result: R = func(*args, **kwargs)
if result_queue is not None:
result_queue.put({"ok": True, "result": result})
# Handle cleanup and exceptions
except Exception as e:
if result_queue is not None:
try:
import traceback
tb = traceback.format_exc()
result_queue.put({
"ok": False,
"exc_type": e.__class__.__name__,
"exc_repr": repr(e),
"traceback_str": tb,
})
except Exception:
# Nothing we can do if even this fails
pass