| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- """
- Async utility functions for running blocking operations
- Global thread pool manager for the application
- """
- import asyncio
- import threading
- import os
- from concurrent.futures import ThreadPoolExecutor, Future
- from typing import Any, Callable, Iterator, Optional
- from src.common.logging_config import get_logger
- logger = get_logger(__name__)
- class ThreadPoolManager:
- """
- 全局线程池管理器(单例模式)
-
- 提供三种类型的线程池:
- - io: IO密集型操作(网络请求、文件读写)
- - cpu: CPU密集型操作
- - parser: 解析任务(图像/PDF/OCR/QA生成)
- """
-
- _instance = None
- _lock = threading.Lock()
- _initialized = False
-
- # 线程池配置
- _pool_configs = {
- "io": {"max_workers": 20, "thread_name_prefix": "io_"},
- "cpu": {"max_workers": os.cpu_count() or 4, "thread_name_prefix": "cpu_"},
- "parser": {"max_workers": 8, "thread_name_prefix": "parser_"}
- }
-
- def __new__(cls):
- if cls._instance is None:
- with cls._lock:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
-
- def __init__(self):
- if not ThreadPoolManager._initialized:
- with ThreadPoolManager._lock:
- if not ThreadPoolManager._initialized:
- self._pools = {}
- for pool_type, config in self._pool_configs.items():
- self._pools[pool_type] = ThreadPoolExecutor(**config)
- logger.info(f"Created thread pool: {pool_type} (max_workers={config['max_workers']})")
- ThreadPoolManager._initialized = True
-
- @classmethod
- def get_pool(cls, pool_type: str = "io") -> ThreadPoolExecutor:
- """
- 获取指定类型的线程池
-
- Args:
- pool_type: 线程池类型 (io/cpu/parser)
-
- Returns:
- ThreadPoolExecutor 实例
- """
- instance = cls()
- if pool_type not in instance._pools:
- raise ValueError(f"Unknown pool type: {pool_type}. Available: {list(instance._pools.keys())}")
- return instance._pools[pool_type]
-
- @classmethod
- def submit(cls, pool_type: str, fn: Callable, *args, **kwargs) -> Future:
- """
- 提交任务到指定线程池
-
- Args:
- pool_type: 线程池类型
- fn: 要执行的函数
- *args, **kwargs: 函数参数
-
- Returns:
- Future 对象
- """
- pool = cls.get_pool(pool_type)
- return pool.submit(fn, *args, **kwargs)
-
- @classmethod
- def map(cls, pool_type: str, fn: Callable, *iterables, timeout: Optional[float] = None) -> Iterator:
- """
- 并行映射函数到多个参数
-
- Args:
- pool_type: 线程池类型
- fn: 要执行的函数
- *iterables: 参数迭代器
- timeout: 超时时间
-
- Returns:
- 结果迭代器
- """
- pool = cls.get_pool(pool_type)
- return pool.map(fn, *iterables, timeout=timeout)
-
- @classmethod
- def shutdown_all(cls, wait: bool = True):
- """
- 关闭所有线程池
-
- Args:
- wait: 是否等待所有任务完成
- """
- instance = cls()
- for pool_type, pool in instance._pools.items():
- try:
- pool.shutdown(wait=wait)
- logger.info(f"Thread pool '{pool_type}' shutdown successfully")
- except Exception as e:
- logger.error(f"Error shutting down pool '{pool_type}': {e}")
- instance._pools.clear()
- ThreadPoolManager._initialized = False
- logger.info("All thread pools have been shutdown")
-
- @classmethod
- def get_status(cls) -> dict:
- """
- 获取所有线程池的状态信息
-
- Returns:
- 状态字典
- """
- instance = cls()
- status = {}
- for pool_type, config in cls._pool_configs.items():
- pool = instance._pools.get(pool_type)
- if pool:
- status[pool_type] = {
- "max_workers": config["max_workers"],
- "active": not pool._shutdown if hasattr(pool, '_shutdown') else True
- }
- return status
- # 保持向后兼容的全局执行器(使用 io 线程池)
- def get_executor() -> ThreadPoolExecutor:
- """获取全局 IO 线程池(向后兼容)"""
- return ThreadPoolManager.get_pool("io")
- # 向后兼容的 executor 变量
- executor = None
- def _get_compat_executor():
- """延迟初始化的兼容性执行器"""
- global executor
- if executor is None:
- executor = ThreadPoolManager.get_pool("io")
- return executor
- async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any:
- """
- Run a blocking function in a thread pool
-
- Args:
- func: Blocking function to run
- *args: Arguments to pass to the function
- **kwargs: Keyword arguments to pass to the function
-
- Returns:
- Result of the function
- """
- loop = asyncio.get_event_loop()
- pool = ThreadPoolManager.get_pool("io")
- return await loop.run_in_executor(pool, func, *args)
- def close_executor():
- """
- Close all thread pools (向后兼容)
- """
- ThreadPoolManager.shutdown_all()
|