alair 3 mesi fa
parent
commit
be1b66dc18

+ 8 - 6
main.py

@@ -62,20 +62,22 @@ async def main_lifespan(app: FastAPI):
     shutdown_scheduler()
     logger.info("✅ Chunk update scheduler shutdown")
 
-    # 2. 关闭MinIO全局客户端
+    # 2. 关闭全局线程池
+    from src.utils.async_utils import ThreadPoolManager
+    ThreadPoolManager.shutdown_all()
+    logger.info("✅ 全局线程池已关闭")
+
+    # 3. 关闭MinIO全局客户端
     close_minio_client()
     logger.info("✅ MinIO客户端已关闭")
 
-    # 3. 关闭MySQL全局客户端
+    # 4. 关闭MySQL全局客户端
     close_global_mysql_client()
     logger.info("✅ MySQL客户端已关闭")
 
-    # 4. 关闭向量数据库客户端
+    # 5. 关闭向量数据库客户端
     close_vector_db_client()
     logger.info(f"✅ 向量数据库客户端已关闭 (类型: {db_type})")
-    
-
-    
 
 
 # 创建主应用

+ 17 - 17
src/datasets/parser/image_parser/image_parser_workflow.py

@@ -3,7 +3,6 @@
 图片解析工作流
 """
 import concurrent.futures
-from concurrent.futures import ThreadPoolExecutor
 from PIL import Image
 import requests
 from langgraph.graph import StateGraph, START, END
@@ -15,6 +14,7 @@ from src.model.multimodal_embedding import Embedding
 from src.utils.file.image_util import image_util
 from src.conf.settings import model_settings
 from src.utils.infinity import get_client
+from src.utils.async_utils import ThreadPoolManager
 
 # 定义工作流状态类
 class ImageParsingState(BaseModel):
@@ -163,22 +163,22 @@ class ImageParsingWorkflow:
         
         parsed_results = []
         
-        # 使用ThreadPoolExecutor实现并行处理
-        with ThreadPoolExecutor(max_workers=4) as executor:
-            # 提交所有图片解析任务
-            future_to_image = {
-                executor.submit(self._parse_single_page, image_info, self.model_name): image_info
-                for image_info in state.split_images
-            }
-            
-            # 收集解析结果
-            for future in concurrent.futures.as_completed(future_to_image):
-                try:
-                    result = future.result()
-                    parsed_results.append(result)
-                except Exception as e:
-                    image_info = future_to_image[future]
-                    print(f"解析第 {image_info['page_number']} 页时出错: {str(e)}")
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        # 提交所有图片解析任务
+        future_to_image = {
+            pool.submit(self._parse_single_page, image_info, self.model_name): image_info
+            for image_info in state.split_images
+        }
+        
+        # 收集解析结果
+        for future in concurrent.futures.as_completed(future_to_image):
+            try:
+                result = future.result()
+                parsed_results.append(result)
+            except Exception as e:
+                image_info = future_to_image[future]
+                print(f"解析第 {image_info['page_number']} 页时出错: {str(e)}")
         
         # 按页码排序结果
         parsed_results.sort(key=lambda x: x["page_number"])

+ 17 - 17
src/datasets/parser/nodes/image_parse_node.py

@@ -5,13 +5,13 @@
 """
 
 import concurrent.futures
-from concurrent.futures import ThreadPoolExecutor
 from typing import Dict, Any, List, Optional
 from src.datasets.parser.core.base import BaseNode, BaseState
 from src.datasets.parser.core.registry import register_node
 from src.model.qwen_vl import QWenVLParser
 from src.conf.settings import model_settings
 from src.common.logging_config import get_logger
+from src.utils.async_utils import ThreadPoolManager
 
 logger = get_logger(__name__)
 
@@ -42,11 +42,11 @@ class ImageParseNode(BaseNode):
         
         Args:
             model_name: VL模型名称
-            max_workers: 并行处理的最大工作线程数
+            max_workers: 并行处理的最大工作线程数(已废弃,使用全局线程池)
             prompt_template: 自定义提示词模板
         """
         self.model_name = model_name or model_settings.vl_model_name
-        self.max_workers = max_workers
+        self.max_workers = max_workers  # 保留兼容性但不再使用
         self.prompt_template = prompt_template or self._default_prompt_template()
     
     @property
@@ -119,20 +119,20 @@ class ImageParseNode(BaseNode):
         
         parsed_results = []
         
-        # 使用ThreadPoolExecutor实现并行处理
-        with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="parse_page_") as executor:
-            future_to_page = {
-                executor.submit(self._parse_single_page, page): page
-                for page in pages
-            }
-            
-            for future in concurrent.futures.as_completed(future_to_page):
-                try:
-                    result = future.result()
-                    parsed_results.append(result)
-                except Exception as e:
-                    page = future_to_page[future]
-                    logger.error(f"解析第 {page.get('page_number', '?')} 页时出错: {str(e)}")
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        future_to_page = {
+            pool.submit(self._parse_single_page, page): page
+            for page in pages
+        }
+        
+        for future in concurrent.futures.as_completed(future_to_page):
+            try:
+                result = future.result()
+                parsed_results.append(result)
+            except Exception as e:
+                page = future_to_page[future]
+                logger.error(f"解析第 {page.get('page_number', '?')} 页时出错: {str(e)}")
         
         # 按页码排序结果
         parsed_results.sort(key=lambda x: x.get("page_number", 0))

+ 20 - 20
src/datasets/parser/nodes/pdf_ocr_node.py

@@ -3,13 +3,13 @@ PDF OCR解析节点
 """
 
 import concurrent.futures
-from concurrent.futures import ThreadPoolExecutor
 from typing import Dict, Any, List
 from src.datasets.parser.core.base import BaseNode
 from src.datasets.parser.pdf_parser.pdf_splitter import PDFSplitter
 from src.model.qwen_vl import QWenVLParser
 from src.common.logging_config import get_logger
 from src.conf.settings import model_settings
+from src.utils.async_utils import ThreadPoolManager
 
 logger = get_logger(__name__)
 
@@ -78,25 +78,25 @@ class PDFOCRNode(BaseNode):
         
         parsed_results = []
         
-        # 使用ThreadPoolExecutor实现并行处理
-        with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="ocr_page_") as executor:
-            future_to_page = {
-                executor.submit(self._parse_single_page, page): page
-                for page in pages
-            }
-            
-            for future in concurrent.futures.as_completed(future_to_page):
-                try:
-                    result = future.result()
-                    parsed_results.append(result)
-                except Exception as e:
-                    page = future_to_page[future]
-                    page_number = page.get("page_number", 0)
-                    logger.error(f"解析第 {page_number} 页时出错: {str(e)}")
-                    parsed_results.append({
-                        "page_number": page_number,
-                        "content": ""
-                    })
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        future_to_page = {
+            pool.submit(self._parse_single_page, page): page
+            for page in pages
+        }
+        
+        for future in concurrent.futures.as_completed(future_to_page):
+            try:
+                result = future.result()
+                parsed_results.append(result)
+            except Exception as e:
+                page = future_to_page[future]
+                page_number = page.get("page_number", 0)
+                logger.error(f"解析第 {page_number} 页时出错: {str(e)}")
+                parsed_results.append({
+                    "page_number": page_number,
+                    "content": ""
+                })
         
         # 按页码排序结果,确保顺序性
         parsed_results.sort(key=lambda x: x.get("page_number", 0))

+ 17 - 16
src/datasets/parser/nodes/qa_generate_node.py

@@ -6,13 +6,13 @@ QA对生成节点
 
 import json
 import concurrent.futures
-from concurrent.futures import ThreadPoolExecutor
 from typing import Dict, Any, List, Optional
 from src.datasets.parser.core.base import BaseNode, BaseState
 from src.datasets.parser.core.registry import register_node
 from src.model.qwen_vl import QWenVLParser
 from src.conf.settings import model_settings
 from src.common.logging_config import get_logger
+from src.utils.async_utils import ThreadPoolManager
 
 logger = get_logger(__name__)
 
@@ -75,7 +75,8 @@ class QAGenerateNode(BaseNode):
             2. 答案应该准确、完整,直接来源于文本
             3. 问题应该自然,像真实用户会问的问题
             4. 避免过于简单或过于复杂的问题
-            
+            5. 如果json不完整,则去除不完整的元素,只返回完整的json数组
+
             文本内容:
             {chunk}
             
@@ -122,7 +123,7 @@ class QAGenerateNode(BaseNode):
                     return json.loads(json_match.group())
                 except json.JSONDecodeError:
                     pass
-            logger.warning("无法解析QA响应为JSON")
+            logger.warning(f"无法解析QA响应为JSON: {response}")
             return []
     
     def execute(self, state: BaseState) -> Dict[str, Any]:
@@ -146,19 +147,19 @@ class QAGenerateNode(BaseNode):
         
         all_qa_pairs = []
         
-        # 使用ThreadPoolExecutor实现并行处理
-        with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="qa_gen_") as executor:
-            futures = [
-                executor.submit(self._generate_qa_for_chunk, chunk, qa_count_per_chunk, i)
-                for i, chunk in enumerate(text_chunks)
-            ]
-            
-            for future in concurrent.futures.as_completed(futures):
-                try:
-                    qa_pairs = future.result()
-                    all_qa_pairs.extend(qa_pairs)
-                except Exception as e:
-                    logger.error(f"QA生成任务失败: {str(e)}")
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        futures = [
+            pool.submit(self._generate_qa_for_chunk, chunk, qa_count_per_chunk, i)
+            for i, chunk in enumerate(text_chunks)
+        ]
+        
+        for future in concurrent.futures.as_completed(futures):
+            try:
+                qa_pairs = future.result()
+                all_qa_pairs.extend(qa_pairs)
+            except Exception as e:
+                logger.error(f"QA生成任务失败: {str(e)}")
         
         logger.info(f"QA对生成完成,共 {len(all_qa_pairs)} 个")
         

+ 17 - 17
src/datasets/parser/pdf_parser/pdf_parser_workflow.py

@@ -1,7 +1,6 @@
 import os
 import concurrent.futures
 import time
-from concurrent.futures import ThreadPoolExecutor
 from langgraph.graph import StateGraph, START, END
 from typing import List, Dict, Any
 from pydantic import BaseModel, Field, ConfigDict
@@ -14,6 +13,7 @@ from src.conf.settings import model_settings, vector_db_settings, minio_settings
 from src.utils.infinity import get_client
 from langfuse.langchain import CallbackHandler
 from src.conf.rag_parser_config import RagParserDefaults
+from src.utils.async_utils import ThreadPoolManager
 
 # 定义工作流状态类
 class PDFParsingState(BaseModel):
@@ -334,22 +334,22 @@ class PDFParsingWorkflow:
         
         parsed_results = []
         
-        # 使用ThreadPoolExecutor实现并行处理
-        with ThreadPoolExecutor(max_workers=5, thread_name_prefix="parse_page_") as executor:
-            # 提交所有页面解析任务
-            future_to_page = {
-                executor.submit(self._parse_single_page, page, self.model_name): page
-                for page in state.split_pages
-            }
-            
-            # 收集解析结果
-            for future in concurrent.futures.as_completed(future_to_page):
-                try:
-                    result = future.result()
-                    parsed_results.append(result)
-                except Exception as e:
-                    page = future_to_page[future]
-                    print(f"解析第 {page['page_number']} 页时出错: {str(e)}")
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        # 提交所有页面解析任务
+        future_to_page = {
+            pool.submit(self._parse_single_page, page, self.model_name): page
+            for page in state.split_pages
+        }
+        
+        # 收集解析结果
+        for future in concurrent.futures.as_completed(future_to_page):
+            try:
+                result = future.result()
+                parsed_results.append(result)
+            except Exception as e:
+                page = future_to_page[future]
+                print(f"解析第 {page['page_number']} 页时出错: {str(e)}")
         
         # 按页码排序结果
         parsed_results.sort(key=lambda x: x["page_number"])

+ 21 - 21
src/datasets/parser/question_answer_parser/question_answer_parser.py

@@ -10,7 +10,6 @@ import json
 import csv
 import tempfile
 import concurrent.futures
-from concurrent.futures import ThreadPoolExecutor
 from typing import List, Dict, Any, Optional
 from pydantic import BaseModel, Field, ConfigDict
 from langgraph.graph import StateGraph, START, END
@@ -22,6 +21,7 @@ from src.model.qwen_vl import QWenVLParser
 from src.utils.ragflow.ragflow_service import RAGFlowService
 from src.conf.settings import model_settings
 from langfuse.langchain import CallbackHandler
+from src.utils.async_utils import ThreadPoolManager
 
 
 class QAParsingState(BaseModel):
@@ -230,26 +230,26 @@ class QuestionAnswerParserWorkflow:
         
         all_qa_pairs = []
         
-        # 使用ThreadPoolExecutor并行处理
-        with ThreadPoolExecutor(max_workers=5, thread_name_prefix="qa_gen_") as executor:
-            futures = {
-                executor.submit(
-                    self._generate_qa_for_chunk, 
-                    chunk, 
-                    state.qa_count_per_chunk,
-                    i
-                ): i for i, chunk in enumerate(state.chunks)
-            }
-            
-            for future in concurrent.futures.as_completed(futures):
-                chunk_index = futures[future]
-                try:
-                    qa_list = future.result()
-                    for qa in qa_list:
-                        qa["chunk_index"] = chunk_index
-                    all_qa_pairs.extend(qa_list)
-                except Exception as e:
-                    print(f"分块 {chunk_index + 1} 处理异常: {str(e)}")
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        futures = {
+            pool.submit(
+                self._generate_qa_for_chunk, 
+                chunk, 
+                state.qa_count_per_chunk,
+                i
+            ): i for i, chunk in enumerate(state.chunks)
+        }
+        
+        for future in concurrent.futures.as_completed(futures):
+            chunk_index = futures[future]
+            try:
+                qa_list = future.result()
+                for qa in qa_list:
+                    qa["chunk_index"] = chunk_index
+                all_qa_pairs.extend(qa_list)
+            except Exception as e:
+                print(f"分块 {chunk_index + 1} 处理异常: {str(e)}")
         
         # 按chunk_index排序
         all_qa_pairs.sort(key=lambda x: x.get("chunk_index", 0))

+ 35 - 10
src/job/chunk_update_job.py

@@ -12,8 +12,9 @@ from datetime import datetime
 from apscheduler.schedulers.background import BackgroundScheduler
 from apscheduler.triggers.interval import IntervalTrigger
 from src.utils.mysql import get_global_mysql_client
-from src.utils.infinity import get_client
+from src.utils.vector_db import get_vector_db_client
 from src.common.logging_config import get_logger
+from src.conf.settings import vector_db_settings, es_settings
 
 # 获取日志器
 logger = get_logger(__name__)
@@ -27,7 +28,12 @@ class ChunkUpdateJob:
     def __init__(self):
         """初始化定时任务服务"""
         self.mysql_client = get_global_mysql_client()
-        self.infinity_client = get_client()
+        self.vector_db_type = vector_db_settings.vector_db_type
+        
+        # 使用 FastAPI 全局生命周期管理的向量数据库客户端
+        self.vector_db_client = get_vector_db_client()
+        logger.info("Using global vector_db_client for ChunkUpdateJob (type=%s)", 
+                   self.vector_db_type)
     
     def process_due_tasks(self):
         """处理到期的任务"""
@@ -78,21 +84,40 @@ class ChunkUpdateJob:
         
         Args:
             database_name: 数据库名称
-            table_name: 表名称
-            chunk_id: Chunk ID
+            table_name: 表名称 (ES 模式下作为 index_name)
+            chunk_id: Chunk ID (ES 模式下作为 document_id)
             cond: 条件字符串
-            data: 数据字典
+            data: 更新的数据字典
         """
         try:
-            # 使用 Infinity 客户端执行更新操作
-            # 这里需要根据实际的 Infinity API 进行调整
-            if cond and data:
-                self.infinity_client.update(
+            # 解析数据
+            update_data = json.loads(data) if isinstance(data, str) else data
+            
+            if self.vector_db_client is None:
+                logger.warning("No vector database client available (vector_db_type=%s)", 
+                              self.vector_db_type)
+                return
+            
+            # 使用统一的 vector_db_client.update() 接口
+            # ES 模式下:cond 作为 document_id 使用
+            # Infinity 模式下:cond 作为条件字符串使用
+            if self.vector_db_type == "es":
+                # ES 模式:使用 chunk_id 作为 cond(document_id)
+                self.vector_db_client.update(
+                    table_name=table_name or es_settings.es_index_name,
+                    cond=chunk_id,
+                    data=update_data,
+                    database_name=database_name
+                )
+            else:
+                # Infinity 模式:使用原始的 cond
+                self.vector_db_client.update(
                     table_name=table_name,
                     cond=cond,
-                    data=json.loads(data),
+                    data=update_data,
                     database_name=database_name
                 )
+                
             logger.info(f"Updated chunk {chunk_id} in {database_name}.{table_name}")
         except Exception as e:
             raise Exception(f"Failed to update chunk {chunk_id}: {e}")

+ 100 - 5
src/model/qwen_vl.py

@@ -2,8 +2,13 @@ from typing import Dict, Any
 from PIL import Image
 import base64
 import io
+import time
 from langchain.chat_models import init_chat_model
 from src.conf.settings import model_settings
+from src.common.logging_config import get_logger
+
+# 获取日志器
+logger = get_logger(__name__)
 
 class QWenVLParser:
     """QWEN VL模型图像解析工具"""
@@ -26,6 +31,40 @@ class QWenVLParser:
             base_url=self.base_url,
             api_key=self.api_key
         )
+        # 重试配置
+        self.max_retries = 3
+        self.retry_delay = 1.0  # 初始重试延迟(秒)
+    
+    def _invoke_with_retry(self, messages: list) -> Any:
+        """
+        带重试的模型调用
+        
+        Args:
+            messages: 消息列表
+            
+        Returns:
+            模型响应
+        """
+        last_exception = None
+        
+        for attempt in range(self.max_retries):
+            try:
+                response = self.model.invoke(input=messages)
+                return response
+            except Exception as e:
+                last_exception = e
+                if attempt < self.max_retries - 1:
+                    # 指数退避重试
+                    delay = self.retry_delay * (2 ** attempt)
+                    logger.warning(
+                        f"模型调用失败(尝试 {attempt + 1}/{self.max_retries}),"
+                        f"{delay:.1f}秒后重试: {str(e)}"
+                    )
+                    time.sleep(delay)
+                else:
+                    logger.error(f"模型调用失败,已达到最大重试次数: {str(e)}")
+        
+        raise last_exception
     
     def image_to_base64(self, image: Image.Image) -> str:
         """
@@ -79,8 +118,8 @@ class QWenVLParser:
                 }
             ]
             
-            # 使用langchain模型调用OpenAI API
-            response = self.model.invoke(input=messages)
+            # 使用带重试的模型调用
+            response = self._invoke_with_retry(messages)
             
             # 提取解析结果
             content = response.content
@@ -128,8 +167,8 @@ class QWenVLParser:
                 }
             ]
             
-            # 使用langchain模型调用OpenAI API
-            response = self.model.invoke(messages)
+            # 使用带重试的模型调用
+            response = self._invoke_with_retry(messages)
             
             # 提取解析结果
             content = response.content
@@ -156,4 +195,60 @@ class QWenVLParser:
         """
         # 将字节流转换为PIL图像
         image = Image.open(image_bytes)
-        return self.parse_image(image, page_number, prompt)
+        return self.parse_image(image, page_number, prompt)
+    
+    def chat(self, prompt: str, system_prompt: str = None) -> str:
+        """
+        纯文本生成/对话
+        
+        Args:
+            prompt: 用户提示词
+            system_prompt: 系统提示词(可选)
+            
+        Returns:
+            str: 生成的文本内容
+        """
+        try:
+            messages = []
+            
+            # 添加系统提示词(如果有)
+            if system_prompt:
+                messages.append({
+                    "role": "system",
+                    "content": system_prompt
+                })
+            
+            # 添加用户消息
+            messages.append({
+                "role": "user",
+                "content": prompt
+            })
+            
+            # 使用带重试的模型调用
+            response = self._invoke_with_retry(messages)
+            
+            return response.content
+        except Exception as e:
+            raise Exception(f"文本生成失败: {str(e)}")
+    
+    def generate(self, prompt: str, system_prompt: str = None) -> Dict[str, Any]:
+        """
+        文本生成(返回详细信息)
+        
+        Args:
+            prompt: 用户提示词
+            system_prompt: 系统提示词(可选)
+            
+        Returns:
+            Dict: 包含生成结果的字典,包含:
+                - content: 生成的文本内容
+                - model: 使用的模型名称
+        """
+        try:
+            content = self.chat(prompt, system_prompt)
+            return {
+                "content": content,
+                "model": self.model_name
+            }
+        except Exception as e:
+            raise Exception(f"文本生成失败: {str(e)}")

+ 153 - 7
src/utils/async_utils.py

@@ -1,13 +1,158 @@
 """
 Async utility functions for running blocking operations
+Global thread pool manager for the application
 """
 import asyncio
-from concurrent.futures import ThreadPoolExecutor
-from typing import Any, Callable
+import threading
+import os
+from concurrent.futures import ThreadPoolExecutor, Future
+from typing import Any, Callable, Iterator, Optional
+from src.common.logging_config import get_logger
 
+logger = get_logger(__name__)
 
-# Create a thread pool executor
-executor = ThreadPoolExecutor(max_workers=10)
+
+class ThreadPoolManager:
+    """
+    全局线程池管理器(单例模式)
+    
+    提供三种类型的线程池:
+    - io: IO密集型操作(网络请求、文件读写)
+    - cpu: CPU密集型操作
+    - parser: 解析任务(图像/PDF/OCR/QA生成)
+    """
+    
+    _instance = None
+    _lock = threading.Lock()
+    _initialized = False
+    
+    # 线程池配置
+    _pool_configs = {
+        "io": {"max_workers": 20, "thread_name_prefix": "io_"},
+        "cpu": {"max_workers": os.cpu_count() or 4, "thread_name_prefix": "cpu_"},
+        "parser": {"max_workers": 8, "thread_name_prefix": "parser_"}
+    }
+    
+    def __new__(cls):
+        if cls._instance is None:
+            with cls._lock:
+                if cls._instance is None:
+                    cls._instance = super().__new__(cls)
+        return cls._instance
+    
+    def __init__(self):
+        if not ThreadPoolManager._initialized:
+            with ThreadPoolManager._lock:
+                if not ThreadPoolManager._initialized:
+                    self._pools = {}
+                    for pool_type, config in self._pool_configs.items():
+                        self._pools[pool_type] = ThreadPoolExecutor(**config)
+                        logger.info(f"Created thread pool: {pool_type} (max_workers={config['max_workers']})")
+                    ThreadPoolManager._initialized = True
+    
+    @classmethod
+    def get_pool(cls, pool_type: str = "io") -> ThreadPoolExecutor:
+        """
+        获取指定类型的线程池
+        
+        Args:
+            pool_type: 线程池类型 (io/cpu/parser)
+            
+        Returns:
+            ThreadPoolExecutor 实例
+        """
+        instance = cls()
+        if pool_type not in instance._pools:
+            raise ValueError(f"Unknown pool type: {pool_type}. Available: {list(instance._pools.keys())}")
+        return instance._pools[pool_type]
+    
+    @classmethod
+    def submit(cls, pool_type: str, fn: Callable, *args, **kwargs) -> Future:
+        """
+        提交任务到指定线程池
+        
+        Args:
+            pool_type: 线程池类型
+            fn: 要执行的函数
+            *args, **kwargs: 函数参数
+            
+        Returns:
+            Future 对象
+        """
+        pool = cls.get_pool(pool_type)
+        return pool.submit(fn, *args, **kwargs)
+    
+    @classmethod
+    def map(cls, pool_type: str, fn: Callable, *iterables, timeout: Optional[float] = None) -> Iterator:
+        """
+        并行映射函数到多个参数
+        
+        Args:
+            pool_type: 线程池类型
+            fn: 要执行的函数
+            *iterables: 参数迭代器
+            timeout: 超时时间
+            
+        Returns:
+            结果迭代器
+        """
+        pool = cls.get_pool(pool_type)
+        return pool.map(fn, *iterables, timeout=timeout)
+    
+    @classmethod
+    def shutdown_all(cls, wait: bool = True):
+        """
+        关闭所有线程池
+        
+        Args:
+            wait: 是否等待所有任务完成
+        """
+        instance = cls()
+        for pool_type, pool in instance._pools.items():
+            try:
+                pool.shutdown(wait=wait)
+                logger.info(f"Thread pool '{pool_type}' shutdown successfully")
+            except Exception as e:
+                logger.error(f"Error shutting down pool '{pool_type}': {e}")
+        instance._pools.clear()
+        ThreadPoolManager._initialized = False
+        logger.info("All thread pools have been shutdown")
+    
+    @classmethod
+    def get_status(cls) -> dict:
+        """
+        获取所有线程池的状态信息
+        
+        Returns:
+            状态字典
+        """
+        instance = cls()
+        status = {}
+        for pool_type, config in cls._pool_configs.items():
+            pool = instance._pools.get(pool_type)
+            if pool:
+                status[pool_type] = {
+                    "max_workers": config["max_workers"],
+                    "active": not pool._shutdown if hasattr(pool, '_shutdown') else True
+                }
+        return status
+
+
+# 保持向后兼容的全局执行器(使用 io 线程池)
+def get_executor() -> ThreadPoolExecutor:
+    """获取全局 IO 线程池(向后兼容)"""
+    return ThreadPoolManager.get_pool("io")
+
+
+# 向后兼容的 executor 变量
+executor = None
+
+def _get_compat_executor():
+    """延迟初始化的兼容性执行器"""
+    global executor
+    if executor is None:
+        executor = ThreadPoolManager.get_pool("io")
+    return executor
 
 
 async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any:
@@ -23,11 +168,12 @@ async def run_in_threadpool(func: Callable[..., Any], *args, **kwargs) -> Any:
         Result of the function
     """
     loop = asyncio.get_event_loop()
-    return await loop.run_in_executor(executor, func, *args, **kwargs)
+    pool = ThreadPoolManager.get_pool("io")
+    return await loop.run_in_executor(pool, func, *args)
 
 
 def close_executor():
     """
-    Close the thread pool executor
+    Close all thread pools (向后兼容)
     """
-    executor.shutdown(wait=True)
+    ThreadPoolManager.shutdown_all()