Parcourir la source

19维度图书解析流程调整-新增滑动窗口模式

yingge il y a 3 mois
Parent
commit
d16b46a27d

+ 7 - 2
src/api/db/models/search_request_models.py

@@ -34,6 +34,11 @@ class SearchRequest(BaseModel):
         description="用于图像搜索的图片URL",
         examples=["https://example.com/image.jpg"]
     )
+    image_base64: Optional[str] = Field(
+        default=None,
+        description="用于图像搜索的图片base64编码",
+        examples=["base64编码字符串"]
+    )
     matching_text: Optional[str] = Field(
         default=None,
         description="用于文本匹配搜索的查询文本",
@@ -79,5 +84,5 @@ class SearchRequest(BaseModel):
     
     def model_post_init(self, __context) -> None:
         """模型初始化后验证:确保至少提供一个搜索条件"""
-        if self.image_url is None and self.matching_text is None:
-            raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")
+        if self.image_url is None and self.image_base64 is None and self.matching_text is None:
+            raise ValueError("必须提供 image_url 或 image_base64 或 matching_text 其中至少一个")

+ 7 - 1
src/datasets/parser/nodes/result_aggregation_node.py

@@ -104,7 +104,13 @@ class ResultAggregationNode(BaseNode):
             parsed_content = result.get("parsed_content", {})
             
             results_text.append(f"## 窗口 {i} (中心页: {center_page}, 范围: {page_range})")
-            results_text.append(json.dumps(parsed_content, ensure_ascii=False, indent=2))
+            if isinstance(parsed_content, (dict, list)):
+                # 如果是对象,转为格式化字符串
+                content_str = json.dumps(parsed_content, ensure_ascii=False, indent=2)
+            else:
+                # 如果是字符串或其他,转为强类型 string 即可
+                content_str = str(parsed_content)
+            results_text.append(content_str)
             results_text.append("")
         
         prompt = f"""你是一个专业的内容分析专家。现在需要你汇总和整合多个滑动窗口的解析结果。

+ 66 - 5
src/datasets/parser/nodes/sliding_window_parse_node_async.py

@@ -2,9 +2,11 @@
 滑动窗口图像解析节点(异步版本)
 
 使用VL模型异步解析滑动窗口拼接后的图像内容,大幅提升并行效率。
+支持结果缓存和图片编码缓存优化。
 """
 
 import asyncio
+import hashlib
 from typing import Dict, Any, List, Optional
 from src.datasets.parser.core.base import BaseNode, BaseState
 from src.datasets.parser.core.registry import register_node
@@ -15,6 +17,18 @@ from src.utils.json_utils import parse_json_response
 
 logger = get_logger(__name__)
 
+# 全局结果缓存(与同步版本共享)
+_result_cache = {}
+_result_cache_lock = None
+
+def _get_result_cache_lock():
+    """获取结果缓存锁"""
+    global _result_cache_lock
+    if _result_cache_lock is None:
+        import threading
+        _result_cache_lock = threading.Lock()
+    return _result_cache_lock
+
 
 @register_node()
 class SlidingWindowParseNodeAsync(BaseNode):
@@ -39,7 +53,8 @@ class SlidingWindowParseNodeAsync(BaseNode):
     def __init__(
         self,
         model_name: Optional[str] = None,
-        max_concurrent: int = 50
+        max_concurrent: int = 50,
+        enable_result_cache: bool = True
     ):
         """
         初始化滑动窗口解析节点(异步版本)
@@ -47,12 +62,16 @@ class SlidingWindowParseNodeAsync(BaseNode):
         Args:
             model_name: VL模型名称
             max_concurrent: 最大并发数(默认50,可根据API限制调整)
+            enable_result_cache: 是否启用结果缓存(默认True)
         """
         self.model_name = model_name or model_settings.vl_model_name
         self.max_concurrent = max_concurrent
+        self.enable_result_cache = enable_result_cache
+        
+        # 创建共享的Parser实例(复用模型连接,启用图片编码缓存)
+        self.parser = QWenVLParser(self.model_name, reuse_model=True, enable_image_cache=True)
         
-        # 创建共享的Parser实例(复用模型连接)
-        self.parser = QWenVLParser(self.model_name, reuse_model=True)
+        logger.info(f"滑动窗口异步解析节点初始化: 最大并发={self.max_concurrent}, 结果缓存={'启用' if enable_result_cache else '禁用'}")
     
     @property
     def name(self) -> str:
@@ -65,7 +84,7 @@ class SlidingWindowParseNodeAsync(BaseNode):
         semaphore: asyncio.Semaphore
     ) -> Dict[str, Any]:
         """
-        异步解析单个滑动窗口页面
+        异步解析单个滑动窗口页面(优化版:结果缓存 + 类型检查)
         
         Args:
             windowed_page: 窗口页面信息
@@ -89,6 +108,28 @@ class SlidingWindowParseNodeAsync(BaseNode):
             logger.warning(f"格式化提示词失败: {str(e)},使用原始模板")
             prompt = prompt_template
         
+        # === 结果缓存检查 ===
+        if self.enable_result_cache:
+            try:
+                # 生成缓存键:图片哈希 + 提示词哈希
+                image_hash = hashlib.md5(image.tobytes()).hexdigest()
+                prompt_hash = hashlib.md5(prompt.encode('utf-8')).hexdigest()
+                cache_key = f"{image_hash}:{prompt_hash}"
+                
+                # 检查缓存
+                lock = _get_result_cache_lock()
+                with lock:
+                    if cache_key in _result_cache:
+                        logger.debug(f"结果缓存命中: 窗口 (中心页: {center_page})")
+                        cached_result = _result_cache[cache_key].copy()
+                        # 更新页码信息
+                        cached_result["center_page"] = center_page
+                        cached_result["page_range"] = page_range
+                        cached_result["page_numbers"] = windowed_page.get("page_numbers", [])
+                        return cached_result
+            except Exception as e:
+                logger.warning(f"结果缓存检查失败: {str(e)}")
+        
         logger.debug(f"开始异步解析窗口 (中心页: {center_page}, 范围: {page_range})")
         
         # 使用信号量控制并发
@@ -101,14 +142,34 @@ class SlidingWindowParseNodeAsync(BaseNode):
                 # JSON解析(在协程中完成)
                 parsed_content = parse_json_response(content, expected_type=dict)
                 
+                # 如果解析失败返回字符串,使用空字典并记录警告
+                if isinstance(parsed_content, str):
+                    logger.warning(f"窗口 (中心页: {center_page}) 异步解析结果无法转换为JSON,使用原始字符串")
+                    parsed_content = {"原始内容": parsed_content}
+                
                 logger.debug(f"窗口 (中心页: {center_page}) 异步解析完成")
                 
-                return {
+                result_dict = {
                     "center_page": center_page,
                     "page_range": page_range,
                     "page_numbers": windowed_page.get("page_numbers", []),
                     "parsed_content": parsed_content
                 }
+                
+                # === 存入缓存 ===
+                if self.enable_result_cache:
+                    try:
+                        lock = _get_result_cache_lock()
+                        with lock:
+                            # 限制缓存大小,避免内存溢出
+                            if len(_result_cache) < 500:
+                                _result_cache[cache_key] = result_dict.copy()
+                            else:
+                                logger.debug("结果缓存已满,跳过缓存")
+                    except Exception as e:
+                        logger.warning(f"结果缓存存储失败: {str(e)}")
+                
+                return result_dict
             except Exception as e:
                 logger.error(f"异步解析窗口 (中心页: {center_page}) 时出错: {str(e)}")
                 return {

+ 23 - 7
src/datasets/parser/workflow_nodes/dimension_sliding_window_node.py

@@ -18,6 +18,7 @@ from src.datasets.parser.nodes import (
 )
 from src.datasets.parser.nodes.sliding_window_stitching_node import SlidingWindowStitchingNode
 from src.datasets.parser.nodes.sliding_window_parse_node import SlidingWindowParseNode
+from src.datasets.parser.nodes.sliding_window_parse_node_async import SlidingWindowParseNodeAsync
 from src.datasets.parser.nodes.result_aggregation_node import ResultAggregationNode
 from src.datasets.parser.states.parser_states import DynamicDimensionState
 from src.common.logging_config import get_logger
@@ -51,7 +52,9 @@ class DimensionSlidingWindowNode(BaseNode):
         model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
         max_workers: int = 5,
         window_size: int = 3,
-        skip_stitching: bool = False
+        skip_stitching: bool = False,
+        use_async: bool = False,
+        max_concurrent: int = 50
     ):
         """
         初始化维度滑动窗口节点
@@ -59,15 +62,19 @@ class DimensionSlidingWindowNode(BaseNode):
         Args:
             dimension_id: 维度ID
             model_name: VL模型名称
-            max_workers: 并行处理的最大工作线程数
+            max_workers: 并行处理的最大工作线程数(同步版本使用)
             window_size: 滑动窗口大小,默认3(当前页+前1页+后1页)
             skip_stitching: 是否跳过滑动窗口拆分(复用已有结果),默认False
+            use_async: 是否使用异步版本解析节点,默认False
+            max_concurrent: 异步版本的最大并发数,默认50
         """
         self.dimension_id = dimension_id
         self.model_name = model_name
         self.max_workers = max_workers
         self.window_size = window_size
         self.skip_stitching = skip_stitching
+        self.use_async = use_async
+        self.max_concurrent = max_concurrent
     
     @property
     def name(self) -> str:
@@ -84,7 +91,7 @@ class DimensionSlidingWindowNode(BaseNode):
         Returns:
             编译后的 LangGraph 工作流
         """
-        logger.info(f"[维度滑动窗口-{self.dimension_id}] 开始构建子工作流 (跳过拆分: {self.skip_stitching})")
+        logger.info(f"[维度滑动窗口-{self.dimension_id}] 开始构建子工作流 (跳过拆分: {self.skip_stitching}, 异步模式: {self.use_async})")
         rag_flow_api_key = getattr(state, 'rag_flow_api_key', '')
         decomposition_methods = getattr(state, 'decomposition_methods', None)
         is_preced = decomposition_methods[self.dimension_id]["is_preced"] or 0 if decomposition_methods else 0
@@ -102,10 +109,19 @@ class DimensionSlidingWindowNode(BaseNode):
         if not self.skip_stitching:
             stitching_node = SlidingWindowStitchingNode(window_size=self.window_size)
         
-        parse_node = SlidingWindowParseNode(
-            model_name=self.model_name,
-            max_workers=self.max_workers
-        )
+        # 根据use_async选择解析节点
+        if self.use_async:
+            parse_node = SlidingWindowParseNodeAsync(
+                model_name=self.model_name,
+                max_concurrent=self.max_concurrent
+            )
+            logger.info(f"[维度滑动窗口-{self.dimension_id}] 使用异步解析节点,最大并发: {self.max_concurrent}")
+        else:
+            parse_node = SlidingWindowParseNode(
+                model_name=self.model_name,
+                max_workers=self.max_workers
+            )
+            logger.info(f"[维度滑动窗口-{self.dimension_id}] 使用同步解析节点,线程数: {self.max_workers}")
         aggregation_node = ResultAggregationNode(
             model_name=self.model_name,
             dimension_id=self.dimension_id

+ 115 - 17
src/datasets/parser/workflows/dynamic_dimension_workflow.py

@@ -10,6 +10,8 @@
 """
 
 from typing import Dict, Any, List
+import os
+import psutil
 from langfuse.langchain import CallbackHandler
 
 from src.datasets.parser.core.workflow_builder import WorkflowBuilder
@@ -45,36 +47,117 @@ class DynamicDimensionWorkflow:
         ... )
     """
     
-    def __init__(self, model_name: str = "Qwen/Qwen3-VL-8B-Instruct", max_workers: int = 5):
+    def __init__(
+        self, 
+        model_name: str = "Qwen/Qwen3-VL-8B-Instruct", 
+        max_workers: int = 5,
+        max_concurrent_dimensions: int = 2,  # 最大并发维度数
+        memory_threshold_percent: float = 80.0,  # 内存使用阈值(百分比)
+        use_async: bool = False,  # 是否使用异步解析节点
+        max_concurrent: int = 50  # 异步版本的最大并发数
+    ):
         """
         初始化工作流
         
         Args:
             model_name: VL模型名称
-            max_workers: 每个维度内部图片解析的并行线程数
+            max_workers: 每个维度内部图片解析的并行线程数(同步版本)
+            max_concurrent_dimensions: 最大并发维度数,默认2(避免资源耗尽)
+            memory_threshold_percent: 内存使用阈值百分比,超过则降级
+            use_async: 是否使用异步解析节点,默认False
+            max_concurrent: 异步版本的最大并发数,默认50
         """
         self.model_name = model_name
         self.max_workers = max_workers
+        self.max_concurrent_dimensions = max_concurrent_dimensions
+        self.memory_threshold_percent = memory_threshold_percent
+        self.use_async = use_async
+        self.max_concurrent = max_concurrent
         self.langfuse_handler = CallbackHandler()
         # 懒加载维度分解方法(初始化时加载一次)
         self.decomposition_methods = self._get_decomposition_method()
+        
+        logger.info(f"工作流初始化: 解析模式={'异步' if use_async else '同步'}, "
+                   f"{'最大并发=' + str(max_concurrent) if use_async else '线程数=' + str(max_workers)}")
 
     def _get_decomposition_method(self):
         prompt_service = get_prompt_service()
         decomposition_methods = prompt_service.get_decomposition_method()
         return decomposition_methods
     
-    def _build_workflow_for_dimensions(self, dimension_ids: List[int]):
+    def _get_memory_usage(self) -> float:
+        """获取当前内存使用百分比"""
+        try:
+            memory = psutil.virtual_memory()
+            return memory.percent
+        except Exception as e:
+            logger.warning(f"获取内存使用失败: {e}")
+            return 0.0
+    
+    def _calculate_optimal_workers(self, dimension_count: int) -> int:
+        """
+        根据维度数量和系统资源动态计算最优线程数
+        
+        保证工作流完整执行:
+        - 最小线程数为2,确保基本并行能力
+        - 只调整并行度,不影响任务完整性
+        - 所有页面都会被处理,只是速度可能变慢
+        
+        Args:
+            dimension_count: 维度数量
+            
+        Returns:
+            优化后的线程数(最小为2)
+        """
+        # 获取CPU核心数
+        cpu_count = os.cpu_count() or 4
+        
+        # 获取内存使用情况
+        memory_percent = self._get_memory_usage()
+        
+        # 基础线程数:CPU核心数 * 2
+        base_workers = cpu_count * 2
+        
+        # 根据维度数量调整(维度越多,每个维度的线程数越少)
+        if dimension_count <= 2:
+            adjusted_workers = base_workers
+        elif dimension_count <= 4:
+            adjusted_workers = max(base_workers // 2, 4)  # 最小4个线程
+        else:
+            adjusted_workers = max(base_workers // 3, 3)  # 最小3个线程
+        
+        # 根据内存使用情况降级
+        if memory_percent > self.memory_threshold_percent:
+            logger.warning(f"内存使用率 {memory_percent:.1f}% 超过阈值 {self.memory_threshold_percent}%,降低线程数")
+            adjusted_workers = max(adjusted_workers // 2, 2)  # 极端情况最小2个线程
+        elif memory_percent > self.memory_threshold_percent - 10:
+            logger.info(f"内存使用率 {memory_percent:.1f}% 接近阈值,适度降低线程数")
+            adjusted_workers = max(int(adjusted_workers * 0.7), 3)  # 最小3个线程
+        
+        # 限制最大值
+        max_allowed = min(self.max_workers, 10)
+        final_workers = min(adjusted_workers, max_allowed)
+        
+        # 确保最小值为2,保证工作流能够完整执行
+        final_workers = max(final_workers, 2)
+        
+        logger.info(f"动态计算线程数: CPU核心={cpu_count}, 维度数={dimension_count}, "
+                   f"内存使用={memory_percent:.1f}%, 最终线程数={final_workers}")
+        
+        return final_workers
+    
+    def _build_workflow_for_dimensions(self, dimension_ids: List[int], optimal_workers: int):
         """
         根据维度ID列表动态构建 LangGraph 工作流
         
         Args:
             dimension_ids: 维度ID列表,决定节点和执行顺序
+            optimal_workers: 优化后的线程数
             
         Returns:
             编译后的 LangGraph 工作流
         """
-        logger.info(f"动态构建工作流,维度: {dimension_ids}")
+        logger.info(f"动态构建工作流,维度: {dimension_ids}, 线程数: {optimal_workers}")
         
         # 创建固定节点
         split_node = PDFSplitNode()
@@ -113,34 +196,36 @@ class DynamicDimensionWorkflow:
                 skill_node = DimensionBookSplitNode(
                     dimension_id=dim_id,
                     model_name=self.model_name,
-                    max_workers=self.max_workers
+                    max_workers=optimal_workers  # 使用优化后的线程数
                 )
             elif decomposition_method == 1:
                 skill_node = DimensionPageSplitNode(
                     dimension_id=dim_id,
                     model_name=self.model_name,
-                    max_workers=self.max_workers
+                    max_workers=optimal_workers  # 使用优化后的线程数
                 )
             elif decomposition_method == 2:
                 # 只有第一个滑动窗口维度需要执行拆分,后续维度复用结果
                 skill_node = DimensionSlidingWindowNode(
                     dimension_id=dim_id,
                     model_name=self.model_name,
-                    max_workers=self.max_workers,
+                    max_workers=optimal_workers,  # 使用优化后的线程数(同步版本)
+                    max_concurrent=self.max_concurrent,  # 异步版本的并发数
                     window_size=3,  # 滑动窗口大小:当前页+前后各1页
-                    skip_stitching=not first_sliding_window  # 非第一个维度跳过拆分
+                    skip_stitching=not first_sliding_window,  # 非第一个维度跳过拆分
+                    use_async=self.use_async  # 是否使用异步版本
                 )
                 if first_sliding_window:
                     first_sliding_window = False
-                    logger.info(f"维度 {dim_id} 将执行滑动窗口拆分")
+                    logger.info(f"维度 {dim_id} 将执行滑动窗口拆分 ({'异步' if self.use_async else '同步'}模式)")
                 else:
-                    logger.info(f"维度 {dim_id} 将复用已有的滑动窗口拆分结果")
+                    logger.info(f"维度 {dim_id} 将复用已有的滑动窗口拆分结果 ({'异步' if self.use_async else '同步'}模式)")
             else:
                 logger.warning(f"未知的分解方法: {decomposition_method},使用默认分页模式")
                 skill_node = DimensionPageSplitNode(
                     dimension_id=dim_id,
                     model_name=self.model_name,
-                    max_workers=self.max_workers
+                    max_workers=optimal_workers  # 使用优化后的线程数
                 )
             builder.add_node(skill_node)
             builder.add_edge(prev_node, skill_node.name)
@@ -185,6 +270,16 @@ class DynamicDimensionWorkflow:
         
         logger.info(f"开始运行动态多维度解析: {pdf_path}")
         logger.info(f"维度执行顺序: {dimension_ids}")
+        
+        # 检查内存使用情况(仅警告,不阻止执行)
+        memory_percent = self._get_memory_usage()
+        logger.info(f"当前内存使用: {memory_percent:.1f}%")
+        
+        if memory_percent > self.memory_threshold_percent:
+            logger.warning(f"内存使用率 {memory_percent:.1f}% 超过阈值 {self.memory_threshold_percent}%,将使用最小线程数执行")
+            logger.warning("工作流将继续执行,但速度可能较慢,请监控系统资源")
+        elif memory_percent > self.memory_threshold_percent - 10:
+            logger.info(f"内存使用率 {memory_percent:.1f}% 接近阈值,将适度降低线程数")
 
         # 查询维度知识库对应的user_id 和 api-key
         ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(3)
@@ -197,10 +292,13 @@ class DynamicDimensionWorkflow:
             logger.error("未找到维度知识库对应的user_id和api-key")
             return {"success": False, "error": "ragflow_user_not_found"}
         
-        # 1. 根据维度列表动态构建工作流
-        workflow = self._build_workflow_for_dimensions(dimension_ids)
+        # 1. 动态计算最优线程数
+        optimal_workers = self._calculate_optimal_workers(len(dimension_ids))
+        
+        # 2. 根据维度列表动态构建工作流
+        workflow = self._build_workflow_for_dimensions(dimension_ids, optimal_workers)
         
-        # 2. 创建初始状态
+        # 3. 创建初始状态
         initial_state = DynamicDimensionState(
             pdf_path=pdf_path,
             dimension_ids=dimension_ids,
@@ -214,19 +312,19 @@ class DynamicDimensionWorkflow:
             total_vectorized_pages=0
         )
         
-        # 3. 执行工作流
+        # 4. 执行工作流
         result = workflow.invoke(
             initial_state,
             config={"callbacks": [self.langfuse_handler]}
         )
         
-        # 4. 处理结果
+        # 5. 处理结果
         if isinstance(result, dict):
             final_result = result
         else:
             final_result = result.dict() if hasattr(result, 'dict') else dict(result)
         
-        # 5. 添加统计信息
+        # 6. 添加统计信息
         dim_results = final_result.get('dimension_results', {})
         success_count = sum(1 for r in dim_results.values() if r.get("success"))
         

+ 2 - 2
src/utils/task_queue.py

@@ -56,7 +56,7 @@ class TaskQueue:
             cls._instance._initialized = False
         return cls._instance
     
-    def __init__(self, max_concurrent: int = 3):
+    def __init__(self, max_concurrent: int = 1):
         if self._initialized:
             return
         
@@ -200,7 +200,7 @@ class TaskQueue:
 
 
 # 全局任务队列实例
-task_queue = TaskQueue(max_concurrent=3)
+task_queue = TaskQueue(max_concurrent=1)
 
 
 def get_task_queue() -> TaskQueue:

+ 3 - 3
src/utils/vector_db/elasticsearch_adapter.py

@@ -266,7 +266,7 @@ class ElasticsearchAdapter(VectorDBClient):
         #     query_vector = []
 
         # 将query_vector的值写入D:\project\work\graph_rag_server\logs\app.log
-        logger.info(f"混合搜索: query_vector={query_vector}")
+        # logger.info(f"混合搜索: query_vector={query_vector}")
         # 获取阈值
         knn_params = query.get("knn_params", {})
         threshold = float(knn_params.get("threshold", 0.0))
@@ -282,7 +282,7 @@ class ElasticsearchAdapter(VectorDBClient):
         topn = query.get("topn", 10)
         
         # match 的 query 不能为 None,否则 ES 报 VALUE_NULL。无文案时仅用向量(match_all)
-        if matching_text and match_field == "content":
+        if matching_text and match_field != "content":
             # 处理中文文本
             text_clause = self._build_chinese_query(match_field, matching_text)
         else:
@@ -306,7 +306,7 @@ class ElasticsearchAdapter(VectorDBClient):
                 "k": topn,
                 "num_candidates": topn * 10,
                 "similarity": threshold,
-                "boost": 8.0,
+                "boost": 5.0,
                 "filter": kb_id_filter  # kb_id 精准匹配(支持单个或数组)
             },
             size=topn,