Browse Source

问答对解析器&问答对检索器

yingge 3 tháng trước cách đây
mục cha
commit
434dc600f1

+ 1 - 1
src/api/dataset/services/dataset_manage_service.py

@@ -10,7 +10,7 @@
 import os
 import os
 import tempfile
 import tempfile
 from typing import Dict, Any, Optional
 from typing import Dict, Any, Optional
-from src.parser.pdf_parser.pdf_parser_workflow import PDFParsingWorkflow
+from src.datasets.parser.pdf_parser.pdf_parser_workflow import PDFParsingWorkflow
 from src.conf.settings import vector_db_settings
 from src.conf.settings import vector_db_settings
 
 
 
 

+ 8 - 0
src/api/db/models/__init__.py

@@ -0,0 +1,8 @@
+from .search_request_models import SearchRequest
+# from .search_param_models import KnnParams, VectorSearchQuery
+
+__all__ = [
+    "SearchRequest",
+    # "KnnParams",
+    # "VectorSearchQuery",
+]

+ 79 - 0
src/api/db/models/search_param_models.py

@@ -0,0 +1,79 @@
+# """
+# 搜索相关的Pydantic模型定义
+
+# 该模块定义了用于图像和文本搜索的请求和响应模型。
+# """
+
+# from pydantic import BaseModel, Field, field_validator
+# from typing import List, Optional
+
+# class KnnParams(BaseModel):
+#     """
+#     KNN搜索参数模型
+    
+#     用于配置向量搜索的近邻搜索参数。
+    
+#     Attributes:
+#         ef: HNSW索引的搜索参数,影响搜索精度和速度
+#         threshold: 相似度分数阈值,低于该阈值的结果将被过滤
+#     """
+    
+#     ef: str = Field(
+#         default="100",
+#         description="HNSW索引的搜索参数 (ef值越大搜索越精确但速度越慢)"
+#     )
+#     threshold: str = Field(
+#         default="0.0",
+#         description="相似度分数阈值,范围 0.0-1.0"
+#     )
+    
+#     @classmethod
+#     def from_retrieval_setting(cls, top_k: int, score_threshold: float) -> "KnnParams":
+#         """
+#         从检索设置创建KNN参数
+        
+#         Args:
+#             top_k: 返回的结果数量
+#             score_threshold: 分数阈值
+            
+#         Returns:
+#             KnnParams实例
+#         """
+#         return cls(
+#             ef=str(top_k * 10),
+#             threshold=str(score_threshold)
+#         )
+
+
+# class VectorSearchQuery(BaseModel):
+#     """
+#     向量搜索查询参数模型
+    
+#     封装向量检索所需的所有参数。
+    
+#     Attributes:
+#         vector_field: 向量字段名称
+#         query_vector: 查询向量
+#         topn: 返回的最大结果数量
+#         knn_params: KNN搜索参数
+#     """
+    
+#     vector_field: str = Field(
+#         ...,
+#         description="向量字段名称",
+#         examples=["dense_vector_1024"]
+#     )
+#     query_vector: List[float] = Field(
+#         ...,
+#         description="查询向量"
+#     )
+#     topn: int = Field(
+#         default=10,
+#         ge=1,
+#         le=100,
+#         description="返回的最大结果数量"
+#     )
+#     knn_params: Optional[KnnParams] = Field(
+#         default=None,
+#         description="KNN搜索参数"
+#     )

+ 65 - 0
src/api/db/models/search_request_models.py

@@ -0,0 +1,65 @@
+"""
+搜索相关的Pydantic模型定义
+
+该模块定义了用于图像和文本搜索的请求和响应模型。
+"""
+from pydantic import BaseModel, Field, field_validator
+from typing import List, Optional
+
+
+class SearchRequest(BaseModel):
+    """
+    搜索请求模型
+    
+    支持通过图像URL或文本进行相似性搜索。
+    至少需要提供 image_url 或 matching_text 其中之一。
+    
+    Attributes:
+        image_url: 用于图像搜索的图片URL (可选)
+        matching_text: 用于文本匹配搜索的查询文本 (可选)
+        topn: 返回的最大结果数量,默认为10,范围1-100
+    """
+    
+    image_url: Optional[str] = Field(
+        default=None,
+        description="用于图像搜索的图片URL",
+        examples=["https://example.com/image.jpg"]
+    )
+    matching_text: Optional[str] = Field(
+        default=None,
+        description="用于文本匹配搜索的查询文本",
+        examples=["搜索关键词"]
+    )
+    topn: int = Field(
+        default=10,
+        ge=1,
+        le=100,
+        description="返回的最大结果数量"
+    ),
+    score_threshold: float = Field(
+        default=0.0,
+        ge=0.0,
+        le=1.0,
+        description="相似度分数阈值"
+    )
+    
+    @field_validator('image_url')
+    @classmethod
+    def validate_image_url(cls, v: Optional[str]) -> Optional[str]:
+        """验证图片URL格式"""
+        if v is not None and v.strip() == "":
+            return None
+        return v
+    
+    @field_validator('matching_text')
+    @classmethod
+    def validate_matching_text(cls, v: Optional[str]) -> Optional[str]:
+        """验证匹配文本"""
+        if v is not None and v.strip() == "":
+            return None
+        return v
+    
+    def model_post_init(self, __context) -> None:
+        """模型初始化后验证:确保至少提供一个搜索条件"""
+        if self.image_url is None and self.matching_text is None:
+            raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")

+ 78 - 26
src/api/db/services/infinity_search_service.py

@@ -3,13 +3,13 @@ from src.conf.settings import vector_db_settings
 from src.utils.infinity import InfinityClient
 from src.utils.infinity import InfinityClient
 from src.utils.file.image_util import image_util
 from src.utils.file.image_util import image_util
 from src.model.multimodal_embedding import get_embedding_model
 from src.model.multimodal_embedding import get_embedding_model
-from src.utils.infinity.result_util import convert_to_basic_types
+from src.utils.infinity.result_util import convert_to_json
+from src.api.db.models import SearchRequest
+from langchain_core.documents import Document
 
 
 class InfinitySearchService:
 class InfinitySearchService:
-    def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None):
-        self.infinity_client = infinity_client
-        # 输出字段
-        self.output_fields = [
+
+    default_output_fields = [
                     "file_name",
                     "file_name",
                     "page_number",
                     "page_number",
                     "content",
                     "content",
@@ -17,12 +17,33 @@ class InfinitySearchService:
                     "dataset_id",
                     "dataset_id",
                     "document_id"
                     "document_id"
                 ]
                 ]
+
+    def __init__(self, infinity_client: InfinityClient, 
+                       vector_field: str = None, 
+                       match_field: str = None, 
+                       match_type: str = None, 
+                       table_name: str = None, 
+                       output_fields: List[str] = None):
+        """
+        初始化InfinitySearchService
+        
+        Args:
+            infinity_client: InfinityClient实例
+            vector_field: 向量字段名
+            match_field: 匹配字段名
+            match_type: 匹配类型
+            table_name: 表名
+            output_fields: 输出字段列表
+        """
+        self.infinity_client = infinity_client
+        # 输出字段
+        self.output_fields = output_fields or self.default_output_fields
         self.vector_field = vector_field or "dense_vector_1024"
         self.vector_field = vector_field or "dense_vector_1024"
         self.match_field = match_field or "content"
         self.match_field = match_field or "content"
         self.match_type = match_type or "cosine"
         self.match_type = match_type or "cosine"
         self.table_name = table_name or vector_db_settings.infinity_table_name
         self.table_name = table_name or vector_db_settings.infinity_table_name
 
 
-    def search(self, search_query: Dict[str, Any]) -> Dict[str, Any]:
+    def search(self, request: SearchRequest) -> List[Dict[str, Any]]:
         """
         """
         执行Infinity数据库搜索
         执行Infinity数据库搜索
     
     
@@ -33,16 +54,18 @@ class InfinitySearchService:
             搜索结果,转换为基本类型以便序列化
             搜索结果,转换为基本类型以便序列化
         """
         """
         try:
         try:
+            # 参数转换
+            search_query = self._convert_search_request_to_search_query(request)
             # 执行搜索
             # 执行搜索
             result = self.infinity_client.search(self.table_name, self.output_fields, search_query)
             result = self.infinity_client.search(self.table_name, self.output_fields, search_query)
             # 将结果转换为基本类型,处理可能的复杂类型
             # 将结果转换为基本类型,处理可能的复杂类型
             result_dict = result.to_result()
             result_dict = result.to_result()
             # 递归转换所有复杂类型为基本类型
             # 递归转换所有复杂类型为基本类型
-            return convert_to_basic_types(result_dict)
+            return convert_to_json(result_dict)
         except Exception as e:
         except Exception as e:
             raise Exception(f"搜索失败: {str(e)}")
             raise Exception(f"搜索失败: {str(e)}")
 
 
-    def vector_search(self, search_query: Dict[str, Any]):
+    def vector_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
         """
         """
         执行Infinity数据库向量检索
         执行Infinity数据库向量检索
     
     
@@ -53,23 +76,18 @@ class InfinitySearchService:
             向量检索结果,转换为基本类型以便序列化
             向量检索结果,转换为基本类型以便序列化
         """
         """
         try:
         try:
-            # 1.处理image_url为image: Image.Image
-            image = image_util._url_to_image(search_query["image_url"])
-            # 2.将图片进行向量化
-            query_vector = get_embedding_model().get_multimodal_embedding(search_query["matching_text"], image)
-
-            search_query["vector_field"] = self.vector_field
-            search_query["query_vector"] = query_vector
+            # 参数转换
+            search_query = self._convert_search_request_to_search_query(request)
             # 执行向量检索
             # 执行向量检索
             result = self.infinity_client.vector_search(self.table_name, self.output_fields, search_query)
             result = self.infinity_client.vector_search(self.table_name, self.output_fields, search_query)
             # 将结果转换为基本类型,处理可能的复杂类型
             # 将结果转换为基本类型,处理可能的复杂类型
             result_dict = result.to_result()
             result_dict = result.to_result()
             # 递归转换所有复杂类型为基本类型
             # 递归转换所有复杂类型为基本类型
-            return convert_to_basic_types(result_dict)
+            return convert_to_json(result_dict)
         except Exception as e:
         except Exception as e:
             raise Exception(f"向量检索失败: {str(e)}")
             raise Exception(f"向量检索失败: {str(e)}")
 
 
-    def hybrid_search(self, search_query: Dict[str, Any]):
+    def hybrid_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
         """
         """
         执行Infinity数据库混合检索
         执行Infinity数据库混合检索
     
     
@@ -80,18 +98,52 @@ class InfinitySearchService:
             混合检索结果,转换为基本类型以便序列化
             混合检索结果,转换为基本类型以便序列化
         """
         """
         try:
         try:
-            # 1.处理image_url为image: Image.Image
-            image = image_util._url_to_image(search_query["image_url"])
-            # 2.将图片进行向量化
-            query_vector = get_embedding_model().get_multimodal_embedding(search_query["matching_text"], image)
-            search_query["vector_field"] = self.vector_field
-            search_query["query_vector"] = query_vector
-            search_query["match_field"] = self.match_field
+            # 参数转换
+            search_query = self._convert_search_request_to_search_query(request)
             # 执行混合检索
             # 执行混合检索
             result = self.infinity_client.hybrid_search(self.table_name, self.output_fields, search_query)
             result = self.infinity_client.hybrid_search(self.table_name, self.output_fields, search_query)
             # 将结果转换为基本类型,处理可能的复杂类型
             # 将结果转换为基本类型,处理可能的复杂类型
             result_dict = result.to_result()
             result_dict = result.to_result()
             # 递归转换所有复杂类型为基本类型
             # 递归转换所有复杂类型为基本类型
-            return convert_to_basic_types(result_dict)
+            return convert_to_json(result_dict)
+        except Exception as e:
+            raise Exception(f"混合检索失败: {str(e)}")
+
+    # 参数转换
+    def _convert_search_request_to_search_query(self, request: SearchRequest) -> dict:
+        """
+        将SearchRequest转换为SearchQuery
+        
+        Args:
+            request: SearchRequest实例
+            
+        Returns:
+            SearchQuery字典
+        """
+        try:
+            if request.score_threshold is not None:
+                # 判断是文本向量还是多模态向量
+                if request.image_url is not None:
+                    # 处理image_url为image: Image.Image
+                    image = image_util._url_to_image(request.image_url)
+                    # 将图片进行向量化
+                    query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image)
+                else:
+                    query_vector = get_embedding_model().get_text_embedding(request.matching_text)
+            else:
+                query_vector = None
+            
+            search_query = {
+                "match_field": self.match_field,
+                "matching_text": request.matching_text,
+                "vector_field": self.vector_field,
+                "query_vector": query_vector,
+                "topn": request.topn,
+                "knn_params": {
+                    "ef": str(request.topn * 10),
+                    "threshold": str(request.score_threshold)
+                }
+            }
+            return search_query
         except Exception as e:
         except Exception as e:
-            raise Exception(f"混合检索失败: {str(e)}")
+            raise Exception(f"参数转换失败: {str(e)}")

+ 25 - 12
src/api/sdk/search_infinity.py

@@ -1,12 +1,11 @@
 # Infinity搜索API服务
 # Infinity搜索API服务
 
 
 from fastapi import FastAPI, HTTPException
 from fastapi import FastAPI, HTTPException
-from typing import List, Dict, Any, Optional
 from src.api.db.services.infinity_search_service import InfinitySearchService
 from src.api.db.services.infinity_search_service import InfinitySearchService
 from src.utils.infinity import get_client
 from src.utils.infinity import get_client
 from src.common.result import Result
 from src.common.result import Result
 from src.utils.async_utils import run_in_threadpool
 from src.utils.async_utils import run_in_threadpool
-
+from src.api.db.models import SearchRequest
 
 
 # 创建FastAPI应用
 # 创建FastAPI应用
 
 
@@ -16,13 +15,6 @@ app = FastAPI(
     version="1.0.0"
     version="1.0.0"
 )
 )
 
 
-# 请求模型
-from pydantic import BaseModel
-
-class SearchRequest(BaseModel):
-    """搜索请求模型"""
-    search_query: Dict[str, Any]
-
 # 1. 普通搜索接口
 # 1. 普通搜索接口
 @app.post("/text")
 @app.post("/text")
 async def search(request: SearchRequest):
 async def search(request: SearchRequest):
@@ -36,7 +28,7 @@ async def search(request: SearchRequest):
     """
     """
     try:
     try:
         search_service = InfinitySearchService(infinity_client=get_client())
         search_service = InfinitySearchService(infinity_client=get_client())
-        result = await run_in_threadpool(search_service.search, request.search_query)
+        result = await run_in_threadpool(search_service.search, request)
         return Result.success(data=result, message="搜索成功")
         return Result.success(data=result, message="搜索成功")
     except Exception as e:
     except Exception as e:
         return Result.error(code=500, message=f"搜索失败: {str(e)}")
         return Result.error(code=500, message=f"搜索失败: {str(e)}")
@@ -54,7 +46,7 @@ async def vector_search(request: SearchRequest):
     """
     """
     try:
     try:
         search_service = InfinitySearchService(infinity_client=get_client())
         search_service = InfinitySearchService(infinity_client=get_client())
-        result = await run_in_threadpool(search_service.vector_search, request.search_query)
+        result = await run_in_threadpool(search_service.vector_search, request)
         return Result.success(data=result, message="向量搜索成功")
         return Result.success(data=result, message="向量搜索成功")
     except Exception as e:
     except Exception as e:
         return Result.error(code=500, message=f"向量搜索失败: {str(e)}")
         return Result.error(code=500, message=f"向量搜索失败: {str(e)}")
@@ -72,7 +64,28 @@ async def hybrid_search(request: SearchRequest):
     """
     """
     try:
     try:
         search_service = InfinitySearchService(infinity_client=get_client())
         search_service = InfinitySearchService(infinity_client=get_client())
-        result = await run_in_threadpool(search_service.hybrid_search, request.search_query)
+        result = await run_in_threadpool(search_service.hybrid_search, request)
         return Result.success(data=result, message="混合搜索成功")
         return Result.success(data=result, message="混合搜索成功")
     except Exception as e:
     except Exception as e:
         return Result.error(code=500, message=f"混合搜索失败: {str(e)}")
         return Result.error(code=500, message=f"混合搜索失败: {str(e)}")
+
+# 4. 问答对检索
+@app.post("/question")
+async def question_search(request: SearchRequest):
+    """
+    问答对检索接口
+
+    - **output_fields**: 要返回的字段列表
+    - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
+    - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
+    """
+    try:
+        output_fields = ["content"]
+        search_service = InfinitySearchService(infinity_client=get_client(database="ragflow_db"), 
+                                               table_name="ragflow_92162247e93e11f084830242ac1d0002_1c4f7a82f66c11f09c750242c0a8d002", 
+                                               vector_field="q_1024_vec",
+                                               output_fields=output_fields)
+        result = await run_in_threadpool(search_service.hybrid_search, request)
+        return Result.success(data=result, message="问答对检索成功")
+    except Exception as e:
+        return Result.error(code=500, message=f"问答对检索失败: {str(e)}")

+ 0 - 0
src/parser/__init__.py → src/datasets/__init__.py


+ 0 - 0
src/parser/image_parser/__init__.py → src/datasets/parser/__init__.py


+ 0 - 0
src/parser/pdf_parser/__init__.py → src/datasets/parser/image_parser/__init__.py


+ 0 - 0
src/parser/image_parser/image_parser_workflow.py → src/datasets/parser/image_parser/image_parser_workflow.py


+ 0 - 0
src/datasets/parser/pdf_parser/__init__.py


+ 1 - 1
src/parser/pdf_parser/pdf_parser_workflow.py → src/datasets/parser/pdf_parser/pdf_parser_workflow.py

@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
 from langgraph.graph import StateGraph, START, END
 from langgraph.graph import StateGraph, START, END
 from typing import List, Dict, Any
 from typing import List, Dict, Any
 from pydantic import BaseModel, Field, ConfigDict
 from pydantic import BaseModel, Field, ConfigDict
-from src.parser.pdf_parser.pdf_splitter import PDFSplitter
+from src.datasets.parser.pdf_parser.pdf_splitter import PDFSplitter
 from src.model.qwen_vl import QWenVLParser
 from src.model.qwen_vl import QWenVLParser
 from src.utils.ragflow.ragflow_service import RAGFlowService
 from src.utils.ragflow.ragflow_service import RAGFlowService
 from src.utils.ragflow.chunk_record import get_chunk_record_service
 from src.utils.ragflow.chunk_record import get_chunk_record_service

+ 0 - 0
src/parser/pdf_parser/pdf_splitter.py → src/datasets/parser/pdf_parser/pdf_splitter.py


+ 0 - 0
src/parser/pdf_parser/test_service.py → src/datasets/parser/pdf_parser/test_service.py


+ 417 - 0
src/datasets/parser/question_answer_parser/question_answer_parser.py

@@ -0,0 +1,417 @@
+"""
+QA模式问答对解析器工作流
+
+使用LangGraph实现完整的PDF到QA知识库工作流:
+PDF OCR解析 → 文本分块 → 生成QA对 → 上传RAGFlow
+"""
+
+import os
+import json
+import csv
+import tempfile
+import concurrent.futures
+from concurrent.futures import ThreadPoolExecutor
+from typing import List, Dict, Any, Optional
+from pydantic import BaseModel, Field, ConfigDict
+from langgraph.graph import StateGraph, START, END
+from langchain.chat_models import init_chat_model
+from langchain_core.messages import HumanMessage, SystemMessage
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from src.datasets.parser.pdf_parser.pdf_splitter import PDFSplitter
+from src.model.qwen_vl import QWenVLParser
+from src.utils.ragflow.ragflow_service import RAGFlowService
+from src.conf.settings import model_settings
+from langfuse.langchain import CallbackHandler
+
+
+class QAParsingState(BaseModel):
+    """QA解析工作流状态"""
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+    
+    # 输入参数
+    pdf_path: str = Field(..., description="PDF文件路径")
+    dataset_id: str = Field(..., description="RAGFlow数据集ID")
+    qa_count_per_chunk: int = Field(default=50, description="每块生成的QA数量")
+    chunk_size: int = Field(default=1000, description="文本分块大小")
+    chunk_overlap: int = Field(default=200, description="分块重叠大小")
+    
+    # 中间状态
+    extracted_text: str = Field(default="", description="OCR提取的文本")
+    chunks: List[str] = Field(default_factory=list, description="分块后的文本列表")
+    qa_pairs: List[Dict[str, Any]] = Field(default_factory=list, description="生成的QA对列表")
+    csv_path: str = Field(default="", description="临时CSV文件路径")
+    
+    # 输出状态
+    uploaded_document_id: str = Field(default="", description="上传后的文档ID")
+    qa_count: int = Field(default=0, description="生成的QA对数量")
+    is_complete: bool = Field(default=False, description="是否处理完成")
+    error_message: Optional[str] = Field(default=None, description="错误信息")
+
+
+class QuestionAnswerParserWorkflow:
+    """QA问答对解析工作流"""
+    
+    def __init__(self, model_name: str = None):
+        """
+        初始化QA解析工作流
+        
+        Args:
+            model_name: VL模型名称,用于PDF OCR解析
+        """
+        self.model_name = model_name or model_settings.model_name
+        self.ragflow_service = RAGFlowService()
+        self.langfuse_handler = CallbackHandler()
+        
+        # 初始化LangChain chat model用于QA生成
+        self.chat_model = init_chat_model(
+            model_provider=model_settings.model_provider,
+            model=model_settings.chat_model_name,
+            api_key=model_settings.api_key,
+            base_url=model_settings.base_url,
+            temperature=0.7
+        )
+        
+        self.workflow = self._build_workflow()
+    
+    def _build_workflow(self):
+        """构建LangGraph工作流"""
+        graph = StateGraph(QAParsingState)
+        
+        # 添加节点
+        graph.add_node("parse_pdf", self._parse_pdf_node)
+        graph.add_node("split_text", self._split_text_node)
+        graph.add_node("generate_qa", self._generate_qa_node)
+        graph.add_node("export_csv", self._export_csv_node)
+        graph.add_node("upload_document", self._upload_document_node)
+        graph.add_node("parse_document", self._parse_document_node)
+        graph.add_node("complete", self._complete_node)
+        
+        # 定义边
+        graph.add_edge(START, "parse_pdf")
+        graph.add_edge("parse_pdf", "split_text")
+        graph.add_edge("split_text", "generate_qa")
+        graph.add_edge("generate_qa", "export_csv")
+        graph.add_edge("export_csv", "upload_document")
+        graph.add_edge("upload_document", "parse_document")
+        graph.add_edge("parse_document", "complete")
+        graph.add_edge("complete", END)
+        
+        return graph.compile()
+    
+    def _parse_pdf_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """PDF OCR解析节点 - 提取文本内容"""
+        print(f"开始解析PDF: {state.pdf_path}")
+        
+        try:
+            # 使用PDFSplitter拆分PDF为图片
+            splitter = PDFSplitter()
+            pages = splitter.split_pdf(state.pdf_path)
+            
+            # 使用QWenVL模型提取每页文本
+            extracted_texts = []
+            parser = QWenVLParser(self.model_name)
+            
+            for page in pages:
+                page_number = page["page_number"]
+                image = page["image"]
+                
+                # OCR提取文本的prompt
+                prompt = """请提取图片中的所有文字内容,保持原有的段落结构。
+只输出提取的文字,不要添加任何额外的说明或格式。"""
+                
+                result = parser.parse_image(image, page_number, prompt)
+                text = result.get("content", "")
+                extracted_texts.append(text)
+                print(f"第 {page_number} 页文本提取完成")
+            
+            full_text = "\n\n".join(extracted_texts)
+            print(f"PDF解析完成,提取文本长度: {len(full_text)} 字符")
+            
+            return {"extracted_text": full_text}
+            
+        except Exception as e:
+            print(f"PDF解析失败: {str(e)}")
+            return {"error_message": f"PDF解析失败: {str(e)}"}
+    
+    def _split_text_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """文本分块节点"""
+        print("开始文本分块...")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.extracted_text:
+            return {"error_message": "没有提取到文本内容"}
+        
+        try:
+            # 使用LangChain的RecursiveCharacterTextSplitter
+            text_splitter = RecursiveCharacterTextSplitter(
+                chunk_size=state.chunk_size,
+                chunk_overlap=state.chunk_overlap,
+                length_function=len,
+                separators=["\n\n", "\n", "。", "!", "?", ";", " ", ""]
+            )
+            
+            chunks = text_splitter.split_text(state.extracted_text)
+            print(f"文本分块完成,共 {len(chunks)} 个分块")
+            
+            return {"chunks": chunks}
+            
+        except Exception as e:
+            print(f"文本分块失败: {str(e)}")
+            return {"error_message": f"文本分块失败: {str(e)}"}
+    
+    def _generate_qa_for_chunk(self, chunk: str, count: int, chunk_index: int) -> List[Dict[str, str]]:
+        """为单个分块生成QA对"""
+        print(f"正在为分块 {chunk_index + 1} 生成 {count} 个QA对...")
+        
+        system_prompt = """你是一个专业的问答对生成专家。你的任务是根据给定的文本内容,生成高质量的问答对。
+
+            要求:
+            1. 问题应该覆盖内容的不同方面和细节
+            2. 答案应该准确、简洁,直接来自文本
+            3. 问答对应该有助于知识检索和理解
+            4. 避免生成过于简单或重复的问题
+            5. 确保问题是自包含的,不需要额外上下文即可理解"""
+
+        user_prompt = f"""请根据以下内容,生成 {count} 个高质量的问答对。
+
+            内容:
+                {chunk}
+
+            请严格以JSON格式输出,格式如下:
+            [
+                {{"question": "问题1", "answer": "答案1"}},
+                {{"question": "问题2", "answer": "答案2"}}
+            ]
+
+            只输出JSON数组,不要添加任何其他内容。"""
+
+        try:
+            messages = [
+                SystemMessage(content=system_prompt),
+                HumanMessage(content=user_prompt)
+            ]
+            
+            response = self.chat_model.invoke(messages)
+            content = response.content if hasattr(response, 'content') else str(response)
+            
+            # 尝试解析JSON
+            # 清理可能的markdown代码块标记
+            content = content.strip()
+            if content.startswith("```json"):
+                content = content[7:]
+            if content.startswith("```"):
+                content = content[3:]
+            if content.endswith("```"):
+                content = content[:-3]
+            content = content.strip()
+            
+            qa_list = json.loads(content)
+            print(f"分块 {chunk_index + 1} 生成了 {len(qa_list)} 个QA对")
+            return qa_list
+            
+        except json.JSONDecodeError as e:
+            print(f"分块 {chunk_index + 1} JSON解析失败: {str(e)}")
+            return []
+        except Exception as e:
+            print(f"分块 {chunk_index + 1} QA生成失败: {str(e)}")
+            return []
+    
+    def _generate_qa_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """生成QA对节点 - 并行处理所有分块"""
+        print(f"开始为 {len(state.chunks)} 个分块生成QA对...")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.chunks:
+            return {"error_message": "没有可处理的文本分块"}
+        
+        all_qa_pairs = []
+        
+        # 使用ThreadPoolExecutor并行处理
+        with ThreadPoolExecutor(max_workers=5, thread_name_prefix="qa_gen_") as executor:
+            futures = {
+                executor.submit(
+                    self._generate_qa_for_chunk, 
+                    chunk, 
+                    state.qa_count_per_chunk,
+                    i
+                ): i for i, chunk in enumerate(state.chunks)
+            }
+            
+            for future in concurrent.futures.as_completed(futures):
+                chunk_index = futures[future]
+                try:
+                    qa_list = future.result()
+                    for qa in qa_list:
+                        qa["chunk_index"] = chunk_index
+                    all_qa_pairs.extend(qa_list)
+                except Exception as e:
+                    print(f"分块 {chunk_index + 1} 处理异常: {str(e)}")
+        
+        # 按chunk_index排序
+        all_qa_pairs.sort(key=lambda x: x.get("chunk_index", 0))
+        
+        print(f"QA对生成完成,共生成 {len(all_qa_pairs)} 个问答对")
+        return {"qa_pairs": all_qa_pairs}
+    
+    def _export_csv_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """导出QA对到CSV临时文件节点"""
+        print(f"开始导出 {len(state.qa_pairs)} 个QA对到CSV文件...")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.qa_pairs:
+            return {"error_message": "没有可导出的QA对"}
+        
+        try:
+            # 使用PDF文件名作为CSV文件名前缀
+            pdf_basename = os.path.splitext(os.path.basename(state.pdf_path))[0]
+            
+            # 创建临时CSV文件
+            temp_file = tempfile.NamedTemporaryFile(
+                mode='w',
+                suffix='.csv',
+                prefix=f'{pdf_basename}_qa_',
+                delete=False,
+                encoding='utf-8',
+                newline=''
+            )
+            
+            # 使用TAB作为分隔符写入CSV
+            writer = csv.writer(temp_file, delimiter='\t')
+            
+            # 写入数据(问题和答案)
+            for qa in state.qa_pairs:
+                writer.writerow([qa['question'], qa['answer']])
+            
+            temp_file.close()
+            csv_path = temp_file.name
+            
+            print(f"CSV文件导出完成: {csv_path}")
+            print(f"共导出 {len(state.qa_pairs)} 个QA对")
+            
+            return {
+                "csv_path": csv_path,
+                "qa_count": len(state.qa_pairs)
+            }
+            
+        except Exception as e:
+            print(f"导出CSV失败: {str(e)}")
+            return {"error_message": f"导出CSV失败: {str(e)}"}
+    
+    def _upload_document_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """上传CSV文档到RAGFlow节点"""
+        print(f"开始上传CSV文件到RAGFlow: {state.csv_path}")
+        
+        if state.error_message:
+            return {}
+        
+        if not state.csv_path:
+            return {"error_message": "没有可上传的CSV文件"}
+        
+        try:
+            # 上传文档到RAGFlow
+            document_info_list = self.ragflow_service.upload_document(
+                dataset_id=state.dataset_id,
+                file_path=state.csv_path
+            )
+            
+            if document_info_list and len(document_info_list) > 0:
+                document_id = document_info_list[0]["id"]
+                print(f"CSV文档上传成功,文档ID: {document_id}")
+                
+                # 清理临时文件
+                try:
+                    os.remove(state.csv_path)
+                    print(f"临时文件已清理: {state.csv_path}")
+                except Exception as e:
+                    print(f"清理临时文件失败: {str(e)}")
+                
+                return {"uploaded_document_id": document_id}
+            else:
+                return {"error_message": "文档上传失败: 未返回有效的文档信息"}
+                
+        except Exception as e:
+            print(f"上传文档失败: {str(e)}")
+            return {"error_message": f"上传文档失败: {str(e)}"}
+
+    def _parse_document_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """RAGFLOW文档解析节点"""
+        print(f"开始解析文档 {state.dataset_id}: {state.uploaded_document_id}")
+        
+        try:        
+            # 解析文档
+            parse_success = self.ragflow_service.parse_document(
+                dataset_id=state.dataset_id,
+                document_ids=[state.uploaded_document_id]
+            )
+            
+            # 检查响应parse_success为bool
+            if parse_success:
+                print(f"文档解析成功,文档ID: {state.uploaded_document_id}")
+                # 返回空列表,因为parsed_results字段期望是列表类型
+                return {
+                    "parsed_results": []
+                }
+            else:
+                print("文档解析失败: 未返回有效的解析结果")
+                raise Exception("文档解析失败: 未返回有效的解析结果")
+        except Exception as e:
+            print(f"解析文档时出错: {str(e)}")
+            raise
+    
+    def _complete_node(self, state: QAParsingState) -> Dict[str, Any]:
+        """完成节点"""
+        if state.error_message:
+            print(f"工作流完成(有错误): {state.error_message}")
+        else:
+            print(f"QA解析工作流完成!")
+            print(f"  - 提取文本: {len(state.extracted_text)} 字符")
+            print(f"  - 分块数量: {len(state.chunks)}")
+            print(f"  - 生成QA对: {state.qa_count}")
+            print(f"  - 上传文档ID: {state.uploaded_document_id}")
+        
+        return {"is_complete": True}
+    
+    def run(
+        self, 
+        pdf_path: str, 
+        dataset_id: str, 
+        qa_count_per_chunk: int = 50,
+        chunk_size: int = 1000,
+        chunk_overlap: int = 200
+    ) -> Dict[str, Any]:
+        """
+        运行QA解析工作流
+        
+        Args:
+            pdf_path: PDF文件路径
+            dataset_id: RAGFlow数据集ID
+            qa_count_per_chunk: 每块生成的QA数量,默认50
+            chunk_size: 文本分块大小,默认1000
+            chunk_overlap: 分块重叠大小,默认200
+            
+        Returns:
+            Dict: 包含最终状态的字典
+        """
+        initial_state = QAParsingState(
+            pdf_path=pdf_path,
+            dataset_id=dataset_id,
+            qa_count_per_chunk=qa_count_per_chunk,
+            chunk_size=chunk_size,
+            chunk_overlap=chunk_overlap
+        )
+        
+        result = self.workflow.invoke(
+            initial_state, 
+            config={"callbacks": [self.langfuse_handler]}
+        )
+        
+        if isinstance(result, dict):
+            return result
+        else:
+            return result.dict()

+ 42 - 16
src/utils/infinity/client.py

@@ -260,8 +260,8 @@ class InfinityClient:
         """关闭客户端,释放所有连接"""
         """关闭客户端,释放所有连接"""
         self.pool.close()
         self.pool.close()
 
 
-# 全局客户端实例
-_global_client: Optional[InfinityClient] = None
+# 全局客户端实例缓存(按database进行缓存)
+_client_cache: Dict[str, "InfinityClient"] = {}
 _client_lock = threading.Lock()
 _client_lock = threading.Lock()
 
 
 def get_client(
 def get_client(
@@ -270,25 +270,36 @@ def get_client(
     database: str = vector_db_settings.infinity_database,
     database: str = vector_db_settings.infinity_database,
     min_connections: int = 5,
     min_connections: int = 5,
     max_connections: int = 10
     max_connections: int = 10
-) -> InfinityClient:
+) -> "InfinityClient":
     """
     """
-    获取全局客户端实例(单例模式)
+    获取Infinity客户端实例(按数据库缓存)
+    
+    支持动态切换数据库:每个数据库对应一个独立的客户端实例,
+    通过不同的database参数可获取对应数据库的客户端。
     
     
     Args:
     Args:
         host: Infinity服务地址
         host: Infinity服务地址
         port: Infinity服务端口
         port: Infinity服务端口
-        database: 数据库名称
+        database: 数据库名称(不同的数据库会创建不同的客户端实例)
         min_connections: 最小连接数
         min_connections: 最小连接数
         max_connections: 最大连接数
         max_connections: 最大连接数
         
         
     Returns:
     Returns:
-        全局Infinity客户端实例
+        对应数据库的Infinity客户端实例
+        
+    Example:
+        # 获取默认数据库客户端
+        client = get_client()
+        
+        # 获取指定数据库客户端
+        ragflow_client = get_client(database="ragflow_db")
     """
     """
-    global _global_client
+    # 使用 (host:port:database) 作为缓存key
+    cache_key = f"{host}:{port}:{database}"
     
     
     with _client_lock:
     with _client_lock:
-        if _global_client is None:
-            _global_client = InfinityClient(
+        if cache_key not in _client_cache:
+            _client_cache[cache_key] = InfinityClient(
                 host=host,
                 host=host,
                 port=port,
                 port=port,
                 database=database,
                 database=database,
@@ -296,13 +307,28 @@ def get_client(
                 max_connections=max_connections
                 max_connections=max_connections
             )
             )
     
     
-    return _global_client
+    return _client_cache[cache_key]
+
 
 
-def close_client():
-    """关闭全局客户端"""
-    global _global_client
+def close_client(database: Optional[str] = None):
+    """
+    关闭客户端
+    
+    Args:
+        database: 要关闭的数据库客户端名称。
+                  如果为None,则关闭所有缓存的客户端。
+    """
+    global _client_cache
     
     
     with _client_lock:
     with _client_lock:
-        if _global_client:
-            _global_client.close()
-            _global_client = None
+        if database is None:
+            # 关闭所有客户端
+            for client in _client_cache.values():
+                client.close()
+            _client_cache.clear()
+        else:
+            # 关闭指定数据库的客户端
+            keys_to_remove = [k for k in _client_cache if k.endswith(f":{database}")]
+            for key in keys_to_remove:
+                _client_cache[key].close()
+                del _client_cache[key]

+ 51 - 1
src/utils/infinity/result_util.py

@@ -63,4 +63,54 @@ def convert_to_langchain_docs(obj: Any) -> List[Document]:
                  metadata={k: v for k, v in item.items() if k != "content"})
                  metadata={k: v for k, v in item.items() if k != "content"})
         for item in res[0]
         for item in res[0]
     ]
     ]
-    return candidate_docs
+    return candidate_docs
+
+
+def convert_to_json(obj: Any, content_field: str = "content") -> List[Dict[str, Any]]:
+    """
+    将Infinity搜索结果转换为JSON可序列化的列表格式
+    
+    Args:
+        obj: 要转换的对象(Infinity搜索结果)
+        content_field: 内容字段名称,默认为"content"
+        
+    Returns:
+        转换后的JSON可序列化列表,每个元素包含:
+        - content: 内容字段
+        - metadata: 其他字段组成的元数据字典
+        
+    Example:
+        >>> result = infinity_client.vector_search(...)
+        >>> json_data = convert_to_json(result.to_result())
+        >>> # [{"content": "...", "metadata": {"field1": "...", ...}}, ...]
+    """
+    res = convert_to_basic_types(obj=obj)
+    
+    # 处理结果为空的情况
+    if not res or len(res) == 0 or len(res[0]) == 0:
+        return []
+    
+    # 转换为JSON格式
+    json_list = []
+    for item in res[0]:
+        # 处理 item 可能是字符串或字典的情况
+        if isinstance(item, dict):
+            json_item = {
+                "content": item.get(content_field, ""),
+                "metadata": {k: v for k, v in item.items() if k != content_field}
+            }
+        elif isinstance(item, str):
+            # 如果 item 是字符串,直接作为 content
+            json_item = {
+                "content": item,
+                "metadata": {}
+            }
+        else:
+            # 其他类型转为字符串
+            json_item = {
+                "content": str(item),
+                "metadata": {}
+            }
+        json_list.append(json_item)
+    
+    return json_list