async_utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """
  2. Async utility functions for running blocking operations
  3. Global thread pool manager for the application
  4. """
  5. import asyncio
  6. import threading
  7. import os
  8. from concurrent.futures import ThreadPoolExecutor, Future
  9. from typing import Any, Callable, Iterator, Optional
  10. from src.common.logging_config import get_logger
  11. logger = get_logger(__name__)
  12. class ThreadPoolManager:
  13. """
  14. 全局线程池管理器(单例模式)
  15. 提供三种类型的线程池:
  16. - io: IO密集型操作(网络请求、文件读写)
  17. - cpu: CPU密集型操作
  18. - parser: 解析任务(图像/PDF/OCR/QA生成)
  19. """
  20. _instance = None
  21. _lock = threading.Lock()
  22. _initialized = False
  23. # 线程池配置
  24. _pool_configs = {
  25. "io": {"max_workers": 20, "thread_name_prefix": "io_"},
  26. "cpu": {"max_workers": os.cpu_count() or 4, "thread_name_prefix": "cpu_"},
  27. "parser": {"max_workers": 8, "thread_name_prefix": "parser_"}
  28. }
  29. def __new__(cls):
  30. if cls._instance is None:
  31. with cls._lock:
  32. if cls._instance is None:
  33. cls._instance = super().__new__(cls)
  34. return cls._instance
  35. def __init__(self):
  36. if not ThreadPoolManager._initialized:
  37. with ThreadPoolManager._lock:
  38. if not ThreadPoolManager._initialized:
  39. self._pools = {}
  40. for pool_type, config in self._pool_configs.items():
  41. self._pools[pool_type] = ThreadPoolExecutor(**config)
  42. logger.info(f"Created thread pool: {pool_type} (max_workers={config['max_workers']})")
  43. ThreadPoolManager._initialized = True
  44. @classmethod
  45. def get_pool(cls, pool_type: str = "io") -> ThreadPoolExecutor:
  46. """
  47. 获取指定类型的线程池
  48. Args:
  49. pool_type: 线程池类型 (io/cpu/parser)
  50. Returns:
  51. ThreadPoolExecutor 实例
  52. """
  53. instance = cls()
  54. if pool_type not in instance._pools:
  55. raise ValueError(f"Unknown pool type: {pool_type}. Available: {list(instance._pools.keys())}")
  56. return instance._pools[pool_type]
  57. @classmethod
  58. def submit(cls, pool_type: str, fn: Callable, *args, **kwargs) -> Future:
  59. """
  60. 提交任务到指定线程池
  61. Args:
  62. pool_type: 线程池类型
  63. fn: 要执行的函数
  64. *args, **kwargs: 函数参数
  65. Returns:
  66. Future 对象
  67. """
  68. pool = cls.get_pool(pool_type)
  69. return pool.submit(fn, *args, **kwargs)
  70. @classmethod
  71. def map(cls, pool_type: str, fn: Callable, *iterables, timeout: Optional[float] = None) -> Iterator:
  72. """
  73. 并行映射函数到多个参数
  74. Args:
  75. pool_type: 线程池类型
  76. fn: 要执行的函数
  77. *iterables: 参数迭代器
  78. timeout: 超时时间
  79. Returns:
  80. 结果迭代器
  81. """
  82. pool = cls.get_pool(pool_type)
  83. return pool.map(fn, *iterables, timeout=timeout)
  84. @classmethod
  85. def shutdown_all(cls, wait: bool = True):
  86. """
  87. 关闭所有线程池
  88. Args:
  89. wait: 是否等待所有任务完成
  90. """
  91. instance = cls()
  92. for pool_type, pool in instance._pools.items():
  93. try:
  94. pool.shutdown(wait=wait)
  95. logger.info(f"Thread pool '{pool_type}' shutdown successfully")
  96. except Exception as e:
  97. logger.error(f"Error shutting down pool '{pool_type}': {e}")
  98. instance._pools.clear()
  99. ThreadPoolManager._initialized = False
  100. logger.info("All thread pools have been shutdown")
  101. @classmethod
  102. def get_status(cls) -> dict:
  103. """
  104. 获取所有线程池的状态信息
  105. Returns:
  106. 状态字典
  107. """
  108. instance = cls()
  109. status = {}
  110. for pool_type, config in cls._pool_configs.items():
  111. pool = instance._pools.get(pool_type)
  112. if pool:
  113. status[pool_type] = {
  114. "max_workers": config["max_workers"],
  115. "active": not pool._shutdown if hasattr(pool, '_shutdown') else True
  116. }
  117. return status
  118. # 保持向后兼容的全局执行器(使用 io 线程池)
  119. def get_executor() -> ThreadPoolExecutor:
  120. """获取全局 IO 线程池(向后兼容)"""
  121. return ThreadPoolManager.get_pool("io")
  122. # 向后兼容的 executor 变量
  123. executor = None
  124. def _get_compat_executor():
  125. """延迟初始化的兼容性执行器"""
  126. global executor
  127. if executor is None:
  128. executor = ThreadPoolManager.get_pool("io")
  129. return executor
  130. async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any:
  131. """
  132. Run a blocking function in a thread pool
  133. Args:
  134. func: Blocking function to run
  135. *args: Arguments to pass to the function
  136. **kwargs: Keyword arguments to pass to the function
  137. Returns:
  138. Result of the function
  139. """
  140. loop = asyncio.get_event_loop()
  141. pool = ThreadPoolManager.get_pool("io")
  142. return await loop.run_in_executor(pool, func, *args)
  143. def close_executor():
  144. """
  145. Close all thread pools (向后兼容)
  146. """
  147. ThreadPoolManager.shutdown_all()