Kaynağa Gözat

19维度图书解析流程调整

yingge 3 ay önce
ebeveyn
işleme
36d9b30c8a

+ 6 - 0
src/api/db/models/prompt_dimensions_models.py

@@ -0,0 +1,6 @@
+
+class DecompositionMethod(BaseModel):
+    dimension_id: int
+    decomposition_method: str
+    is_preced: bool
+    preced_node: List[int]

+ 4 - 3
src/api/db/services/prompt_service.py

@@ -239,15 +239,16 @@ class PromptService:
         result = self._db.fetch_one(sql, [dimension_id])
         return result['dataset_id'] if result else None
 
-    def get_decomposition_method(self) -> Dict[str, Dict[str, Any]]:
+    def get_decomposition_method(self) -> Dict[int, Dict[str, Any]]:
         """获取所有维度的decomposition method"""
         sql = "SELECT id, decomposition_method, is_preced, preced_node FROM prompt_dimensions"
         results = self._db.fetch_all(sql)
         return {
-            str(row['id']): {
+            row['id']: {
                 "decomposition_method": row['decomposition_method'],
                 "is_preced": row['is_preced'],
-                "preced_node": json.loads(row['preced_node'])
+                # preced_node的格式为[2,3],将其解析为数组[2,3]
+                "preced_node": json.loads(row['preced_node']) if row['preced_node'] else []
             }
             for row in results
         }

+ 3 - 1
src/datasets/parser/nodes/pdf_split_node.py

@@ -45,7 +45,6 @@ class PDFSplitNode(BaseNode):
         
         pdf_path = getattr(state, 'pdf_path', None)
         pdf_content = getattr(state, 'pdf_content', None)
-        original_filename = getattr(state, 'original_filename', None)
         
         if not pdf_path and not pdf_content:
             raise ValueError("State must contain either 'pdf_path' or 'pdf_content' field")
@@ -55,6 +54,9 @@ class PDFSplitNode(BaseNode):
         else:
             logger.info(f"开始拆分PDF: {pdf_path}")
         
+        # 根据pdf_path获取原始文件名(去除后缀)
+        original_filename = pdf_path.split('/')[-1].split('.')[0] if pdf_path else None
+
         # 拆分PDF
         splitter = PDFSplitter()
         split_pages = splitter.split_pdf(

+ 26 - 1
src/datasets/parser/nodes/picture_stitching_node.py

@@ -5,9 +5,12 @@
 """
 
 from typing import Dict, Any
+from io import BytesIO
 from PIL import Image
 from src.datasets.parser.core.base import BaseNode, BaseState
 from src.datasets.parser.core.registry import register_node
+from src.utils.file.image_util import ImageUtil
+
 from src.common.logging_config import get_logger
 
 logger = get_logger(__name__)
@@ -100,6 +103,28 @@ class PictureStitchingNode(BaseNode):
         
         logger.info(f"图片拼接完成,最终尺寸: {stitched_image.size}")
         
+        # 将合成后的图片进行压缩
+        image_util = ImageUtil()
+        
+        # 检查像素数量是否超过Pillow安全限制
+        max_pixels = Image.MAX_IMAGE_PIXELS
+        total_pixels = stitched_image.width * stitched_image.height
+        if total_pixels > max_pixels:
+            logger.warning(f"图片像素数 ({total_pixels}) 超过安全限制 ({max_pixels}),进行缩放处理")
+            # 计算缩放比例,将像素数降到安全限制的80%
+            target_pixels = max_pixels * 0.8
+            scale_ratio = (target_pixels / total_pixels) ** 0.5
+            new_width = int(stitched_image.width * scale_ratio)
+            new_height = int(stitched_image.height * scale_ratio)
+            stitched_image = stitched_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
+            logger.info(f"图片已缩放至: {stitched_image.size}")
+        
+        image_stream = BytesIO()
+        stitched_image.save(image_stream, format='JPEG')
+        image_stream.seek(0)
+        compressed_bytes = image_util._compress_image_to_bytes(image_stream)
+        compressed_image = Image.open(BytesIO(compressed_bytes))
+        
         return {
-            "book_image": stitched_image
+            "book_image": compressed_image
         }

+ 7 - 1
src/datasets/parser/nodes/prompt_retrieval_node.py

@@ -100,7 +100,13 @@ class PromptRetrievalNode(BaseNode):
             with open(f"temp/{original_filename}_{node}.md", "r", encoding="utf-8") as f:
                 current_content = f.read()
             content.append(current_content)
-        chat_model = QWenVLParser(self.model_name)
+
+        # 判断content是否为空
+        if not content:
+            logger.info(f"[Prompt-{self.dimension_id}] preceding_node中没有内容")
+            return dimension_prompt
+        
+        chat_model = QWenVLParser()
         # 使用大语言模型对content列表中的content进行整合、压缩
         system_prompt = """
         你是一个高级数据处理助手。请对输入的 内容 进行结构化压缩。

+ 7 - 11
src/datasets/parser/nodes/summary_node.py

@@ -6,8 +6,6 @@
 
 import os
 from typing import Dict, Any, List, Optional
-from langchain_core.prompts import ChatPromptTemplate
-
 from src.datasets.parser.core.base import BaseNode, BaseState
 from src.datasets.parser.core.registry import register_node
 from src.conf.settings import model_settings
@@ -25,7 +23,7 @@ class SummaryNode(BaseNode):
     对图像解析结果进行总结提要,并将结果保存到文件。
     """
     
-    def __init__(self, dimension_id: int, model_name: Optional[str] = None):
+    def __init__(self, dimension_id: int = 0, model_name: Optional[str] = None):
         """
         初始化总结节点
         
@@ -36,8 +34,7 @@ class SummaryNode(BaseNode):
         self.dimension_id = dimension_id
         self.model_name = model_name or model_settings.chat_model_name
         # 总结提示模板
-        self.summary_prompt = ChatPromptTemplate.from_template(
-            """
+        self.summary_prompt = """
             你是一位专业的内容总结专家,擅长从长篇文本中提取核心内容并生成简洁的总结。
             
             请对以下内容进行总结,要求:
@@ -51,7 +48,6 @@ class SummaryNode(BaseNode):
             
             总结:
             """
-        )
     
     @property
     def name(self) -> str:
@@ -69,11 +65,11 @@ class SummaryNode(BaseNode):
         """
         try:
             # 构建提示
-            messages = self.summary_prompt.format_messages(content=content)
+            messages = self.summary_prompt.format(content=content)
             # 调用模型生成总结
             chat_model = QWenVLParser(self.model_name)
-            response = chat_model.invoke(messages)
-            return response.content
+            response = chat_model.chat(prompt=messages)
+            return response
         except Exception as e:
             logger.error(f"总结内容时出错: {str(e)}")
             return f"总结失败: {str(e)}"
@@ -90,7 +86,7 @@ class SummaryNode(BaseNode):
         """
         # 获取解析结果
         parsed_results = getattr(state, 'parsed_results', [])
-        
+        original_filename = getattr(state, 'original_filename', '')
         if not parsed_results:
             logger.warning("没有解析结果可总结")
             return {
@@ -128,7 +124,7 @@ class SummaryNode(BaseNode):
         os.makedirs(temp_dir, exist_ok=True)
         
         # 保存总结到文件
-        file_path = os.path.join(temp_dir, f"{state.original_filename}_{self.dimension_id}.md")
+        file_path = os.path.join(temp_dir, f"{original_filename}_{self.dimension_id}.md")
         try:
             with open(file_path, 'w', encoding='utf-8') as f:
                 f.write(summary)

+ 0 - 1
src/datasets/parser/workflow_nodes/dimension_book_split_node.py

@@ -88,7 +88,6 @@ class DimensionBookSplitNode(BaseNode):
         table_name_node = TableNameGenerationNode(self.dimension_id)
         stitching_node = PictureStitchingNode()
         parse_node = ImageParseNode(
-            dimension_id=self.dimension_id,
             model_name=self.model_name,
             max_workers=self.max_workers,
             use_book_image=True,  # 标记使用book_image模式

+ 0 - 1
src/datasets/parser/workflow_nodes/dimension_page_split_node.py

@@ -81,7 +81,6 @@ class DimensionPageSplitNode(BaseNode):
         document_upload_node = RAGFlowDocumentUploadNode(api_key=rag_flow_api_key)
         table_name_node = TableNameGenerationNode(self.dimension_id)
         parse_node = ImageParseNode(
-            dimension_id=self.dimension_id,
             model_name=self.model_name,
             max_workers=self.max_workers,
         )

+ 9 - 4
src/datasets/parser/workflows/dynamic_dimension_workflow.py

@@ -57,7 +57,12 @@ class DynamicDimensionWorkflow:
         self.max_workers = max_workers
         self.langfuse_handler = CallbackHandler()
         # 懒加载维度分解方法(初始化时加载一次)
-        self._decomposition_methods = get_prompt_service().get_decomposition_method()
+        self.decomposition_methods = self._get_decomposition_method()
+
+    def _get_decomposition_method(self):
+        prompt_service = get_prompt_service()
+        decomposition_methods = prompt_service.get_decomposition_method()
+        return decomposition_methods
     
     def _build_workflow_for_dimensions(self, dimension_ids: List[int]):
         """
@@ -76,7 +81,7 @@ class DynamicDimensionWorkflow:
         complete_node = CompleteNode(message_template="动态多维度解析完成")
         
         # 使用预加载的维度分解方法
-        decomposition_methods = self._decomposition_methods
+        decomposition_methods = self.decomposition_methods
 
         # 使用 WorkflowBuilder 构建
         builder = WorkflowBuilder(DynamicDimensionState)
@@ -94,7 +99,7 @@ class DynamicDimensionWorkflow:
         prev_node = "pdf_split"
         
         for dim_id in dimension_ids:
-            decomposition_method = decomposition_methods[str(dim_id)]["decomposition_method"]
+            decomposition_method = decomposition_methods[dim_id]["decomposition_method"]
             # 为每个维度创建分页拆分节点
             if decomposition_method == 0:
                 skill_node = DimensionBookSplitNode(
@@ -172,7 +177,7 @@ class DynamicDimensionWorkflow:
             dimension_ids=dimension_ids,
             ragflow_id=ragflow_id,
             rag_flow_api_key=ragflow_api_key,
-            decomposition_methods=self._decomposition_methods,
+            decomposition_methods=self.decomposition_methods,
             # dataset_name=dataset_name or pdf_path.split("/")[-1].split("\\")[-1].replace(".pdf", ""),
             # ragflow_api_url=ragflow_api_url,
             # rag_flow_api_key=rag_flow_api_key,

+ 2 - 0
src/model/qwen_vl.py

@@ -126,6 +126,7 @@ class QWenVLParser:
             
             return {
                 "page_number": page_number,
+                "chunk_id": "",
                 "content": content,
                 "model": self.model_name
             }
@@ -175,6 +176,7 @@ class QWenVLParser:
             
             return {
                 "page_number": page_number,
+                "chunk_id": "",
                 "content": content,
                 "model": self.model_name
             }