Selaa lähdekoodia

问答对解析器&问答对检索器

yingge 3 kuukautta sitten
vanhempi
commit
434dc600f1

+ 1 - 1
src/api/dataset/services/dataset_manage_service.py

@@ -10,7 +10,7 @@
 import os
 import tempfile
 from typing import Dict, Any, Optional
-from src.parser.pdf_parser.pdf_parser_workflow import PDFParsingWorkflow
+from src.datasets.parser.pdf_parser.pdf_parser_workflow import PDFParsingWorkflow
 from src.conf.settings import vector_db_settings
 
 

+ 8 - 0
src/api/db/models/__init__.py

@@ -0,0 +1,8 @@
+from .search_request_models import SearchRequest
+# from .search_param_models import KnnParams, VectorSearchQuery
+
+__all__ = [
+    "SearchRequest",
+    # "KnnParams",
+    # "VectorSearchQuery",
+]

+ 79 - 0
src/api/db/models/search_param_models.py

@@ -0,0 +1,79 @@
+# """
+# 搜索相关的Pydantic模型定义
+
+# 该模块定义了用于图像和文本搜索的请求和响应模型。
+# """
+
+# from pydantic import BaseModel, Field, field_validator
+# from typing import List, Optional
+
+# class KnnParams(BaseModel):
+#     """
+#     KNN搜索参数模型
+    
+#     用于配置向量搜索的近邻搜索参数。
+    
+#     Attributes:
+#         ef: HNSW索引的搜索参数,影响搜索精度和速度
+#         threshold: 相似度分数阈值,低于该阈值的结果将被过滤
+#     """
+    
+#     ef: str = Field(
+#         default="100",
+#         description="HNSW索引的搜索参数 (ef值越大搜索越精确但速度越慢)"
+#     )
+#     threshold: str = Field(
+#         default="0.0",
+#         description="相似度分数阈值,范围 0.0-1.0"
+#     )
+    
+#     @classmethod
+#     def from_retrieval_setting(cls, top_k: int, score_threshold: float) -> "KnnParams":
+#         """
+#         从检索设置创建KNN参数
+        
+#         Args:
+#             top_k: 返回的结果数量
+#             score_threshold: 分数阈值
+            
+#         Returns:
+#             KnnParams实例
+#         """
+#         return cls(
+#             ef=str(top_k * 10),
+#             threshold=str(score_threshold)
+#         )
+
+
+# class VectorSearchQuery(BaseModel):
+#     """
+#     向量搜索查询参数模型
+    
+#     封装向量检索所需的所有参数。
+    
+#     Attributes:
+#         vector_field: 向量字段名称
+#         query_vector: 查询向量
+#         topn: 返回的最大结果数量
+#         knn_params: KNN搜索参数
+#     """
+    
+#     vector_field: str = Field(
+#         ...,
+#         description="向量字段名称",
+#         examples=["dense_vector_1024"]
+#     )
+#     query_vector: List[float] = Field(
+#         ...,
+#         description="查询向量"
+#     )
+#     topn: int = Field(
+#         default=10,
+#         ge=1,
+#         le=100,
+#         description="返回的最大结果数量"
+#     )
+#     knn_params: Optional[KnnParams] = Field(
+#         default=None,
+#         description="KNN搜索参数"
+#     )

+ 65 - 0
src/api/db/models/search_request_models.py

@@ -0,0 +1,65 @@
+"""
+搜索相关的Pydantic模型定义
+
+该模块定义了用于图像和文本搜索的请求和响应模型。
+"""
+from pydantic import BaseModel, Field, field_validator
+from typing import List, Optional
+
+
+class SearchRequest(BaseModel):
+    """
+    搜索请求模型
+    
+    支持通过图像URL或文本进行相似性搜索。
+    至少需要提供 image_url 或 matching_text 其中之一。
+    
+    Attributes:
+        image_url: 用于图像搜索的图片URL (可选)
+        matching_text: 用于文本匹配搜索的查询文本 (可选)
+        topn: 返回的最大结果数量,默认为10,范围1-100
+    """
+    
+    image_url: Optional[str] = Field(
+        default=None,
+        description="用于图像搜索的图片URL",
+        examples=["https://example.com/image.jpg"]
+    )
+    matching_text: Optional[str] = Field(
+        default=None,
+        description="用于文本匹配搜索的查询文本",
+        examples=["搜索关键词"]
+    )
+    topn: int = Field(
+        default=10,
+        ge=1,
+        le=100,
+        description="返回的最大结果数量"
+    ),
+    score_threshold: float = Field(
+        default=0.0,
+        ge=0.0,
+        le=1.0,
+        description="相似度分数阈值"
+    )
+    
+    @field_validator('image_url')
+    @classmethod
+    def validate_image_url(cls, v: Optional[str]) -> Optional[str]:
+        """验证图片URL格式"""
+        if v is not None and v.strip() == "":
+            return None
+        return v
+    
+    @field_validator('matching_text')
+    @classmethod
+    def validate_matching_text(cls, v: Optional[str]) -> Optional[str]:
+        """验证匹配文本"""
+        if v is not None and v.strip() == "":
+            return None
+        return v
+    
+    def model_post_init(self, __context) -> None:
+        """模型初始化后验证:确保至少提供一个搜索条件"""
+        if self.image_url is None and self.matching_text is None:
+            raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")

+ 78 - 26
src/api/db/services/infinity_search_service.py

@@ -3,13 +3,13 @@ from src.conf.settings import vector_db_settings
 from src.utils.infinity import InfinityClient
 from src.utils.file.image_util import image_util
 from src.model.multimodal_embedding import get_embedding_model
-from src.utils.infinity.result_util import convert_to_basic_types
+from src.utils.infinity.result_util import convert_to_json
+from src.api.db.models import SearchRequest
+from langchain_core.documents import Document
 
 class InfinitySearchService:
-    def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None):
-        self.infinity_client = infinity_client
-        # 输出字段
-        self.output_fields = [
+
+    default_output_fields = [
                     "file_name",
                     "page_number",
                     "content",
@@ -17,12 +17,33 @@ class InfinitySearchService:
                     "dataset_id",
                     "document_id"
                 ]
+
+    def __init__(self, infinity_client: InfinityClient, 
+                       vector_field: str = None, 
+                       match_field: str = None, 
+                       match_type: str = None, 
+                       table_name: str = None, 
+                       output_fields: List[str] = None):
+        """
+        初始化InfinitySearchService
+        
+        Args:
+            infinity_client: InfinityClient实例
+            vector_field: 向量字段名
+            match_field: 匹配字段名
+            match_type: 匹配类型
+            table_name: 表名
+            output_fields: 输出字段列表
+        """
+        self.infinity_client = infinity_client
+        # 输出字段
+        self.output_fields = output_fields or self.default_output_fields
         self.vector_field = vector_field or "dense_vector_1024"
         self.match_field = match_field or "content"
         self.match_type = match_type or "cosine"
         self.table_name = table_name or vector_db_settings.infinity_table_name
 
-    def search(self, search_query: Dict[str, Any]) -> Dict[str, Any]:
+    def search(self, request: SearchRequest) -> List[Dict[str, Any]]:
         """
         执行Infinity数据库搜索
     
@@ -33,16 +54,18 @@ class InfinitySearchService:
             搜索结果,转换为基本类型以便序列化
         """
         try:
+            # 参数转换
+            search_query = self._convert_search_request_to_search_query(request)
             # 执行搜索
             result = self.infinity_client.search(self.table_name, self.output_fields, search_query)
             # 将结果转换为基本类型,处理可能的复杂类型
             result_dict = result.to_result()
             # 递归转换所有复杂类型为基本类型
-            return convert_to_basic_types(result_dict)
+            return convert_to_json(result_dict)
         except Exception as e:
             raise Exception(f"搜索失败: {str(e)}")
 
-    def vector_search(self, search_query: Dict[str, Any]):
+    def vector_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
         """
         执行Infinity数据库向量检索
     
@@ -53,23 +76,18 @@ class InfinitySearchService:
             向量检索结果,转换为基本类型以便序列化
         """
         try:
-            # 1.处理image_url为image: Image.Image
-            image = image_util._url_to_image(search_query["image_url"])
-            # 2.将图片进行向量化
-            query_vector = get_embedding_model().get_multimodal_embedding(search_query["matching_text"], image)
-
-            search_query["vector_field"] = self.vector_field
-            search_query["query_vector"] = query_vector
+            # 参数转换
+            search_query = self._convert_search_request_to_search_query(request)
             # 执行向量检索
             result = self.infinity_client.vector_search(self.table_name, self.output_fields, search_query)
             # 将结果转换为基本类型,处理可能的复杂类型
             result_dict = result.to_result()
             # 递归转换所有复杂类型为基本类型
-            return convert_to_basic_types(result_dict)
+            return convert_to_json(result_dict)
         except Exception as e:
             raise Exception(f"向量检索失败: {str(e)}")
 
-    def hybrid_search(self, search_query: Dict[str, Any]):
+    def hybrid_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
         """
         执行Infinity数据库混合检索
     
@@ -80,18 +98,52 @@ class InfinitySearchService:
             混合检索结果,转换为基本类型以便序列化
         """
         try:
-            # 1.处理image_url为image: Image.Image
-            image = image_util._url_to_image(search_query["image_url"])
-            # 2.将图片进行向量化
-            query_vector = get_embedding_model().get_multimodal_embedding(search_query["matching_text"], image)
-            search_query["vector_field"] = self.vector_field
-            search_query["query_vector"] = query_vector
-            search_query["match_field"] = self.match_field
+            # 参数转换
+            search_query = self._convert_search_request_to_search_query(request)
             # 执行混合检索
             result = self.infinity_client.hybrid_search(self.table_name, self.output_fields, search_query)
             # 将结果转换为基本类型,处理可能的复杂类型
             result_dict = result.to_result()
             # 递归转换所有复杂类型为基本类型
-            return convert_to_basic_types(result_dict)
+            return convert_to_json(result_dict)
+        except Exception as e:
+            raise Exception(f"混合检索失败: {str(e)}")
+
+    # 参数转换
+    def _convert_search_request_to_search_query(self, request: SearchRequest) -> dict:
+        """
+        将SearchRequest转换为SearchQuery
+        
+        Args:
+            request: SearchRequest实例
+            
+        Returns:
+            SearchQuery字典
+        """
+        try:
+            if request.score_threshold is not None:
+                # 判断是文本向量还是多模态向量
+                if request.image_url is not None:
+                    # 处理image_url为image: Image.Image
+                    image = image_util._url_to_image(request.image_url)
+                    # 将图片进行向量化
+                    query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image)
+                else:
+                    query_vector = get_embedding_model().get_text_embedding(request.matching_text)
+            else:
+                query_vector = None
+            
+            search_query = {
+                "match_field": self.match_field,
+                "matching_text": request.matching_text,
+                "vector_field": self.vector_field,
+                "query_vector": query_vector,
+                "topn": request.topn,
+                "knn_params": {
+                    "ef": str(request.topn * 10),
+                    "threshold": str(request.score_threshold)
+                }
+            }
+            return search_query
         except Exception as e:
-            raise Exception(f"混合检索失败: {str(e)}")
+            raise Exception(f"参数转换失败: {str(e)}")

+ 25 - 12
src/api/sdk/search_infinity.py

@@ -1,12 +1,11 @@
 # Infinity搜索API服务
 
 from fastapi import FastAPI, HTTPException
-from typing import List, Dict, Any, Optional
 from src.api.db.services.infinity_search_service import InfinitySearchService
 from src.utils.infinity import get_client
 from src.common.result import Result
 from src.utils.async_utils import run_in_threadpool
-
+from src.api.db.models import SearchRequest
 
 # 创建FastAPI应用
 
@@ -16,13 +15,6 @@ app = FastAPI(
     version="1.0.0"
 )
 
-# 请求模型
-from pydantic import BaseModel
-
-class SearchRequest(BaseModel):
-    """搜索请求模型"""
-    search_query: Dict[str, Any]
-
 # 1. 普通搜索接口
 @app.post("/text")
 async def search(request: SearchRequest):
@@ -36,7 +28,7 @@ async def search(request: SearchRequest):
     """
     try:
         search_service = InfinitySearchService(infinity_client=get_client())
-        result = await run_in_threadpool(search_service.search, request.search_query)
+        result = await run_in_threadpool(search_service.search, request)
         return Result.success(data=result, message="搜索成功")
     except Exception as e:
         return Result.error(code=500, message=f"搜索失败: {str(e)}")
@@ -54,7 +46,7 @@ async def vector_search(request: SearchRequest):
     """
     try:
         search_service = InfinitySearchService(infinity_client=get_client())
-        result = await run_in_threadpool(search_service.vector_search, request.search_query)
+        result = await run_in_threadpool(search_service.vector_search, request)
         return Result.success(data=result, message="向量搜索成功")
     except Exception as e:
         return Result.error(code=500, message=f"向量搜索失败: {str(e)}")
@@ -72,7 +64,28 @@ async def hybrid_search(request: SearchRequest):
     """
     try:
         search_service = InfinitySearchService(infinity_client=get_client())
-        result = await run_in_threadpool(search_service.hybrid_search, request.search_query)
+        result = await run_in_threadpool(search_service.hybrid_search, request)
         return Result.success(data=result, message="混合搜索成功")
     except Exception as e:
         return Result.error(code=500, message=f"混合搜索失败: {str(e)}")
+
+# 4. 问答对检索
+@app.post("/question")
+async def question_search(request: SearchRequest):
+    """
+    问答对检索接口
+
+    - **output_fields**: 要返回的字段列表
+    - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
+    - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
+    """
+    try:
+        output_fields = ["content"]
+        search_service = InfinitySearchService(infinity_client=get_client(database="ragflow_db"), 
+                                               table_name="ragflow_92162247e93e11f084830242ac1d0002_1c4f7a82f66c11f09c750242c0a8d002", 
+                                               vector_field="q_1024_vec",
+                                               output_fields=output_fields)
+        result = await run_in_threadpool(search_service.hybrid_search, request)
+        return Result.success(data=result, message="问答对检索成功")
+    except Exception as e:
+        return Result.error(code=500, message=f"问答对检索失败: {str(e)}")

+ 0 - 0
src/parser/__init__.py → src/datasets/__init__.py


+ 0 - 0
src/parser/image_parser/__init__.py → src/datasets/parser/__init__.py


+ 0 - 0
src/parser/pdf_parser/__init__.py → src/datasets/parser/image_parser/__init__.py


+ 0 - 0
src/parser/image_parser/image_parser_workflow.py → src/datasets/parser/image_parser/image_parser_workflow.py


+ 0 - 0
src/datasets/parser/pdf_parser/__init__.py


+ 1 - 1
src/parser/pdf_parser/pdf_parser_workflow.py → src/datasets/parser/pdf_parser/pdf_parser_workflow.py

@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
 from langgraph.graph import StateGraph, START, END
 from typing import List, Dict, Any
 from pydantic import BaseModel, Field, ConfigDict
-from src.parser.pdf_parser.pdf_splitter import PDFSplitter
+from src.datasets.parser.pdf_parser.pdf_splitter import PDFSplitter
 from src.model.qwen_vl import QWenVLParser
 from src.utils.ragflow.ragflow_service import RAGFlowService
 from src.utils.ragflow.chunk_record import get_chunk_record_service

+ 0 - 0
src/parser/pdf_parser/pdf_splitter.py → src/datasets/parser/pdf_parser/pdf_splitter.py


+ 0 - 0
src/parser/pdf_parser/test_service.py → src/datasets/parser/pdf_parser/test_service.py


+ 417 - 0
src/datasets/parser/question_answer_parser/question_answer_parser.py

@@ -0,0 +1,417 @@
+"""
+QA模式问答对解析器工作流
+
+使用LangGraph实现完整的PDF到QA知识库工作流:
+PDF OCR解析 → 文本分块 → 生成QA对 → 上传RAGFlow
+"""
+
+import os
+import json
+import csv
+import tempfile
+import concurrent.futures
+from concurrent.futures import ThreadPoolExecutor
+from typing import List, Dict, Any, Optional
+from pydantic import BaseModel, Field, ConfigDict
+from langgraph.graph import StateGraph, START, END
+from langchain.chat_models import init_chat_model
+from langchain_core.messages import HumanMessage, SystemMessage
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from src.datasets.parser.pdf_parser.pdf_splitter import PDFSplitter
+from src.model.qwen_vl import QWenVLParser
+from src.utils.ragflow.ragflow_service import RAGFlowService
+from src.conf.settings import model_settings
+from langfuse.langchain import CallbackHandler
+
+
+class QAParsingState(BaseModel):
+    """QA解析工作流状态"""
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+    
+    # 输入参数
+    pdf_path: str = Field(..., description="PDF文件路径")
+    dataset_id: str = Field(..., description="RAGFlow数据集ID")
+    qa_count_per_chunk: int = Field(default=50, description="每块生成的QA数量")
+    chunk_size: int = Field(default=1000, description="文本分块大小")
+    chunk_overlap: int = Field(default=200, description="分块重叠大小")
+    
+    # 中间状态
+    extracted_text: str = Field(default="", description="OCR提取的文本")
+    chunks: List[str] = Field(default_factory=list, description="分块后的文本列表")
+    qa_pairs: List[Dict[str, Any]] = Field(default_factory=list, description="生成的QA对列表")
+    csv_path: str = Field(default="", description="临时CSV文件路径")
+    
+    # 输出状态
+    uploaded_document_id: str = Field(default="", description="上传后的文档ID")
+    qa_count: int = Field(default=0, description="生成的QA对数量")
+    is_complete: bool = Field(default=False, description="是否处理完成")
+    error_message: Optional[str] = Field(default=None, description="错误信息")
+
+
+class QuestionAnswerParserWorkflow:
+    """QA问答对解析工作流"""
+    
+    def __init__(self, model_name: str = None):
+        """
+        初始化QA解析工作流
+        
+        Args:
+            model_name: VL模型名称,用于PDF OCR解析
+        """
+        self.model_name = model_name or model_settings.model_name
+        self.ragflow_service = RAGFlowService()
+        self.langfuse_handler = CallbackHandler()
+        
+        # 初始化LangChain chat model用于QA生成
+        self.chat_model = init_chat_model(
+            model_provider=model_settings.model_provider,
+            model=model_settings.chat_model_name,
+            api_key=model_settings.api_key,
+            base_url=model_settings.base_url,
+            temperature=0.7
+        )
+        
+        self.workflow = self._build_workflow()
+    
+    def _build_workflow(self):
+        """构建LangGraph工作流"""
+        graph = StateGraph(QAParsingState)
+        
+        # 添加节点
+        graph.add_node("parse_pdf", self._parse_pdf_node)
+        graph.add_node("split_text", self._split_text_node)
+        graph.add_node("generate_qa", self._generate_qa_node)
+        graph.add_node("export_csv", self._export_csv_node)
+        graph.add_node("upload_document", self._upload_document_node)
+        graph.add_node("parse_document", self._parse_document_node)
+        graph.add_node("complete", self._complete_node)
+        
+        # 定义边
+        graph.add_edge(START, "parse_pdf")
+        graph.add_edge("parse_pdf", "split_text")
+        graph.add_edge("split_text", "generate_qa")
+        graph.add_edge("generate_qa", "export_csv")
+        graph.add_edge("export_csv", "upload_document")
+        graph.add_edge("upload_document", "parse_document")
+        graph.add_edge("parse_document", "complete")
+        graph.add_edge("complete", END)
+        
+        return graph.compile()
+    
+    def _parse_pdf_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """PDF OCR解析节点 - 提取文本内容"""
+        print(f"开始解析PDF: {state.pdf_path}")
+        
+        try:
+            # 使用PDFSplitter拆分PDF为图片
+            splitter = PDFSplitter()
+            pages = splitter.split_pdf(state.pdf_path)
+            
+            # 使用QWenVL模型提取每页文本
+            extracted_texts = []
+            parser = QWenVLParser(self.model_name)
+            
+            for page in pages:
+                page_number = page["page_number"]
+                image = page["image"]
+                
+                # OCR提取文本的prompt
+                prompt = """请提取图片中的所有文字内容,保持原有的段落结构。
+只输出提取的文字,不要添加任何额外的说明或格式。"""
+                
+                result = parser.parse_image(image, page_number, prompt)
+                text = result.get("content", "")
+                extracted_texts.append(text)
+                print(f"第 {page_number} 页文本提取完成")
+            
+            full_text = "\n\n".join(extracted_texts)
+            print(f"PDF解析完成,提取文本长度: {len(full_text)} 字符")
+            
+            return {"extracted_text": full_text}
+            
+        except Exception as e:
+            print(f"PDF解析失败: {str(e)}")
+            return {"error_message": f"PDF解析失败: {str(e)}"}
+    
+    def _split_text_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """文本分块节点"""
+        print("开始文本分块...")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.extracted_text:
+            return {"error_message": "没有提取到文本内容"}
+        
+        try:
+            # 使用LangChain的RecursiveCharacterTextSplitter
+            text_splitter = RecursiveCharacterTextSplitter(
+                chunk_size=state.chunk_size,
+                chunk_overlap=state.chunk_overlap,
+                length_function=len,
+                separators=["\n\n", "\n", "。", "!", "?", ";", " ", ""]
+            )
+            
+            chunks = text_splitter.split_text(state.extracted_text)
+            print(f"文本分块完成,共 {len(chunks)} 个分块")
+            
+            return {"chunks": chunks}
+            
+        except Exception as e:
+            print(f"文本分块失败: {str(e)}")
+            return {"error_message": f"文本分块失败: {str(e)}"}
+    
+    def _generate_qa_for_chunk(self, chunk: str, count: int, chunk_index: int) -> List[Dict[str, str]]:
+        """为单个分块生成QA对"""
+        print(f"正在为分块 {chunk_index + 1} 生成 {count} 个QA对...")
+        
+        system_prompt = """你是一个专业的问答对生成专家。你的任务是根据给定的文本内容,生成高质量的问答对。
+
+            要求:
+            1. 问题应该覆盖内容的不同方面和细节
+            2. 答案应该准确、简洁,直接来自文本
+            3. 问答对应该有助于知识检索和理解
+            4. 避免生成过于简单或重复的问题
+            5. 确保问题是自包含的,不需要额外上下文即可理解"""
+
+        user_prompt = f"""请根据以下内容,生成 {count} 个高质量的问答对。
+
+            内容:
+                {chunk}
+
+            请严格以JSON格式输出,格式如下:
+            [
+                {{"question": "问题1", "answer": "答案1"}},
+                {{"question": "问题2", "answer": "答案2"}}
+            ]
+
+            只输出JSON数组,不要添加任何其他内容。"""
+
+        try:
+            messages = [
+                SystemMessage(content=system_prompt),
+                HumanMessage(content=user_prompt)
+            ]
+            
+            response = self.chat_model.invoke(messages)
+            content = response.content if hasattr(response, 'content') else str(response)
+            
+            # 尝试解析JSON
+            # 清理可能的markdown代码块标记
+            content = content.strip()
+            if content.startswith("```json"):
+                content = content[7:]
+            if content.startswith("```"):
+                content = content[3:]
+            if content.endswith("```"):
+                content = content[:-3]
+            content = content.strip()
+            
+            qa_list = json.loads(content)
+            print(f"分块 {chunk_index + 1} 生成了 {len(qa_list)} 个QA对")
+            return qa_list
+            
+        except json.JSONDecodeError as e:
+            print(f"分块 {chunk_index + 1} JSON解析失败: {str(e)}")
+            return []
+        except Exception as e:
+            print(f"分块 {chunk_index + 1} QA生成失败: {str(e)}")
+            return []
+    
+    def _generate_qa_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """生成QA对节点 - 并行处理所有分块"""
+        print(f"开始为 {len(state.chunks)} 个分块生成QA对...")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.chunks:
+            return {"error_message": "没有可处理的文本分块"}
+        
+        all_qa_pairs = []
+        
+        # 使用ThreadPoolExecutor并行处理
+        with ThreadPoolExecutor(max_workers=5, thread_name_prefix="qa_gen_") as executor:
+            futures = {
+                executor.submit(
+                    self._generate_qa_for_chunk, 
+                    chunk, 
+                    state.qa_count_per_chunk,
+                    i
+                ): i for i, chunk in enumerate(state.chunks)
+            }
+            
+            for future in concurrent.futures.as_completed(futures):
+                chunk_index = futures[future]
+                try:
+                    qa_list = future.result()
+                    for qa in qa_list:
+                        qa["chunk_index"] = chunk_index
+                    all_qa_pairs.extend(qa_list)
+                except Exception as e:
+                    print(f"分块 {chunk_index + 1} 处理异常: {str(e)}")
+        
+        # 按chunk_index排序
+        all_qa_pairs.sort(key=lambda x: x.get("chunk_index", 0))
+        
+        print(f"QA对生成完成,共生成 {len(all_qa_pairs)} 个问答对")
+        return {"qa_pairs": all_qa_pairs}
+    
+    def _export_csv_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """导出QA对到CSV临时文件节点"""
+        print(f"开始导出 {len(state.qa_pairs)} 个QA对到CSV文件...")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.qa_pairs:
+            return {"error_message": "没有可导出的QA对"}
+        
+        try:
+            # 使用PDF文件名作为CSV文件名前缀
+            pdf_basename = os.path.splitext(os.path.basename(state.pdf_path))[0]
+            
+            # 创建临时CSV文件
+            temp_file = tempfile.NamedTemporaryFile(
+                mode='w',
+                suffix='.csv',
+                prefix=f'{pdf_basename}_qa_',
+                delete=False,
+                encoding='utf-8',
+                newline=''
+            )
+            
+            # 使用TAB作为分隔符写入CSV
+            writer = csv.writer(temp_file, delimiter='\t')
+            
+            # 写入数据(问题和答案)
+            for qa in state.qa_pairs:
+                writer.writerow([qa['question'], qa['answer']])
+            
+            temp_file.close()
+            csv_path = temp_file.name
+            
+            print(f"CSV文件导出完成: {csv_path}")
+            print(f"共导出 {len(state.qa_pairs)} 个QA对")
+            
+            return {
+                "csv_path": csv_path,
+                "qa_count": len(state.qa_pairs)
+            }
+            
+        except Exception as e:
+            print(f"导出CSV失败: {str(e)}")
+            return {"error_message": f"导出CSV失败: {str(e)}"}
+    
+    def _upload_document_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """上传CSV文档到RAGFlow节点"""
+        print(f"开始上传CSV文件到RAGFlow: {state.csv_path}")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.csv_path:
+            return {"error_message": "没有可上传的CSV文件"}
+        
+        try:
+            # 上传文档到RAGFlow
+            document_info_list = self.ragflow_service.upload_document(
+                dataset_id=state.dataset_id,
+                file_path=state.csv_path
+            )
+            
+            if document_info_list and len(document_info_list) > 0:
+                document_id = document_info_list[0]["id"]
+                print(f"CSV文档上传成功,文档ID: {document_id}")
+                
+                # 清理临时文件
+                try:
+                    os.remove(state.csv_path)
+                    print(f"临时文件已清理: {state.csv_path}")
+                except Exception as e:
+                    print(f"清理临时文件失败: {str(e)}")
+                
+                return {"uploaded_document_id": document_id}
+            else:
+                return {"error_message": "文档上传失败: 未返回有效的文档信息"}
+                
+        except Exception as e:
+            print(f"上传文档失败: {str(e)}")
+            return {"error_message": f"上传文档失败: {str(e)}"}
+
+    def _parse_document_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """RAGFLOW文档解析节点"""
+        print(f"开始解析文档 {state.dataset_id}: {state.uploaded_document_id}")
+        
+        try:        
+            # 解析文档
+            parse_success = self.ragflow_service.parse_document(
+                dataset_id=state.dataset_id,
+                document_ids=[state.uploaded_document_id]
+            )
+            
+            # 检查响应parse_success为bool
+            if parse_success:
+                print(f"文档解析成功,文档ID: {state.uploaded_document_id}")
+                # 返回空列表,因为parsed_results字段期望是列表类型
+                return {
+                    "parsed_results": []
+                }
+            else:
+                print("文档解析失败: 未返回有效的解析结果")
+                raise Exception("文档解析失败: 未返回有效的解析结果")
+        except Exception as e:
+            print(f"解析文档时出错: {str(e)}")
+            raise
+    
+    def _complete_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """完成节点"""
+        if state.error_message:
+            print(f"工作流完成(有错误): {state.error_message}")
+        else:
+            print(f"QA解析工作流完成!")
+            print(f"  - 提取文本: {len(state.extracted_text)} 字符")
+            print(f"  - 分块数量: {len(state.chunks)}")
+            print(f"  - 生成QA对: {state.qa_count}")
+            print(f"  - 上传文档ID: {state.uploaded_document_id}")
+        
+        return {"is_complete": True}
+    
+    def run(
+        self, 
+        pdf_path: str, 
+        dataset_id: str, 
+        qa_count_per_chunk: int = 50,
+        chunk_size: int = 1000,
+        chunk_overlap: int = 200
+    ) -> Dict[str, Any]:
+        """
+        运行QA解析工作流
+        
+        Args:
+            pdf_path: PDF文件路径
+            dataset_id: RAGFlow数据集ID
+            qa_count_per_chunk: 每块生成的QA数量,默认50
+            chunk_size: 文本分块大小,默认1000
+            chunk_overlap: 分块重叠大小,默认200
+            
+        Returns:
+            Dict: 包含最终状态的字典
+        """
+        initial_state = QAParsingState(
+            pdf_path=pdf_path,
+            dataset_id=dataset_id,
+            qa_count_per_chunk=qa_count_per_chunk,
+            chunk_size=chunk_size,
+            chunk_overlap=chunk_overlap
+        )
+        
+        result = self.workflow.invoke(
+            initial_state, 
+            config={"callbacks": [self.langfuse_handler]}
+        )
+        
+        if isinstance(result, dict):
+            return result
+        else:
+            return result.dict()

+ 42 - 16
src/utils/infinity/client.py

@@ -260,8 +260,8 @@ class InfinityClient:
         """关闭客户端,释放所有连接"""
         self.pool.close()
 
-# 全局客户端实例
-_global_client: Optional[InfinityClient] = None
+# 全局客户端实例缓存(按database进行缓存)
+_client_cache: Dict[str, "InfinityClient"] = {}
 _client_lock = threading.Lock()
 
 def get_client(
@@ -270,25 +270,36 @@ def get_client(
     database: str = vector_db_settings.infinity_database,
     min_connections: int = 5,
     max_connections: int = 10
-) -> InfinityClient:
+) -> "InfinityClient":
     """
-    获取全局客户端实例(单例模式)
+    获取Infinity客户端实例(按数据库缓存)
+    
+    支持动态切换数据库:每个数据库对应一个独立的客户端实例,
+    通过不同的database参数可获取对应数据库的客户端。
     
     Args:
         host: Infinity服务地址
         port: Infinity服务端口
-        database: 数据库名称
+        database: 数据库名称(不同的数据库会创建不同的客户端实例)
         min_connections: 最小连接数
         max_connections: 最大连接数
         
     Returns:
-        全局Infinity客户端实例
+        对应数据库的Infinity客户端实例
+        
+    Example:
+        # 获取默认数据库客户端
+        client = get_client()
+        
+        # 获取指定数据库客户端
+        ragflow_client = get_client(database="ragflow_db")
     """
-    global _global_client
+    # 使用 (host:port:database) 作为缓存key
+    cache_key = f"{host}:{port}:{database}"
     
     with _client_lock:
-        if _global_client is None:
-            _global_client = InfinityClient(
+        if cache_key not in _client_cache:
+            _client_cache[cache_key] = InfinityClient(
                 host=host,
                 port=port,
                 database=database,
@@ -296,13 +307,28 @@ def get_client(
                 max_connections=max_connections
             )
     
-    return _global_client
+    return _client_cache[cache_key]
+
 
-def close_client():
-    """关闭全局客户端"""
-    global _global_client
+def close_client(database: Optional[str] = None):
+    """
+    关闭客户端
+    
+    Args:
+        database: 要关闭的数据库客户端名称。
+                  如果为None,则关闭所有缓存的客户端。
+    """
+    global _client_cache
     
     with _client_lock:
-        if _global_client:
-            _global_client.close()
-            _global_client = None
+        if database is None:
+            # 关闭所有客户端
+            for client in _client_cache.values():
+                client.close()
+            _client_cache.clear()
+        else:
+            # 关闭指定数据库的客户端
+            keys_to_remove = [k for k in _client_cache if k.endswith(f":{database}")]
+            for key in keys_to_remove:
+                _client_cache[key].close()
+                del _client_cache[key]

+ 51 - 1
src/utils/infinity/result_util.py

@@ -63,4 +63,54 @@ def convert_to_langchain_docs(obj: Any) -> List[Document]:
                  metadata={k: v for k, v in item.items() if k != "content"})
         for item in res[0]
     ]
-    return candidate_docs
+    return candidate_docs
+
+
+def convert_to_json(obj: Any, content_field: str = "content") -> List[Dict[str, Any]]:
+    """
+    将Infinity搜索结果转换为JSON可序列化的列表格式
+    
+    Args:
+        obj: 要转换的对象(Infinity搜索结果)
+        content_field: 内容字段名称,默认为"content"
+        
+    Returns:
+        转换后的JSON可序列化列表,每个元素包含:
+        - content: 内容字段
+        - metadata: 其他字段组成的元数据字典
+        
+    Example:
+        >>> result = infinity_client.vector_search(...)
+        >>> json_data = convert_to_json(result.to_result())
+        >>> # [{"content": "...", "metadata": {"field1": "...", ...}}, ...]
+    """
+    res = convert_to_basic_types(obj=obj)
+    
+    # 处理结果为空的情况
+    if not res or len(res) == 0 or len(res[0]) == 0:
+        return []
+    
+    # 转换为JSON格式
+    json_list = []
+    for item in res[0]:
+        # 处理 item 可能是字符串或字典的情况
+        if isinstance(item, dict):
+            json_item = {
+                "content": item.get(content_field, ""),
+                "metadata": {k: v for k, v in item.items() if k != content_field}
+            }
+        elif isinstance(item, str):
+            # 如果 item 是字符串,直接作为 content
+            json_item = {
+                "content": item,
+                "metadata": {}
+            }
+        else:
+            # 其他类型转为字符串
+            json_item = {
+                "content": str(item),
+                "metadata": {}
+            }
+        json_list.append(json_item)
+    
+    return json_list