|
|
@@ -1,13 +1,158 @@
|
|
|
"""
|
|
|
Async utility functions for running blocking operations
|
|
|
+Global thread pool manager for the application
|
|
|
"""
|
|
|
import asyncio
|
|
|
-from concurrent.futures import ThreadPoolExecutor
|
|
|
-from typing import Any, Callable
|
|
|
+import threading
|
|
|
+import os
|
|
|
+from concurrent.futures import ThreadPoolExecutor, Future
|
|
|
+from typing import Any, Callable, Iterator, Optional
|
|
|
+from src.common.logging_config import get_logger
|
|
|
|
|
|
+logger = get_logger(__name__)
|
|
|
|
|
|
-# Create a thread pool executor
|
|
|
-executor = ThreadPoolExecutor(max_workers=10)
|
|
|
+
|
|
|
+class ThreadPoolManager:
|
|
|
+ """
|
|
|
+ 全局线程池管理器(单例模式)
|
|
|
+
|
|
|
+ 提供三种类型的线程池:
|
|
|
+ - io: IO密集型操作(网络请求、文件读写)
|
|
|
+ - cpu: CPU密集型操作
|
|
|
+ - parser: 解析任务(图像/PDF/OCR/QA生成)
|
|
|
+ """
|
|
|
+
|
|
|
+ _instance = None
|
|
|
+ _lock = threading.Lock()
|
|
|
+ _initialized = False
|
|
|
+
|
|
|
+ # 线程池配置
|
|
|
+ _pool_configs = {
|
|
|
+ "io": {"max_workers": 20, "thread_name_prefix": "io_"},
|
|
|
+ "cpu": {"max_workers": os.cpu_count() or 4, "thread_name_prefix": "cpu_"},
|
|
|
+ "parser": {"max_workers": 8, "thread_name_prefix": "parser_"}
|
|
|
+ }
|
|
|
+
|
|
|
+ def __new__(cls):
|
|
|
+ if cls._instance is None:
|
|
|
+ with cls._lock:
|
|
|
+ if cls._instance is None:
|
|
|
+ cls._instance = super().__new__(cls)
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ if not ThreadPoolManager._initialized:
|
|
|
+ with ThreadPoolManager._lock:
|
|
|
+ if not ThreadPoolManager._initialized:
|
|
|
+ self._pools = {}
|
|
|
+ for pool_type, config in self._pool_configs.items():
|
|
|
+ self._pools[pool_type] = ThreadPoolExecutor(**config)
|
|
|
+ logger.info(f"Created thread pool: {pool_type} (max_workers={config['max_workers']})")
|
|
|
+ ThreadPoolManager._initialized = True
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_pool(cls, pool_type: str = "io") -> ThreadPoolExecutor:
|
|
|
+ """
|
|
|
+ 获取指定类型的线程池
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pool_type: 线程池类型 (io/cpu/parser)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ ThreadPoolExecutor 实例
|
|
|
+ """
|
|
|
+ instance = cls()
|
|
|
+ if pool_type not in instance._pools:
|
|
|
+ raise ValueError(f"Unknown pool type: {pool_type}. Available: {list(instance._pools.keys())}")
|
|
|
+ return instance._pools[pool_type]
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def submit(cls, pool_type: str, fn: Callable, *args, **kwargs) -> Future:
|
|
|
+ """
|
|
|
+ 提交任务到指定线程池
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pool_type: 线程池类型
|
|
|
+ fn: 要执行的函数
|
|
|
+ *args, **kwargs: 函数参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Future 对象
|
|
|
+ """
|
|
|
+ pool = cls.get_pool(pool_type)
|
|
|
+ return pool.submit(fn, *args, **kwargs)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def map(cls, pool_type: str, fn: Callable, *iterables, timeout: Optional[float] = None) -> Iterator:
|
|
|
+ """
|
|
|
+ 并行映射函数到多个参数
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pool_type: 线程池类型
|
|
|
+ fn: 要执行的函数
|
|
|
+ *iterables: 参数迭代器
|
|
|
+ timeout: 超时时间
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 结果迭代器
|
|
|
+ """
|
|
|
+ pool = cls.get_pool(pool_type)
|
|
|
+ return pool.map(fn, *iterables, timeout=timeout)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def shutdown_all(cls, wait: bool = True):
|
|
|
+ """
|
|
|
+ 关闭所有线程池
|
|
|
+
|
|
|
+ Args:
|
|
|
+ wait: 是否等待所有任务完成
|
|
|
+ """
|
|
|
+ instance = cls()
|
|
|
+ for pool_type, pool in instance._pools.items():
|
|
|
+ try:
|
|
|
+ pool.shutdown(wait=wait)
|
|
|
+ logger.info(f"Thread pool '{pool_type}' shutdown successfully")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error shutting down pool '{pool_type}': {e}")
|
|
|
+ instance._pools.clear()
|
|
|
+ ThreadPoolManager._initialized = False
|
|
|
+ logger.info("All thread pools have been shutdown")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_status(cls) -> dict:
|
|
|
+ """
|
|
|
+ 获取所有线程池的状态信息
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 状态字典
|
|
|
+ """
|
|
|
+ instance = cls()
|
|
|
+ status = {}
|
|
|
+ for pool_type, config in cls._pool_configs.items():
|
|
|
+ pool = instance._pools.get(pool_type)
|
|
|
+ if pool:
|
|
|
+ status[pool_type] = {
|
|
|
+ "max_workers": config["max_workers"],
|
|
|
+ "active": not pool._shutdown if hasattr(pool, '_shutdown') else True
|
|
|
+ }
|
|
|
+ return status
|
|
|
+
|
|
|
+
|
|
|
+# 保持向后兼容的全局执行器(使用 io 线程池)
|
|
|
+def get_executor() -> ThreadPoolExecutor:
|
|
|
+ """获取全局 IO 线程池(向后兼容)"""
|
|
|
+ return ThreadPoolManager.get_pool("io")
|
|
|
+
|
|
|
+
|
|
|
+# 向后兼容的 executor 变量
|
|
|
+executor = None
|
|
|
+
|
|
|
+def _get_compat_executor():
|
|
|
+ """延迟初始化的兼容性执行器"""
|
|
|
+ global executor
|
|
|
+ if executor is None:
|
|
|
+ executor = ThreadPoolManager.get_pool("io")
|
|
|
+ return executor
|
|
|
|
|
|
|
|
|
async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any:
|
|
|
@@ -23,11 +168,12 @@ async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any:
|
|
|
Result of the function
|
|
|
"""
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- return await loop.run_in_executor(executor, func, *args, **kwargs)
|
|
|
+ pool = ThreadPoolManager.get_pool("io")
|
|
|
+ return await loop.run_in_executor(pool, func, *args)
|
|
|
|
|
|
|
|
|
def close_executor():
|
|
|
"""
|
|
|
- Close the thread pool executor
|
|
|
+ Close all thread pools (向后兼容)
|
|
|
"""
|
|
|
- executor.shutdown(wait=True)
|
|
|
+ ThreadPoolManager.shutdown_all()
|