Prechádzať zdrojové kódy

问答对生成增加步骤-图书基本问题生成

yingge 3 mesiacov pred
rodič
commit
674e1c5e97

BIN
requirements.txt


+ 1 - 1
src/api/sdk/dataset_manage.py

@@ -169,7 +169,7 @@ async def pdf_parse_v2(
 async def qa_parse_v2(
     file: UploadFile = File(..., description="PDF文件"),
     dataset_id: str = Form(..., description="RAGFlow数据集ID"),
-    qa_count_per_chunk: int = Form(default=50, description="每块生成的QA数量"),
+    qa_count_per_chunk: int = Form(default=50, gt=0, ge=50, description="每块生成的QA数量"),
     chunk_size: int = Form(default=1000, description="文本分块大小"),
     chunk_overlap: int = Form(default=200, description="分块重叠大小")
 ):

+ 9 - 40
src/api/sdk/search_infinity.py

@@ -16,7 +16,7 @@ logger = get_logger(__name__)
 _search_service = None
 
 
-def get_search_service():
+def get_ragflow_search_service():
     """获取搜索服务实例(单例模式)"""
     global _search_service
     if _search_service is None:
@@ -107,57 +107,26 @@ async def hybrid_search(request: SearchRequest):
         return Result.error(code=500, message=f"混合搜索失败: {str(e)}")
 
 # 4. 问答对检索
-@app.post("/question")
+@app.post("/retrieve")
 async def question_search(request: SearchRequest):
     """
-    问答对检索接口
+    检索接口
 
     - **output_fields**: 要返回的字段列表
     - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
     - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
     """
     try:
-        logger.info(f"Received question search request: {request.matching_text[:100]}...")
+        logger.info(f"Retrieve request: {request.matching_text[:100]}...")
         
         # 获取搜索服务实例
-        search_service = get_search_service()
+        search_service = get_ragflow_search_service()
         
         # 执行搜索
         result = await run_in_threadpool(search_service.hybrid_search, request)
         
-        logger.info(f"Question search completed successfully, found {len(result) if result else 0} results")
-        return Result.success(data=result, message="问答对检索成功")
+        logger.info(f"Retrieve completed successfully, found {len(result) if result else 0} results")
+        return Result.success(data=result, message="检索成功")
     except Exception as e:
-        logger.error(f"问答对检索失败: {str(e)}")
-        return Result.error(code=500, message=f"问答对检索失败: {str(e)}")
-
-@app.post("/resource")
-async def resource_search(request: SearchRequest):
-    """
-    资源检索接口
-
-    - **output_fields**: 要返回的字段列表
-    - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
-    - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
-    """
-    try:
-        logger.info(f"Received resource search request: {request.matching_text[:100]}...")
-        
-        # 获取搜索服务实例
-        search_service = get_search_service()
-        
-        # 获取资源表名
-        resource_table_name = await run_in_threadpool(_get_question_table_name)
-        
-        # 设置请求参数
-        request.kb_id = "f59936fbfb1f11f095280242ac180002"
-        request.database_ids = [resource_table_name]
-        
-        # 执行搜索
-        result = await run_in_threadpool(search_service.hybrid_search, request)
-        
-        logger.info(f"Resource search completed successfully, found {len(result) if result else 0} results")
-        return Result.success(data=result, message="资源检索成功")
-    except Exception as e:
-        logger.error(f"资源检索失败: {str(e)}")
-        return Result.error(code=500, message=f"资源检索失败: {str(e)}")
+        logger.error(f"检索失败: {str(e)}")
+        return Result.error(code=500, message=f"检索失败: {str(e)}")

+ 5 - 1
src/datasets/parser/nodes/__init__.py

@@ -14,7 +14,10 @@ from src.datasets.parser.nodes.ragflow_nodes import (
     RAGFlowDocumentParseNode,
     RAGFlowChunkNode,
 )
-from src.datasets.parser.nodes.qa_generate_node import QAGenerateNode
+from src.datasets.parser.nodes.qa_generate_node import (
+    QABaseGenerateNode,
+    QAGenerateNode
+)
 from src.datasets.parser.nodes.complete_node import CompleteNode
 from src.datasets.parser.nodes.pdf_ocr_node import PDFOCRNode
 from src.datasets.parser.nodes.export_csv_node import ExportCSVNode
@@ -32,6 +35,7 @@ __all__ = [
     "RAGFlowDocumentUploadNode",
     "RAGFlowDocumentParseNode",
     "RAGFlowChunkNode",
+    "QABaseGenerateNode",
     "QAGenerateNode",
     "CompleteNode",
     "PDFOCRNode",

+ 138 - 5
src/datasets/parser/nodes/qa_generate_node.py

@@ -4,7 +4,6 @@ QA对生成节点
 从文本块生成问答对。
 """
 
-import json
 import concurrent.futures
 from typing import Dict, Any, List, Optional
 from src.datasets.parser.core.base import BaseNode, BaseState
@@ -118,6 +117,140 @@ class QAGenerateNode(BaseNode):
             return []
     
     
+    def execute(self, state: BaseState) -> Dict[str, Any]:
+        """
+        执行QA对生成
+        
+        Args:
+            state: 包含text_chunks的状态
+            
+        Returns:
+            包含qa_pairs的更新字典
+        """
+        text_chunks = getattr(state, 'text_chunks', [])
+        qa_count_per_chunk = getattr(state, 'qa_count_per_chunk', self.default_qa_count)
+        qa_pairs = getattr(state, 'qa_pairs', [])
+        qa_count = getattr(state, 'qa_count', 0)
+
+        if not text_chunks:
+            logger.warning("没有待处理的文本块")
+            return {"qa_pairs": qa_pairs, "qa_count": qa_count}
+        
+        logger.info(f"开始并行生成QA对,共 {len(text_chunks)} 块,每块 {qa_count_per_chunk} 个")
+        
+        all_qa_pairs = qa_pairs
+        
+        # 使用全局线程池
+        pool = ThreadPoolManager.get_pool("parser")
+        futures = [
+            pool.submit(self._generate_qa_for_chunk, chunk, qa_count_per_chunk, i)
+            for i, chunk in enumerate(text_chunks)
+        ]
+        
+        for future in concurrent.futures.as_completed(futures):
+            try:
+                current_qa_pairs = future.result()
+                all_qa_pairs.extend(current_qa_pairs)
+            except Exception as e:
+                logger.error(f"QA生成任务失败: {str(e)}")
+        
+        logger.info(f"QA对生成完成,共 {len(all_qa_pairs)} 个")
+
+        return {
+            "qa_pairs": all_qa_pairs,
+            "qa_count": len(all_qa_pairs)
+        }
+
+
+
+@register_node()
+class QABaseGenerateNode(BaseNode):
+    """
+    QA对基础生成节点
+    
+    从文本块并行生成问答对。
+    
+    需要的状态字段:
+        - text_chunks: 文本块列表
+        - qa_count_per_chunk: 每块生成的QA数量(可选)
+        
+    更新的状态字段:
+        - qa_pairs: QA对列表
+        - qa_count: 生成的QA总数
+    """
+    
+    def __init__(
+        self,
+        model_name: Optional[str] = None,
+        default_qa_count: int = 50,
+        max_workers: int = 5
+    ):
+        """
+        初始化QA生成节点
+        
+        Args:
+            model_name: 模型名称
+            default_qa_count: 每块默认生成的QA数量
+            max_workers: 并行处理的最大工作线程数
+        """
+        self.model_name = model_name or model_settings.chat_model_name
+        self.default_qa_count = default_qa_count
+        self.max_workers = max_workers
+    
+    @property
+    def name(self) -> str:
+        return "qa_base_generate"
+    
+    def _generate_base_qa_for_chunk(self, chunk: str, count: int, chunk_index: int) -> List[Dict[str, str]]:
+        """
+        为单个分块生成QA对
+        
+        Args:
+            chunk: 文本块内容
+            count: 生成的QA数量
+            chunk_index: 块索引
+            
+        Returns:
+            QA对列表
+        """
+        system_prompt = """你是一个专业的问答对生成专家。你的任务是根据给定的文本内容,生成高质量的问答对。
+
+            要求:
+            1. 问题应该覆盖内容的不同方面和细节
+            2. 答案应该准确、简洁,直接来自文本
+            3. 问答对应该有助于知识检索和理解
+            4. 避免生成过于简单或重复的问题
+            5. 确保问题是自包含的,不需要额外上下文即可理解"""
+
+        user_prompt = f"""请根据以下内容,生成 {count} 个高质量的问答对。
+
+            内容:
+                {chunk}
+
+            请严格以JSON格式输出,格式如下:
+            [
+                {{"question": "问题1", "answer": "答案1"}},
+                {{"question": "问题2", "answer": "答案2"}}
+            ]
+
+            只输出JSON数组,不要添加任何其他内容。"""
+
+        
+        logger.debug(f"开始为第 {chunk_index + 1} 块生成QA对")
+        
+        try:
+            parser = QWenVLParser(self.model_name)
+            result = parser.chat(prompt=user_prompt, system_prompt=system_prompt)
+            
+            # 解析JSON响应(使用通用工具函数)
+            qa_pairs = parse_qa_response(result)
+            logger.debug(f"第 {chunk_index + 1} 块生成 {len(qa_pairs)} 个QA对")
+            return qa_pairs
+        except Exception as e:
+            logger.error(f"第 {chunk_index + 1} 块QA生成失败: {str(e)}")
+            return []
+    
+    
     def execute(self, state: BaseState) -> Dict[str, Any]:
         """
         执行QA对生成
@@ -142,7 +275,7 @@ class QAGenerateNode(BaseNode):
         # 使用全局线程池
         pool = ThreadPoolManager.get_pool("parser")
         futures = [
-            pool.submit(self._generate_qa_for_chunk, chunk, qa_count_per_chunk, i)
+            pool.submit(self._generate_base_qa_for_chunk, chunk, qa_count_per_chunk, i)
             for i, chunk in enumerate(text_chunks)
         ]
         
@@ -151,11 +284,11 @@ class QAGenerateNode(BaseNode):
                 qa_pairs = future.result()
                 all_qa_pairs.extend(qa_pairs)
             except Exception as e:
-                logger.error(f"QA生成任务失败: {str(e)}")
+                logger.error(f"QA基础生成任务失败: {str(e)}")
         
-        logger.info(f"QA对生成完成,共 {len(all_qa_pairs)} 个")
+        logger.info(f"QA对基础生成完成,共 {len(all_qa_pairs)} 个")
         
         return {
             "qa_pairs": all_qa_pairs,
             "qa_count": len(all_qa_pairs)
-        }
+        }

+ 4 - 0
src/datasets/parser/workflows/qa_workflow.py

@@ -13,6 +13,7 @@ from src.datasets.parser.states.parser_states import QAParsingState
 from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
 from src.datasets.parser.nodes import (
     TextSplitNode,
+    QABaseGenerateNode,
     QAGenerateNode,
     RAGFlowDocumentUploadNode,
     RAGFlowDocumentParseNode,
@@ -64,6 +65,7 @@ class QAParsingWorkflowV2:
         # 创建节点实例
         ocr_node = PDFOCRNode(model_name=self.model_name)
         split_node = TextSplitNode()
+        qa_base_node = QABaseGenerateNode(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507")
         qa_node = QAGenerateNode(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507")
         export_node = ExportCSVNode()
         upload_node = RAGFlowDocumentUploadNode(api_key=ragflow_api_key)
@@ -76,6 +78,7 @@ class QAParsingWorkflowV2:
         builder.add_nodes(
             ocr_node,
             split_node,
+            qa_base_node,
             qa_node,
             export_node,
             upload_node,
@@ -88,6 +91,7 @@ class QAParsingWorkflowV2:
         builder.add_sequence(
             "pdf_ocr",
             "text_split",
+            "qa_base_generate",
             "qa_generate",
             "export_csv",
             "ragflow_document_upload",