""" Async utility functions for running blocking operations Global thread pool manager for the application """ import asyncio import threading import os from concurrent.futures import ThreadPoolExecutor, Future from typing import Any, Callable, Iterator, Optional from src.common.logging_config import get_logger logger = get_logger(__name__) class ThreadPoolManager: """ 全局线程池管理器(单例模式) 提供三种类型的线程池: - io: IO密集型操作(网络请求、文件读写) - cpu: CPU密集型操作 - parser: 解析任务(图像/PDF/OCR/QA生成) """ _instance = None _lock = threading.Lock() _initialized = False # 线程池配置 _pool_configs = { "io": {"max_workers": 20, "thread_name_prefix": "io_"}, "cpu": {"max_workers": os.cpu_count() or 4, "thread_name_prefix": "cpu_"}, "parser": {"max_workers": 8, "thread_name_prefix": "parser_"} } def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not ThreadPoolManager._initialized: with ThreadPoolManager._lock: if not ThreadPoolManager._initialized: self._pools = {} for pool_type, config in self._pool_configs.items(): self._pools[pool_type] = ThreadPoolExecutor(**config) logger.info(f"Created thread pool: {pool_type} (max_workers={config['max_workers']})") ThreadPoolManager._initialized = True @classmethod def get_pool(cls, pool_type: str = "io") -> ThreadPoolExecutor: """ 获取指定类型的线程池 Args: pool_type: 线程池类型 (io/cpu/parser) Returns: ThreadPoolExecutor 实例 """ instance = cls() if pool_type not in instance._pools: raise ValueError(f"Unknown pool type: {pool_type}. Available: {list(instance._pools.keys())}") return instance._pools[pool_type] @classmethod def submit(cls, pool_type: str, fn: Callable, *args, **kwargs) -> Future: """ 提交任务到指定线程池 Args: pool_type: 线程池类型 fn: 要执行的函数 *args, **kwargs: 函数参数 Returns: Future 对象 """ pool = cls.get_pool(pool_type) return pool.submit(fn, *args, **kwargs) @classmethod def map(cls, pool_type: str, fn: Callable, *iterables, timeout: Optional[float] = None) -> Iterator: """ 并行映射函数到多个参数 Args: pool_type: 线程池类型 fn: 要执行的函数 *iterables: 参数迭代器 timeout: 超时时间 Returns: 结果迭代器 """ pool = cls.get_pool(pool_type) return pool.map(fn, *iterables, timeout=timeout) @classmethod def shutdown_all(cls, wait: bool = True): """ 关闭所有线程池 Args: wait: 是否等待所有任务完成 """ instance = cls() for pool_type, pool in instance._pools.items(): try: pool.shutdown(wait=wait) logger.info(f"Thread pool '{pool_type}' shutdown successfully") except Exception as e: logger.error(f"Error shutting down pool '{pool_type}': {e}") instance._pools.clear() ThreadPoolManager._initialized = False logger.info("All thread pools have been shutdown") @classmethod def get_status(cls) -> dict: """ 获取所有线程池的状态信息 Returns: 状态字典 """ instance = cls() status = {} for pool_type, config in cls._pool_configs.items(): pool = instance._pools.get(pool_type) if pool: status[pool_type] = { "max_workers": config["max_workers"], "active": not pool._shutdown if hasattr(pool, '_shutdown') else True } return status # 保持向后兼容的全局执行器(使用 io 线程池) def get_executor() -> ThreadPoolExecutor: """获取全局 IO 线程池(向后兼容)""" return ThreadPoolManager.get_pool("io") # 向后兼容的 executor 变量 executor = None def _get_compat_executor(): """延迟初始化的兼容性执行器""" global executor if executor is None: executor = ThreadPoolManager.get_pool("io") return executor async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any: """ Run a blocking function in a thread pool Args: func: Blocking function to run *args: Arguments to pass to the function **kwargs: Keyword arguments to pass to the function Returns: Result of the function """ loop = asyncio.get_event_loop() pool = ThreadPoolManager.get_pool("io") return await loop.run_in_executor(pool, func, *args) def close_executor(): """ Close all thread pools (向后兼容) """ ThreadPoolManager.shutdown_all()