Преглед изворни кода

feat: 新增RAGFlow SDK工具类、MySQL连接池和标签管理API

refactor: 重构配置管理,从config.py迁移到settings模块

feat(utils): 添加非对称加密工具类和Langfuse追踪装饰器

feat(api): 实现标签管理API及相关服务

feat(model): 扩展多模态嵌入模型功能,集成Langfuse追踪

fix(infinity): 修复Infinity客户端更新操作和结果转换问题

docs: 更新.env文件配置说明和测试用例
yingge пре 3 месеци
родитељ
комит
eafbea0aec
50 измењених фајлова са 2359 додато и 250 уклоњено
  1. 25 1
      .env
  2. 9 9
      agent/test_image_agent.py
  3. 3 55
      api/db/services/infinity_search_service.py
  4. 207 0
      api/db/services/tag_service.py
  5. 7 5
      api/hybrid_search_mcp.py
  6. 95 0
      api/tag_manage.py
  7. 11 0
      common/models/pagination.py
  8. 10 0
      conf/age_level.json
  9. 31 0
      conf/config.py
  10. 123 0
      conf/settings.py
  11. 0 0
      ee.json
  12. 18 3
      main.py
  13. 133 0
      model/jina_rerank.py
  14. 30 20
      model/multimodal_embedding.py
  15. 8 7
      model/qwen_vl.py
  16. 21 0
      model/tracked_multi_embedding.py
  17. 12 0
      model/tracked_openai_embeddings.py
  18. 0 0
      parser/image_parser/__init__.py
  19. 2 2
      parser/image_parser/image_parser_workflow.py
  20. 82 14
      parser/pdf_parser/pdf_parser_workflow.py
  21. 27 0
      private_key.pem
  22. 0 0
      prompt/parser/page_parse_prompt.py
  23. 9 0
      public_key.pem
  24. 82 0
      test/test_asymmetric_encryption.py
  25. 39 0
      test/test_excel_util.py
  26. 70 0
      test/test_mysql_config.py
  27. 95 0
      test/test_mysql_conn.py
  28. 14 0
      test/test_ragflow_sdk.py
  29. 85 0
      test/test_search_tag.py
  30. 25 0
      test_upload_tag.py
  31. 151 0
      utils/asymmetric_encryption.py
  32. 45 0
      utils/decorators/langfuse_trace_embedding.py
  33. 0 0
      utils/decorators/singleton.py
  34. 221 0
      utils/excel_util.py
  35. 5 5
      utils/file/minio/minio_util.py
  36. 21 7
      utils/infinity/client.py
  37. 4 4
      utils/infinity/pool.py
  38. 78 0
      utils/infinity/result_util.py
  39. 52 0
      utils/mysql/__init__.py
  40. 50 114
      utils/mysql/mysql_conn.py
  41. 138 0
      utils/mysql/mysql_pool.py
  42. 3 1
      utils/ragflow/dataset_service.py
  43. 3 3
      utils/ragflow/ragflow_service.py
  44. 9 0
      utils/ragflow_sdk/__init__.py
  45. 17 0
      utils/ragflow_sdk/base_util.py
  46. 80 0
      utils/ragflow_sdk/chunk_util.py
  47. 45 0
      utils/ragflow_sdk/conf/rag_parser_config.py
  48. 71 0
      utils/ragflow_sdk/dataset_util.py
  49. 93 0
      utils/ragflow_sdk/document_util.py
  50. 0 0
      utils/ragflow_sdk/param.json

+ 25 - 1
.env

@@ -7,10 +7,15 @@ DASHSCOPE=sk-bc0f1026a41c4c92beb014be8973e4e2
 # embedding模型配置
 EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-0.6B
 MULTIMODAL_EMBEDDING_MODEL_NAME=qwen2.5-vl-embedding
+# rank模型配置
+RANK_MODEL_NAME=Qwen/Qwen3-Reranker-0.6B
+
 # RAGFLOW配置
-RAGFLOW_API_URL=http://192.168.16.134/
+RAGFLOW_API_URL=http://192.168.16.134:9380/
 RAGFLOW_API_KEY=ragflow-sPJ06xiUdRrcfDRlOD-GN2gl-U2DLB-PbgNGckUu0KM
 DATASET_ID=a0f1aa03ed2c11f08b8f0242c0a85002
+RAGFLOW_USER_NAME=O75u85uh+PwmwmJvNebYUCNKpD812xhfnQOvB+Mwy+cHQtGBV2dy0tMQKQwGHiW7MiLJkHPqSLn7ULTzav0c2w3yIze71PAcIfxUScautg6xMMgtjHd4ex8peVyXTQcWc0bmD+GxADaZoOMnDz/XNUtB8mggx/VZ1RBdhrZMylTEGQUcUE8ylbLAVgDVOR6iqJEjGNZYoWUFiuYZAB6bIfPdaPG8Kr0KQrqHj3Y0zZHXl92AloXnw5RsaBOc3p01PKc+xYujrb+mdSGs8flzNHMusWhB/0bJ+t0XxqgePymCCk6+7nmm7M6iZ6pRlYEeERsjiNd/1lcMaWxda5jv4A==
+RAGFLOW_PASSWD=ggR2p1L7DRp5v+VcfikmfMYWh99QgBjtQV1fexSFP5C18HVXhkTD2MeMrcnJft8ciGeRUPDStaYZft3CbHOBY3Bzm2+/WX8iVuZ+5kMkmiAL9iUu0RxK74g1x11zy2CAoASy344ZiDy1p1co7hE0ksPdL62U38dxPGPaK8ngqfTsMvpb0XztWBUxaK5Izl3fqPKaYp+eqC75vvj6PeeaajIc7I7dqSai3a6jipruZeA8VxF/cWjVSKubQrm/YLs8mge2mcSM0pyUU2t9LeNS+hOiUyxRAqE2s+yGmRpeCMKIcfeGaurlVesy8v7zjIjZZBbEPV8rSZOEskGRSAG38Q==
 
 # 应用配置
 LOG_LEVEL=INFO
@@ -28,6 +33,15 @@ INFINITY_USER=admin
 INFINITY_PASSWORD=admin
 INFINITY_TABLE_NAME=book_page_image
 
+# MySQL配置
+MYSQL_HOST=192.168.16.134
+MYSQL_PORT=5455
+MYSQL_USER=root
+MYSQL_PASSWORD=ronshin@1234
+MYSQL_DATABASE=book_page
+MYSQL_CHARSET=utf8mb4
+MYSQL_POOL_SIZE=5
+
 # MinIO配置
 MINIO_ENDPOINT=192.168.16.134:9000
 MINIO_ACCESS_KEY=ck7I8Esssx6rzZrXQ5uP
@@ -36,4 +50,14 @@ MINIO_BUCKET_NAME=bookpage
 # 本地测试设为false;生产环境设为true
 MINIO_SECURE=False
 
+# TAG_SEARCH
+TAG_DB_NAME=default_db
+TAG_DATASET_ID=18caf531f04d11f095670242c0a85002
+TAG_DOCUMENT_ID=3dda0a90f1e211f0a3b80242c0a85002
+#TAG_TABLE_NAME=ragflow_92162247e93e11f084830242ac1d0002_52275b36f03611f0a5340242c0a85002
+TAG_TABLE_NAME=ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002
 
+# LANGFUSE
+LANGFUSE_PUBLIC_KEY=pk-lf-6918a148-be72-4211-a22d-183a23e6643e
+LANGFUSE_SECRET_KEY=sk-lf-9c64d7ed-1618-4da9-a775-33e39b05448e
+LANGFUSE_HOST=http://192.168.16.134:3000

+ 9 - 9
agent/test_image_agent.py

@@ -9,7 +9,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 from utils.infinity_util import InfinityVectorDB
 from model.multimodal_embedding import Embedding
-from conf.config import ModelConfig, VectorDBConfig
+from conf.settings import model_settings, ragflow_settings, vector_db_settings
 
 system_prompt = """
 【角色设定】 你是一位拥有丰富经验的儿童绘本编辑与阅读推广人,擅长从视觉美学、儿童心理学和文学创作三个维度深度解读绘本。
@@ -29,15 +29,15 @@ vector_db = InfinityVectorDB()
 
 # 初始化多模态嵌入模型
 embedding_model = Embedding(
-    model_name=ModelConfig.get_multimodal_embedding_model_name(),
-    api_key=ModelConfig.get_dashscope_api_key()
+    model_name=model_settings.multimodal_embedding_model_name,
+    api_key=model_settings.dashscope_api_key
 )
 
 llm = init_chat_model(
-    model_provider=ModelConfig.get_model_provider(),
-    model=ModelConfig.get_model_name(),
-    base_url=ModelConfig.get_base_url(),
-    api_key=ModelConfig.get_api_key(),
+    model_provider=model_settings.model_provider,
+    model=model_settings.model_name,
+    base_url=model_settings.base_url,
+    api_key=model_settings.api_key,
     temperature=0.3
 )
 
@@ -52,10 +52,10 @@ def vector_search(text_query: str, image_path: str, topn: int = 2):
     :return: 检索结果
     """
         
-    print(f"开始混合检索,数据库: {VectorDBConfig.get_infinity_database}, 知识库id: {ModelConfig.get_dataset_id()}, 文本查询: {text_query}, 返回数量: {topn}")
+    print(f"开始混合检索,数据库: {vector_db_settings.infinity_database}, 知识库id: {ragflow_settings.dataset_id}, 文本查询: {text_query}, 返回数量: {topn}")
         
     # 构建索引名称
-    index_name = f"{VectorDBConfig.get_infinity_table_name()}" 
+    index_name = f"{vector_db_settings.infinity_table_name}" 
     print(f"开始生成多模态嵌入,文本长度: {len(text_query)}")
         
     # 处理image_path为image: Image.Image

+ 3 - 55
api/db/services/infinity_search_service.py

@@ -1,61 +1,9 @@
 from typing import Dict, Any, List
-from conf.config import VectorDBConfig
+from conf.settings import vector_db_settings
 from utils.infinity import InfinityClient
 from utils.file.image_util import image_util
 from model.multimodal_embedding import get_embedding_model
-
-
-def convert_to_basic_types(obj: Any) -> Any:
-    """
-    递归将对象转换为基本类型,以便Pydantic能够序列化
-    
-    特殊处理:当字典中的子项包含相同长度的数组时,将其转换为数组对象结构
-    例如:{"a": [1,2], "b": [3,4]} -> [{"a":1, "b":3}, {"a":2, "b":4}]
-    
-    Args:
-        obj: 要转换的对象
-    
-    Returns:
-        转换后的基本类型对象
-    """
-    if obj is None:
-        return None
-    elif isinstance(obj, (str, int, float, bool)):
-        return obj
-    elif isinstance(obj, dict):
-        # 先递归转换所有值
-        converted = {k: convert_to_basic_types(v) for k, v in obj.items()}
-        
-        # 检查是否需要转换为数组对象结构
-        # 条件:所有值都是列表,且长度一致,且长度大于0
-        values = list(converted.values())
-        if all(isinstance(v, list) for v in values):
-            lengths = [len(v) for v in values]
-            if len(set(lengths)) == 1 and lengths[0] > 0:
-                # 转换为数组对象结构
-                result = []
-                keys = list(converted.keys())
-                for i in range(lengths[0]):
-                    item = {}
-                    for key in keys:
-                        # 处理数组中可能存在的None值
-                        if i < len(converted[key]):
-                            item[key] = converted[key][i]
-                        else:
-                            item[key] = None
-                    result.append(item)
-                return result
-        
-        return converted
-    elif isinstance(obj, (list, tuple)):
-        return [convert_to_basic_types(item) for item in obj]
-    else:
-        # 对于其他类型,尝试将其转换为字符串或字典
-        try:
-            return dict(obj)
-        except:
-            return str(obj)
-
+from utils.infinity.result_util import convert_to_basic_types
 
 class InfinitySearchService:
     def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None):
@@ -72,7 +20,7 @@ class InfinitySearchService:
         self.vector_field = vector_field or "dense_vector_1024"
         self.match_field = match_field or "content"
         self.match_type = match_type or "cosine"
-        self.table_name = table_name or VectorDBConfig.get_infinity_table_name()
+        self.table_name = table_name or vector_db_settings.infinity_table_name
 
     def search(self, search_query: Dict[str, Any]) -> Dict[str, Any]:
         """

+ 207 - 0
api/db/services/tag_service.py

@@ -0,0 +1,207 @@
+from typing import List, Dict, Any, Optional
+from abc import ABC, abstractmethod
+from utils.ragflow_sdk.chunk_util import ChunkUtil
+from utils.infinity import InfinityClient
+from conf.settings import tag_search_settings
+
+
+
+class TagService(ABC):
+    """标签管理服务接口"""
+    
+    @abstractmethod
+    def create_tag(self, tag_data: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        新增标签
+        
+        Args:
+            tag_data: 标签数据,包含标签名称、描述等信息
+            
+        Returns:
+            Dict[str, Any]: 创建成功的标签信息
+        """
+        pass
+    
+    @abstractmethod
+    def upload_tags(self, tags_data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """
+        批量上传标签
+        
+        Args:
+            tags_data: 标签数据列表,每个元素包含标签名称、描述等信息
+            
+        Returns:
+            Dict[str, Any]: 上传结果,包含成功数量、失败数量等信息
+        """
+        pass
+    
+    @abstractmethod
+    def delete_tag(self, tag_id: str) -> Dict[str, Any]:
+        """
+        删除标签
+        
+        Args:
+            tag_id: 标签ID
+            
+        Returns:
+            Dict[str, Any]: 删除结果
+        """
+        pass
+
+
+class TagServiceImpl(TagService):
+    """标签管理服务实现"""
+    
+    def __init__(self, infinity_client: InfinityClient):
+        """
+        初始化标签服务
+        
+        Args:
+            db_client: 数据库客户端实例
+        """
+        self.tag_dataset_id=tag_search_settings.tag_dataset_id
+        self.tag_document_id=tag_search_settings.tag_document_id
+        self.infinity_client = infinity_client
+        self.chunk_util = ChunkUtil(self.tag_dataset_id, self.tag_document_id)
+    
+    def create_tag(self, tag_data: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        新增标签
+        
+        Args:
+            tag_data: 标签数据,包含标签名称、描述等信息
+            
+        Returns:
+            Dict[str, Any]: 创建成功的标签信息
+        """
+        tag_name = tag_data["name"]
+        tag_desc = tag_data["description"]
+        age_range = tag_data["age_range"]
+        # 步骤1:将分块数据写入ragflow
+        chunk = self.chunk_util.add_chunk(document_id=self.tag_document_id, content=tag_desc, important_keywords=[age_range])
+        chunk_id = chunk["id"]
+        # 步骤2: 调用infinity的update方法,将标签更新到块数据中
+        res = self.infinity_client.update(f"id = {chunk_id}", {"tag_kwd": tag_name})
+        if res["code"] != 0:
+            raise Exception(f"更新标签到infinity失败: {res}")
+        
+        # 返回创建成功的标签信息
+        return {
+            "name": tag_name,
+            "description": tag_desc,
+            "age_range": age_range,
+            "chunk_id": chunk_id,
+            "ragflow_chunk": chunk,
+            "infinity_update_result": res
+        }
+    
+    def upload_tags(self, tags_data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """
+        批量上传标签
+        
+        Args:
+            tags_data: 标签数据列表,每个元素包含一级标签、二级标签、三级标签、标签描述等信息
+            
+        Returns:
+            Dict[str, Any]: 上传结果,包含成功数量、失败数量等信息
+        """
+        success_count = 0
+        fail_count = 0
+        failed_tags = []
+
+        tag_data_list = self.get_tag_data(tags_data)
+        # 标签数据入库
+        for tag_data in tag_data_list:
+            tag_name = tag_data["name"]
+            tag_desc = tag_data["description"]
+            age_range = tag_data["age_range"]
+            # 步骤1:将分块数据写入ragflow
+            chunk = self.chunk_util.add_chunk(dataset_name="标签", document_id=self.tag_document_id, content=tag_desc, important_keywords=[age_range])
+            print(f"分块数据写入成功, chunk_id: {chunk.id}")
+            # 步骤2: 调用infinity的update方法,将标签更新到块数据中
+            res = self.infinity_client.update(table_name="ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002", 
+                                              cond=f"id = '{chunk.id}'", 
+                                              data={"tag_kwd": tag_name},
+                                              database_name="default_db")
+            print(f"更新标签 {tag_name} 到 infinity 结果: {res}")
+            if res.error_code == 0:
+                print(f"标签 {tag_name} 更新到 infinity 成功")
+                success_count += 1
+            else:
+                fail_count += 1
+                failed_tags.append({"age_range": age_range, "tag_name": tag_name, "tag_desc": tag_desc, "error": res})
+        return {
+            "success": True,
+            "total": len(tags_data),
+            "success_count": success_count,
+            "fail_count": fail_count,
+            "failed_tags": failed_tags
+        }
+    
+    def delete_tag(self, tag_id: str) -> Dict[str, Any]:
+        """
+        删除标签
+        
+        Args:
+            tag_id: 标签ID
+            
+        Returns:
+            Dict[str, Any]: 删除结果
+        """
+        pass
+
+    def get_tag_data(self, tags_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+        import json
+        import os
+        
+        # 读取年龄段配置文件
+        age_level_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "conf", "age_level.json")
+        with open(age_level_path, "r", encoding="utf-8") as f:
+            age_level_map = json.load(f)
+        
+        # 存储转换后的标签数据
+        formatted_tags = []
+        
+        # 遍历标签数据列表,验证每个标签数据是否符合要求,并按照要求格式输出
+        for tag_data in tags_data:
+            # 提取标签数据
+            age_level_code = tag_data.get("年龄分级", "")  # 年龄分级编码
+            one_tag = tag_data.get("一级标签", "")  # 一级标签
+            two_tag = tag_data.get("二级标签", "")  # 二级标签
+            three_tag = tag_data.get("三级标签", "")  # 三级标签
+            tag_desc = tag_data.get("标签描述", "")  # 标签描述
+                
+            # 生成标签名称:一级标签_二级标签_三级标签
+            tag_name = f"{one_tag}_{two_tag}_{three_tag}"
+                
+            # 获取年龄段
+            age_range = age_level_map.get(age_level_code, "")
+                
+            # 构建最终标签数据
+            formatted_tag_data = {
+                "name": tag_name,
+                "description": tag_desc,
+                "age_range": age_range
+            }
+            
+            # 将转换后的标签数据添加到列表中
+            formatted_tags.append(formatted_tag_data)
+        
+        # 返回转换后的标签数据集合
+        return formatted_tags
+
+class TagServiceFactory:
+    """标签服务工厂类"""
+    
+    @staticmethod
+    def create_tag_service(db_client) -> TagService:
+        """
+        创建标签服务实例
+        
+        Args:
+            db_client: 数据库客户端实例
+            
+        Returns:
+            TagService: 标签服务实例
+        """
+        return TagServiceImpl(db_client)

+ 7 - 5
api/hybrid_search_mcp.py

@@ -15,7 +15,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 from PIL import Image
 from utils.infinity_util import InfinityVectorDB
 from model.multimodal_embedding import Embedding
-from conf.config import ModelConfig, VectorDBConfig
+from conf.settings import model_settings, ragflow_settings, vector_db_settings
+
+
 
 # 初始化fastmcp应用
 mcp = FastMCP("Multi_Vector_Search")
@@ -25,8 +27,8 @@ vector_db = InfinityVectorDB()
 
 # 初始化多模态嵌入模型
 embedding_model = Embedding(
-    model_name=ModelConfig.get_multimodal_embedding_model_name(),
-    api_key=ModelConfig.get_dashscope_api_key()
+    model_name=model_settings.multimodal_embedding_model_name,
+    api_key=model_settings.dashscope_api_key
 )
 
 @mcp.tool(name="hybrid_search")
@@ -41,10 +43,10 @@ def hybrid_search(request: Dict[str, Any]) -> Dict[str, Any]:
         image_url = request["image"]
         topn = request.get("topn", 2)
         
-        print(f"开始混合检索,数据库: {VectorDBConfig.get_infinity_database}, 知识库id: {ModelConfig.get_dataset_id()}, 文本查询: {text_query}, 返回数量: {topn}")
+        print(f"开始混合检索,数据库: {vector_db_settings.infinity_database}, 知识库id: {ragflow_settings.dataset_id}, 文本查询: {text_query}, 返回数量: {topn}")
         
         # 构建索引名称
-        index_name = f"pdf_documents_{ModelConfig.get_dataset_id()}"        
+        index_name = f"pdf_documents_{ragflow_settings.dataset_id}"        
         print(f"开始生成多模态嵌入,文本长度: {len(text_query)}")
         
         # 处理image_url为image: Image.Image

+ 95 - 0
api/tag_manage.py

@@ -0,0 +1,95 @@
+# 标签管理API服务
+
+from fastapi import FastAPI, HTTPException, UploadFile, File
+from typing import List, Dict, Any, Optional
+from api.db.services.tag_service import TagServiceFactory
+from utils.infinity import get_client
+from utils.excel_util import excel_util
+from pydantic import BaseModel
+import os
+import tempfile
+
+
+# 创建FastAPI应用
+app = FastAPI(
+    title="标签管理 API",
+    description="标签管理服务,提供标签上传、查询、删除等功能",
+    version="1.0.0"
+)
+
+
+# 请求模型
+class TagCreateRequest(BaseModel):
+    """创建标签请求模型"""
+    name: str  # 标签名称,格式:一级标签_二级标签_三级标签
+    description: str  # 标签描述
+    age_range: str  # 年龄段,如:0-2、2-3等
+
+
+# 1. 标签上传接口
+@app.post("/upload", response_model=Dict[str, Any])
+async def upload_tags(file: UploadFile = File(...)):
+    """
+    批量上传标签接口
+    
+    - **file**: Excel格式的标签文件,包含年龄分级、一级标签、二级标签、三级标签、标签描述字段
+    
+    Excel文件格式要求:
+    - 支持 .xlsx 和 .xls 格式
+    - 第一行为表头,包含:年龄分级、一级标签、二级标签、三级标签、标签描述
+    - 年龄分级字段值应为 L1-L8 之间的一个(如 L1、L2 等)
+    """
+    try:
+        # 验证文件格式
+        file_ext = os.path.splitext(file.filename)[1].lower()
+        if file_ext not in [".xlsx", ".xls"]:
+            raise HTTPException(status_code=400, detail=f"不支持的文件格式 {file_ext},只支持 .xlsx 和 .xls 格式")
+        
+        # 创建临时文件
+        with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
+            temp_file.write(await file.read())
+            temp_file_path = temp_file.name
+        
+        try:
+            # 解析Excel文件获取标签数据
+            tags_data = excel_util.parse_excel(file_path=temp_file_path)
+            
+            # 使用工厂类创建标签服务实例
+            tag_service = TagServiceFactory.create_tag_service(get_client())
+            
+            # 调用批量上传标签方法
+            result = tag_service.upload_tags(tags_data)
+            
+            return {"success": True, "result": result}
+        finally:
+            # 删除临时文件
+            os.unlink(temp_file_path)
+    except HTTPException as e:
+        raise e
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"标签上传失败: {str(e)}")
+
+
+# 2. 新增标签接口
+@app.post("/create", response_model=Dict[str, Any])
+async def create_tag(request: TagCreateRequest):
+    """
+    新增标签接口
+    
+    - **name**: 标签名称,格式为:一级标签_二级标签_三级标签
+    - **description**: 标签描述,详细说明标签的含义和用途
+    - **age_range**: 年龄段,如:0-2、2-3、3-4等
+    """
+    try:
+        # 使用工厂类创建标签服务实例
+        tag_service = TagServiceFactory.create_tag_service(get_client())
+        
+        # 调用创建标签方法
+        tag_data = request.model_dump()
+        result = tag_service.create_tag(tag_data)
+        
+        return {"success": True, "tag": result}
+    except KeyError as e:
+        raise HTTPException(status_code=400, detail=f"缺少必填字段: {str(e)}")
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"创建标签失败: {str(e)}")

+ 11 - 0
common/models/pagination.py

@@ -0,0 +1,11 @@
+from pydantic import BaseModel, Field
+from typing import Optional
+
+class Pagination(BaseModel):
+    """通用分页与过滤模型"""
+    page: int = Field(default=1, ge=1, description="当前页码")
+    page_size: int = Field(default=30, ge=1, le=100, description="每页条数")
+    
+    def to_dict(self):
+        """过滤掉 None 值,转化为 API 要求的字典"""
+        return {k: v for k, v in self.model_dump().items() if v is not None}

+ 10 - 0
conf/age_level.json

@@ -0,0 +1,10 @@
+{
+    "L1": "0-2",
+    "L2": "2-3",
+    "L3": "3-4",
+    "L4": "4-5",
+    "L5": "5-6",
+    "L6": "6-10",
+    "L7": "10-14",
+    "L8": "14-100"
+}

+ 31 - 0
conf/config.py

@@ -33,6 +33,11 @@ class ModelConfig:
     def get_api_key() -> str:
         """获取模型API密钥"""
         return os.getenv("API_KEY", "")
+
+    @staticmethod
+    def get_rank_model_name() -> str:
+        """获取模型名称"""
+        return os.getenv("RANK_MODEL_NAME", "Qwen/Qwen3-Reranker-0.6B")
     
     @staticmethod
     def get_model_config() -> dict:
@@ -64,6 +69,9 @@ class ModelConfig:
         """获取DASHSCOPE API密钥"""
         return os.getenv("DASHSCOPE", "")
 
+class RagflowConfig:
+    """RAGFLOW配置类"""
+
     # RAGFLOW配置
     @staticmethod
     def get_ragflow_api_url() -> str:
@@ -80,6 +88,16 @@ class ModelConfig:
         """获取数据集ID"""
         return os.getenv("DATASET_ID", "")
 
+    @staticmethod
+    def get_ragflow_user_name() -> str:
+        """获取RAGFLOW用户名"""
+        return os.getenv("RAGFLOW_USER_NAME", "")
+    
+    @staticmethod
+    def get_ragflow_passwd() -> str:
+        """获取RAGFLOW密码"""
+        return os.getenv("RAGFLOW_PASSWD", "")
+
 class AppConfig:
     """应用配置类"""
     
@@ -159,6 +177,19 @@ class VectorDBConfig:
         """获取Infinity向量数据库表名"""
         return os.getenv("INFINITY_TABLE_NAME", "test")
 
+class TagSearchConfig:
+    """标签搜索配置类"""
+
+    @staticmethod
+    def get_tag_db_name() -> str:
+        """获取标签数据库名称"""
+        return os.getenv("TAG_DB_NAME", "tag_db")
+    
+    @staticmethod
+    def get_tag_table_name() -> str:
+        """获取标签数据库表名"""
+        return os.getenv("TAG_TABLE_NAME", "tag_table")
+
 # 导出配置实例
 model_config = ModelConfig.get_model_config()
 app_config = {

+ 123 - 0
conf/settings.py

@@ -0,0 +1,123 @@
+"""配置管理模块,使用 pydantic-settings 从.env文件加载环境变量"""
+
+from pydantic_settings import BaseSettings, SettingsConfigDict
+from pydantic import Field
+
+class AppSettings(BaseSettings):
+    """应用配置类"""
+    log_level: str = Field(default="INFO", alias="LOG_LEVEL")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class ModelSettings(BaseSettings):
+    """模型配置类"""
+    model_provider: str = Field(default="openai", alias="MODEL_PROVIDER")
+    model_name: str = Field(default="Qwen/Qwen3-VL-8B-Instruct", alias="MODEL_NAME")
+    embedding_model_name: str = Field(default="Qwen/Qwen3-Embedding-0.6B", alias="EMBEDDING_MODEL_NAME")
+    base_url: str = Field(default="https://api.openai.com/v1", alias="BASE_URL")
+    api_key: str = Field(default="", alias="API_KEY")
+    rank_model_name: str = Field(default="Qwen/Qwen3-Reranker-0.6B", alias="RANK_MODEL_NAME")
+    multimodal_embedding_model_name: str = Field(default="qwen2.5-vl-embedding", alias="MULTIMODAL_EMBEDDING_MODEL_NAME")
+    dashscope: str = Field(default="", alias="DASHSCOPE")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class RagflowSettings(BaseSettings):
+    """RAGFLOW配置类"""
+    ragflow_api_url: str = Field(default="http://192.168.16.134/", alias="RAGFLOW_API_URL")
+    ragflow_api_key: str = Field(default="", alias="RAGFLOW_API_KEY")
+    dataset_id: str = Field(default="", alias="DATASET_ID")
+    ragflow_user_name: str = Field(default="", alias="RAGFLOW_USER_NAME")
+    ragflow_passwd: str = Field(default="", alias="RAGFLOW_PASSWD")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class MinioSettings(BaseSettings):
+    """MinIO配置类"""
+    minio_endpoint: str = Field(default="http://localhost:9000", alias="MINIO_ENDPOINT")
+    minio_access_key: str = Field(default="minioadmin", alias="MINIO_ACCESS_KEY")
+    minio_secret_key: str = Field(default="minioadmin", alias="MINIO_SECRET_KEY")
+    minio_bucket_name: str = Field(default="ragflow", alias="MINIO_BUCKET_NAME")
+    minio_secure: bool = Field(default=False, alias="MINIO_SECURE")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class VectorDBSettings(BaseSettings):
+    """向量数据库配置类"""
+    vector_db_type: str = Field(default="es", alias="VECTOR_DB_TYPE")
+    infinity_host: str = Field(default="192.168.16.134", alias="INFINITY_HOST")
+    infinity_port: int = Field(default=23820, alias="INFINITY_PORT")
+    infinity_sdk_port: int = Field(default=23817, alias="INFINITY_SDK_PORT")
+    infinity_user: str = Field(default="admin", alias="INFINITY_USER")
+    infinity_password: str = Field(default="admin", alias="INFINITY_PASSWORD")
+    infinity_database: str = Field(default="test", alias="INFINITY_DATABASE")
+    infinity_table_name: str = Field(default="test", alias="INFINITY_TABLE_NAME")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class MysqlSettings(BaseSettings):
+    """MySQL配置类"""
+    mysql_host: str = Field(default="localhost", alias="MYSQL_HOST")
+    mysql_port: int = Field(default=3306, alias="MYSQL_PORT")
+    mysql_user: str = Field(default="root", alias="MYSQL_USER")
+    mysql_password: str = Field(default="", alias="MYSQL_PASSWORD")
+    mysql_database: str = Field(default="", alias="MYSQL_DATABASE")
+    mysql_charset: str = Field(default="utf8mb4", alias="MYSQL_CHARSET")
+    mysql_pool_size: int = Field(default=5, alias="MYSQL_POOL_SIZE")
+
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+
+class TagSearchSettings(BaseSettings):
+    """标签搜索配置类"""
+    tag_db_name: str = Field(default="tag_db", alias="TAG_DB_NAME")
+    tag_table_name: str = Field(default="tag_table", alias="TAG_TABLE_NAME")
+    tag_document_id: str = Field(default="", alias="TAG_DOCUMENT_ID")
+    tag_dataset_id: str = Field(default="", alias="TAG_DATASET_ID")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+
+# 创建配置实例
+model_settings = ModelSettings()
+ragflow_settings = RagflowSettings()
+app_settings = AppSettings()
+minio_settings = MinioSettings()
+vector_db_settings = VectorDBSettings()
+mysql_settings = MysqlSettings()
+tag_search_settings = TagSearchSettings()


+ 18 - 3
main.py

@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager
 
 # 导入所有子应用
 from api.search_infinity import app as search_app
+from api.tag_manage import app as tag_app
 
 # 定义主应用的生命周期管理
 @asynccontextmanager
@@ -12,15 +13,26 @@ async def main_lifespan(app: FastAPI):
     """主应用生命周期管理"""
     from utils.infinity import get_client, close_client
     print("=== Infinity API Gateway 启动 ===")
-    # 1. 初始化全局客户端(在服务启动时)
+    # 1. 初始化Infinity全局客户端(在服务启动时)
     get_client(database="book_image_db")
     print("✅ Infinity客户端已初始化")
+    
+    # 2. 初始化MySQL全局客户端
+    from utils.mysql import init_global_mysql_client
+    init_global_mysql_client()
+    print("✅ MySQL客户端已初始化")
+    
     yield
 
     print("=== Infinity API Gateway 关闭 ===")
-     # 2. 关闭全局客户端(在服务关闭时)
+    # 1. 关闭Infinity全局客户端(在服务关闭时)
     close_client()
     print("✅ Infinity客户端已关闭")
+    
+    # 2. 关闭MySQL全局客户端
+    from utils.mysql import close_global_mysql_client
+    close_global_mysql_client()
+    print("✅ MySQL客户端已关闭")
 
 # 创建主应用
 main_app = FastAPI(
@@ -33,6 +45,8 @@ main_app = FastAPI(
 # 挂载子应用
 # 1. 搜索 API - 访问路径: /search/*
 main_app.mount("/search", search_app, name="search_api")
+# 2. 标签管理 API - 访问路径: /tag/*
+main_app.mount("/tag", tag_app, name="tag_api")
 
 # 主应用根路径
 @main_app.get("/")
@@ -42,7 +56,8 @@ async def root():
         "message": "Welcome to GRAPH_RAG API Gateway",
         "available_apps": {
             "search_api": "访问路径: /search, 文档: /search/docs",
-            "hybrid_http_api": "访问路径: /hybrid, 文档: /hybrid/docs"
+            "hybrid_http_api": "访问路径: /hybrid, 文档: /hybrid/docs",
+            "tag_api": "访问路径: /tag, 文档: /tag/docs"
         }
     }
 

+ 133 - 0
model/jina_rerank.py

@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+import requests
+from langchain_core.callbacks import Callbacks
+from langchain_core.documents import BaseDocumentCompressor, Document
+from langchain_core.utils import get_from_dict_or_env
+from pydantic import ConfigDict, model_validator
+
+JINA_API_URL: str = "https://api.jina.ai/v1/rerank"
+
+
+class JinaRerank(BaseDocumentCompressor):
+    """Document compressor that uses `Jina Rerank API` with support for custom base_url."""
+
+    session: Any = None
+    """Requests session to communicate with API."""
+    top_n: Optional[int] = 3
+    """Number of documents to return."""
+    model: str = "jina-reranker-v1-base-en"
+    """Model to use for reranking."""
+    jina_api_key: Optional[str] = None
+    """Jina API key. Must be specified directly or via environment variable 
+        JINA_API_KEY."""
+    user_agent: str = "langchain"
+    """Identifier for the application making the request."""
+    base_url: str = JINA_API_URL
+    """Base URL for the Jina API. Defaults to JINA_API_URL."""
+
+    model_config = ConfigDict(
+        arbitrary_types_allowed=True,
+        extra="forbid",
+    )
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_environment(cls, values: Dict) -> Any:
+        """Validate that api key exists in environment."""
+        jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
+        user_agent = values.get("user_agent", "langchain")
+        session = requests.Session()
+        session.headers.update(
+            {
+                "Authorization": f"Bearer {jina_api_key}",
+                "Accept-Encoding": "identity",
+                "Content-type": "application/json",
+                "user-agent": user_agent,
+            }
+        )
+        values["session"] = session
+        return values
+
+    def rerank(
+        self,
+        documents: Sequence[Union[str, Document, dict]],
+        query: str,
+        *,
+        model: Optional[str] = None,
+        top_n: Optional[int] = -1,
+        max_chunks_per_doc: Optional[int] = None,
+    ) -> List[Dict[str, Any]]:
+        """Returns an ordered list of documents ordered by their relevance to the provided query.
+
+        Args:
+            query: The query to use for reranking.
+            documents: A sequence of documents to rerank.
+            model: The model to use for re-ranking. Default to self.model.
+            top_n : The number of results to return. If None returns all results.
+                Defaults to self.top_n.
+            max_chunks_per_doc : The maximum number of chunks derived from a document.
+        """  # noqa: E501
+        if len(documents) == 0:  # to avoid empty api call
+            return []
+        docs = [
+            doc.page_content if isinstance(doc, Document) else doc for doc in documents
+        ]
+        model = model or self.model
+        top_n = top_n if (top_n is None or top_n > 0) else self.top_n
+        data = {
+            "query": query,
+            "documents": docs,
+            "model": model,
+            "top_n": top_n,
+        }
+
+        # 构建完整的API路径
+        api_url = self.base_url.rstrip('/') + '/rerank'
+        resp = self.session.post(
+            api_url,
+            json=data,
+        ).json()
+
+        if "results" not in resp:
+            raise RuntimeError(resp["detail"])
+
+        results = resp["results"]
+        result_dicts = []
+        for res in results:
+            result_dicts.append(
+                {
+                    "index": res["index"],
+                    "relevance_score": res["relevance_score"],
+                }
+            )
+        return result_dicts
+
+    def compress_documents(
+        self,
+        documents: Sequence[Document],
+        query: str,
+        top_n: Optional[int] = None,
+        callbacks: Optional[Callbacks] = None,
+    ) -> Sequence[Document]:
+        """
+        Compress documents using Jina's Rerank API.
+
+        Args:
+            documents: A sequence of documents to compress.
+            query: The query to use for compressing the documents.
+            callbacks: Callbacks to run during the compression process.
+
+        Returns:
+            A sequence of compressed documents.
+        """
+        compressed = []
+        for res in self.rerank(documents=documents, query=query, top_n=top_n):
+            doc = documents[res["index"]]
+            doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
+            doc_copy.metadata["relevance_score"] = res["relevance_score"]
+            compressed.append(doc_copy)
+        return compressed

+ 30 - 20
model/multimodal_embedding.py

@@ -4,8 +4,9 @@ import base64
 import io
 from langchain_openai import OpenAIEmbeddings
 from dashscope import MultiModalEmbedding
-from conf.config import ModelConfig
-from utils.file.image_util import image_util as ImageUtil
+from conf.settings import model_settings
+from langfuse import observe
+from utils.file.image_util import image_util
 
 class Embedding:
     """Embedding模型工具"""
@@ -19,18 +20,12 @@ class Embedding:
             api_key: API密钥,若为None则使用配置文件中的值
         """
         # 获取模型配置
-        self.model_provider = ModelConfig.get_model_provider()
-        self.model_name = model_name or ModelConfig.get_model_name()
-        self.base_url = ModelConfig.get_base_url()
-        self.api_key = api_key or ModelConfig.get_api_key()
-        
-        # 使用langchain_openai初始化OpenAI Embeddings模型
-        self.embeddings = OpenAIEmbeddings(
-            model=self.model_name,
-            base_url=self.base_url,
-            api_key=self.api_key
-        )
+        self.model_provider = model_settings.model_provider
+        self.model_name = model_name or model_settings.model_name
+        self.base_url = model_settings.base_url
+        self.api_key = api_key or model_settings.api_key
     
+    @observe(name="text_embedding", as_type="embedding")
     def get_text_embedding(self, text: str) -> List[float]:
         """
         获取文本的embedding
@@ -42,11 +37,18 @@ class Embedding:
             List[float]: 文本的embedding向量
         """
         try:
-            embedding = self.embeddings.embed_query(text)
+            # 使用langchain_openai初始化OpenAI Embeddings模型
+            embeddings = OpenAIEmbeddings(
+                model=self.model_name,
+                base_url=self.base_url,
+                api_key=self.api_key
+            )
+            embedding = embeddings.embed_query(text)
             return embedding
         except Exception as e:
             raise Exception(f"文本embedding生成失败: {str(e)}")
     
+    @observe(name="texts_embedding", as_type="embedding")
     def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
         """
         获取多个文本的embedding
@@ -58,11 +60,18 @@ class Embedding:
             List[List[float]]: 文本列表的embedding向量列表
         """
         try:
-            embeddings = self.embeddings.embed_documents(texts)
+            # 使用langchain_openai初始化OpenAI Embeddings模型
+            embeddings = OpenAIEmbeddings(
+                model=self.model_name,
+                base_url=self.base_url,
+                api_key=self.api_key
+            )
+            embeddings = embeddings.embed_documents(texts)
             return embeddings
         except Exception as e:
             raise Exception(f"多个文本embedding生成失败: {str(e)}")
     
+    @observe(name="image_embedding", as_type="embedding")
     def get_image_embedding(self, image: Image.Image) -> List[float]:
         """
         获取图像的embedding
@@ -80,7 +89,7 @@ class Embedding:
             buffer.seek(0)
             
             # 压缩图片字节流
-            compressed_bytes = ImageUtil._compress_image_to_bytes(buffer)
+            compressed_bytes = image_util._compress_image_to_bytes(buffer)
             
             image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
             # 构建输入项
@@ -98,7 +107,8 @@ class Embedding:
                 return response.output["embeddings"][0]["embedding"]
         except Exception as e:
             raise Exception(f"图像embedding生成失败: {str(e)}")
-    
+
+    @observe(name="multimodal_embedding", as_type="embedding")
     def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
         """
         获取多模态(文本+图像)的embedding
@@ -118,7 +128,7 @@ class Embedding:
                 buffer.seek(0)
             
                 # 压缩图片字节流
-                compressed_bytes = ImageUtil()._compress_image_to_bytes(buffer)
+                compressed_bytes = image_util._compress_image_to_bytes(buffer)
             
                 image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
                 item.append({'image': f"data:image/png;base64,{image_base64}"})
@@ -148,6 +158,6 @@ def get_embedding_model() -> Embedding:
         Embedding: 全局单例的Embedding模型
     """
     return Embedding(
-        model_name=ModelConfig.get_multimodal_embedding_model_name(),
-        api_key=ModelConfig.get_dashscope_api_key()
+        model_name=model_settings.embedding_model_name,
+        api_key=model_settings.api_key
     )

+ 8 - 7
model/qwen_vl.py

@@ -3,7 +3,8 @@ from PIL import Image
 import base64
 import io
 from langchain.chat_models import init_chat_model
-from conf.config import ModelConfig
+from conf.settings import model_settings
+from langfuse.callback import CallbackHandler
 
 class QWenVLParser:
     """QWEN VL模型图像解析工具"""
@@ -16,11 +17,11 @@ class QWenVLParser:
             model_name: 模型名称,若为None则使用配置文件中的值
         """
         # 获取模型配置
-        self.model_provider = ModelConfig.get_model_provider()
-        self.model_name = model_name or ModelConfig.get_model_name()
-        self.base_url = ModelConfig.get_base_url()
-        self.api_key = ModelConfig.get_api_key()
-        
+        self.model_provider = model_settings.model_provider
+        self.model_name = model_name or model_settings.model_name
+        self.base_url = model_settings.base_url
+        self.api_key = model_settings.api_key
+        self.langfuse_handler = CallbackHandler()
         # 使用langchain的init_chat_model初始化模型
         self.model = init_chat_model(
             model_provider=self.model_provider,
@@ -82,7 +83,7 @@ class QWenVLParser:
             ]
             
             # 使用langchain模型调用OpenAI API
-            response = self.model.invoke(messages)
+            response = self.model.invoke(input=messages, config={"callbacks": [self.langfuse_handler]})
             
             # 提取解析结果
             content = response.content

+ 21 - 0
model/tracked_multi_embedding.py

@@ -0,0 +1,21 @@
+from model.multimodal_embedding import Embedding
+from PIL import Image
+from utils.decorators.langfuse_trace_embedding import langfuse_trace_embedding
+
+class TrackedMultiEmbeddings(Embedding):
+    
+    @langfuse_trace_embedding(name="multi-embed-query")
+    def embed_query(self, text: str):
+        return super().get_text_embedding(text)
+
+    @langfuse_trace_embedding(name="multi-embed-documents")
+    def embed_documents(self, texts: list[str]):
+        return super().get_texts_embedding(texts)
+
+    @langfuse_trace_embedding(name="multi-embed-image")
+    def embed_image(self, image: Image):
+        return super().get_image_embedding(image)
+
+    @langfuse_trace_embedding(name="multi-embed-multimodal")
+    def embed_multimodal(self, text: str, image: Image):
+        return super().get_multimodal_embedding(text, image)

+ 12 - 0
model/tracked_openai_embeddings.py

@@ -0,0 +1,12 @@
+from langchain_openai import OpenAIEmbeddings
+from utils.decorators.langfuse_trace_embedding import langfuse_trace_embedding
+
+class TrackedOpenAIEmbeddings(OpenAIEmbeddings):
+    
+    @langfuse_trace_embedding(name="openai-embed-query")
+    def embed_query(self, text: str):
+        return super().embed_query(text)
+
+    @langfuse_trace_embedding(name="openai-embed-documents")
+    def embed_documents(self, texts: list[str]):
+        return super().embed_documents(texts)

+ 0 - 0
parser/image_parser/__init__.py


+ 2 - 2
workflow/image_parsing_workflow.py → parser/image_parser/image_parser_workflow.py

@@ -19,7 +19,7 @@ from model.qwen_vl import QWenVLParser
 from utils.ragflow.ragflow_service import RAGFlowService
 from model.multimodal_embedding import Embedding
 from utils.minio.image_util import image_util
-from conf.config import ModelConfig
+from conf.settings import model_settings
 from utils.infinity import get_client
 
 # 定义工作流状态类
@@ -298,7 +298,7 @@ class ImageParsingWorkflow:
             zip_file_path=zip_file_path,
             book_name=book_name,
             dataset_id=dataset_id,
-            embedding_model=Embedding(model_name=ModelConfig.get_multimodal_embedding_model_name(), api_key=ModelConfig.get_dashscope_api_key()),
+            embedding_model=Embedding(model_name=model_settings.multimodal_embedding_model_name, api_key=model_settings.dashscope_api_key),
             ragflow_service=RAGFlowService(base_url=ragflow_api_url, api_key=rag_flow_api_key)
         )
         result = self.workflow.invoke(initial_state)

+ 82 - 14
workflow/workflow.py → parser/pdf_parser/pdf_parser_workflow.py

@@ -1,19 +1,15 @@
-import sys
 import os
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
-# 添加项目根目录到Python路径
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
-
 from langgraph.graph import StateGraph, START, END
-from langgraph.graph.message import add_messages
 from typing import List, Dict, Any
 from pydantic import BaseModel, Field, ConfigDict
 from parser.pdf_parser.pdf_splitter import PDFSplitter
 from model.qwen_vl import QWenVLParser
+from utils.ragflow_sdk import DataSetUtil, DocumentUtil, ChunkUtil
 from utils.ragflow.ragflow_service import RAGFlowService
 from model.multimodal_embedding import Embedding
-from conf.config import ModelConfig, VectorDBConfig
+from conf.settings import model_settings, vector_db_settings
 from utils.infinity import get_client
 
 # 定义工作流状态类
@@ -22,7 +18,11 @@ class PDFParsingState(BaseModel):
     model_config = ConfigDict(arbitrary_types_allowed=True)
     pdf_path: str = Field(..., description="PDF文件路径")
     dataset_id: str = Field(..., description="数据集ID")
-    ragflow_service: RAGFlowService = Field(default_factory=RAGFlowService, description="RAGFLOW服务")
+    page_dataset_id: str = Field(..., description="页面数据集ID")
+    ragflow_service: RAGFlowService = Field(default_factory=RAGFlowService, description="RAGFlow服务实例")
+    dataset_util: DataSetUtil = Field(default_factory=DataSetUtil, description="数据集工具类实例")
+    document_util: DocumentUtil = Field(default_factory=DocumentUtil, description="文档工具类实例")
+    chunk_util: ChunkUtil = Field(default_factory=ChunkUtil, description="文档工具类实例")
     embedding_model: Embedding = Field(default_factory=Embedding, description="多模态嵌入模型实例")
     document_id: str = Field(default="", description="上传后的文档ID")
     split_pages: List[Dict[str, Any]] = Field(default_factory=list, description="拆分后的页面列表")
@@ -46,6 +46,7 @@ class PDFParsingWorkflow:
         """
         self.model_name = model_name
         self.workflow = self._build_workflow()
+        
     
     def _build_workflow(self):
         """构建langgraph工作流,实现基于条件路由的并行处理"""
@@ -97,6 +98,37 @@ class PDFParsingWorkflow:
         # 编译工作流
         return graph.compile()
     
+    def get_ragflow_dataset(self, dataset_name: str) -> str:
+        """获取RAGFLOW数据集ID"""
+        try:
+            dataset_id = self.dataset_util.get_dataset(name=dataset_name)
+            print(f"数据集 {dataset_name} 的ID为: {dataset_id}")
+            return dataset_id
+        except Exception as e:
+            print(f"获取数据集ID时出错: {str(e)}")
+            raise
+
+    def create_ragflow_dataset(self, state: PDFParsingState, dataset_name: str) -> str:
+        """创建RAGFLOW数据集"""
+        if state.dataset_id:
+            print(f"数据集 {dataset_name} 已存在,数据集ID: {state.dataset_id}")
+            return state.dataset_id
+        
+        print(f"开始创建数据集: {dataset_name}")
+        
+        try: 
+            # 创建数据集
+            dataset_id = self.dataset_util.create_dataset(
+                chunk_method="naive",
+                dataset_name=dataset_name,
+                dataset_desc="",
+            )
+            print(f"数据集创建成功,数据集ID: {dataset_id}")
+            return dataset_id
+        except Exception as e:
+            print(f"创建数据集时出错: {str(e)}")
+            raise
+
     def _upload_document_node(self, state: PDFParsingState) -> Dict[str, Any]:
         """RAGFLOW上传文档节点"""
         print(f"开始上传文档到数据集 {state.dataset_id}: {state.pdf_path}")
@@ -107,13 +139,20 @@ class PDFParsingWorkflow:
                 dataset_id=state.dataset_id,
                 file_path=state.pdf_path
             )
+            # 上传文档
+            document_info_list2 = state.ragflow_service.upload_document(
+                dataset_id=state.page_dataset_id,
+                file_path=state.pdf_path
+            )
             
             # 检查响应
             if document_info_list and len(document_info_list) > 0:
                 document_id = document_info_list[0]["id"]
+                page_document_id = document_info_list2[0]["id"]
                 print(f"文档上传成功,文档ID: {document_id}")
                 return {
-                    "document_id": document_id
+                    "document_id": document_id,
+                    "page_document_id": page_document_id
                 }
             else:
                 print("文档上传失败: 未返回有效的文档信息")
@@ -278,12 +317,38 @@ class PDFParsingWorkflow:
         # 所以这里总是返回"complete"
         return "complete"
     
+    def create_ragflow_chunk(self, state: PDFParsingState):
+        """单页上传节点"""
+        print(f"开始单页上传,共 {len(state.parsed_results)} 页")
+        
+        # 遍历所有解析结果,上传单页
+        for parsed_result in state.parsed_results:
+            page_number = parsed_result.get("page_number")
+            text = parsed_result.get("content", "")
+            image = state.split_pages[page_number - 1].get("image")
+            
+            # 上传单页到RagFlow Chunk
+            chunk = state.chunk_util.add_chunk(
+                dataset_name=state.dataset_name,
+                document_id=state.page_document_id,
+                content=text,
+            )
+
+            infinity_client = get_client()
+            infinity_client.update(database_name=state.dataset_name, table_name="", cond=f"id = {chunk_id}", data={"tag_kwd": tag_name})
+            
+            # 检查响应
+            if document_info and document_info.get("id"):
+                print(f"第 {page_number} 页上传成功,文档ID: {document_info['id']}")
+            else:
+                print(f"第 {page_number} 页上传失败")
+
     def _vectorize_store_node(self, state: PDFParsingState) -> Dict[str, Any]:
         """向量化入库节点"""
         print(f"开始向量化入库,共 {len(state.parsed_results)} 页")
         
         # 创建索引(如果不存在)
-        index_name = f"{VectorDBConfig.get_infinity_table_name()}"
+        index_name = f"{vector_db_settings.infinity_table_name}"
         state.vector_db.create_index(index_name)
         
         # 准备要入库的文档列表
@@ -349,13 +414,13 @@ class PDFParsingWorkflow:
             "is_complete": True
         }
     
-    def run(self, pdf_path: str, dataset_id: str, ragflow_api_url: str, rag_flow_api_key: str) -> Dict[str, Any]:
+    def run(self, pdf_path: str, page_dataset_id: str, ragflow_api_url: str, rag_flow_api_key: str) -> Dict[str, Any]:
         """
         运行PDF解析工作流
         
         Args:
             pdf_path: PDF文件路径
-            dataset_id: 数据集ID
+            page_dataset_id: 数据集ID
             ragflow_api_url: RAGFLOW API URL
             rag_flow_api_key: RAGFLOW API密钥
             
@@ -364,9 +429,12 @@ class PDFParsingWorkflow:
         """
         initial_state = PDFParsingState(
             pdf_path=pdf_path,
-            dataset_id=dataset_id,
-            embedding_model=Embedding(model_name=ModelConfig.get_multimodal_embedding_model_name(), api_key=ModelConfig.get_dashscope_api_key()),
-            ragflow_service=RAGFlowService(base_url=ragflow_api_url, api_key=rag_flow_api_key)
+            page_dataset_id=page_dataset_id,
+            embedding_model=Embedding(model_name=model_settings.multimodal_embedding_model_name, api_key=model_settings.dashscope_api_key),
+            dataset_util=DataSetUtil(),
+            document_util=DocumentUtil(),
+            chunk_util=ChunkUtil(),
+            ragflow_service=RAGFlowService(api_url=ragflow_api_url, api_key=rag_flow_api_key)
         )
         result = self.workflow.invoke(initial_state)
         

+ 27 - 0
private_key.pem

@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEAk737JTysaX2ZnsMtvU8uOB8sSTqoL4/z7hcO01IwYDFTZPUO
+UeBcTbwsIxoyh3IQReW/SL9kay3ZeZzg3xRZYaiKkMVtbnu1ijt3/t4/K+rmIBs/
+8Lg9vpJGVyY43rPlaJP0IlrttnGAOiurDSnSM3iJN3BiZsdBz5fMgEShp6PQpWpJ
+ggOjK+2wuMlI1GP2nrJs/9/FN5yIcoqbJtxOkXiHTRLvn/6V0iwO0KR8n5WWUseA
+ziVtIEDTZxIGkYyfK1DrRP3WZYmXS3gDbV5wqtNDLhnrbYvarZA6ntC1aAa2S8YA
+Ey3dgGee66547mF8nML7qWzxcVrLpRGYdcz0wQIDAQABAoIBACknHyC1MNVFjhfi
+CAiWWKp7TGutchGliG3GcHO+jgkNVExn0jBOTYe9pZG1h2OhCxVdNJX/8enE4LH2
+V9DrqyX/u3Et/RhsTxoNDJSqtW/3Q1ioHnxf5ZsCxaMa23Z1J66kn7k+GL7cCLqr
+VT8aahm6NoaBwHceg59eijEkbN7z3vEdpcoM7Y/T9vsIiRxof7XbIkoMAR14ZZsF
+51ht/Uw6lcY+ZBil4m+PSGkTF00WrTYYySVpLNeOg7gHOZhCvkegu5ZjAgJf4468
++eG6OW57UP+Fvx1cy1YY+2rO1FiNclmcNKEj6m8kvbcsGDI84NzXx2Z5UNMehTZm
+TsaJBG0CgYEAzgpzsPK/NyZmLwmoty9h+7cIGsVXGDbJBSbHDKLbt1B/QEtUGvVZ
+MdilQBWmYPw74F+dyl/pOiaXppEwwCej/5aIQVc2LruekOj2nlC6aAE+vI1fFJwn
+MPWqHgQKv3Q0oC/0jOA4WfDAuPMAlWgCmPOHX+dpn60Ehd2Gg9jzdGsCgYEAt5DC
+hx5ag7X3YAqm16o63ztDDLmyA9SWuD41wmXqxiVANV4aLOb5xOKhTGePN+vL4uYF
+3XD+SBWeaksIlp8jKgQ6lo+E7ODvFCWp7anUE4joqRYDLX76YfYaV55ftn2QG13g
+FlSJrcHW/LDW2K4VljZ/RMe9EjvwZpwXTLHxpoMCgYBljPc6dlpSZ0Wj12Hlb6Ae
+fTasYykrEQ2A0Xz3uGkMb8tYOAhrBRQCImJINKN9V7R3QQk3TpBYDd9exF6/xonp
+FhtwAPj+jZB4oamzm9xUGJEAQ7liIhia4+oFwz484Jxze09QA88iXSvPvE1EVjA5
+jIXjed4wSglZpTaAU3SiowKBgQCq6eH1pNFqKRc3GTPXQeymxKrdQlj5MAtVaLbt
+Va+O5kguaHu+qyCLrOe6GAJFcwrXLIzxba+UtmxVU5+ffnRjPqS+YsrFM3nnLFpX
+qvSHa8YVCSXUSpIee+JzCkbbx3VRxcpLgFLgscbWaGDiRtRalEtTYxYr0Bs5rsQS
+6OuyuQKBgBtf5/D46K+gnzEHKxvrRNSWvzkWa9t89v/N6ndOhfTPanWfvwCeqEzQ
+nnWfYaQkB3vG54Rzv0pW7/UJnPxRLg6OuCaYBjYYXR7rzyBXuHotAH7MNf8nLiC1
+0ymLnAopyhJK0K17jzPh4vIMNFcjYMnVg7/NB825Ufj88VRow5TE
+-----END RSA PRIVATE KEY-----

+ 0 - 0
prompt/parser/page_parse_prompt.py


+ 9 - 0
public_key.pem

@@ -0,0 +1,9 @@
+-----BEGIN PUBLIC KEY-----
+MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAk737JTysaX2ZnsMtvU8u
+OB8sSTqoL4/z7hcO01IwYDFTZPUOUeBcTbwsIxoyh3IQReW/SL9kay3ZeZzg3xRZ
+YaiKkMVtbnu1ijt3/t4/K+rmIBs/8Lg9vpJGVyY43rPlaJP0IlrttnGAOiurDSnS
+M3iJN3BiZsdBz5fMgEShp6PQpWpJggOjK+2wuMlI1GP2nrJs/9/FN5yIcoqbJtxO
+kXiHTRLvn/6V0iwO0KR8n5WWUseAziVtIEDTZxIGkYyfK1DrRP3WZYmXS3gDbV5w
+qtNDLhnrbYvarZA6ntC1aAa2S8YAEy3dgGee66547mF8nML7qWzxcVrLpRGYdcz0
+wQIDAQAB
+-----END PUBLIC KEY-----

+ 82 - 0
test/test_asymmetric_encryption.py

@@ -0,0 +1,82 @@
+from utils.asymmetric_encryption import AsymmetricEncryption
+import os
+
+
+def test_asymmetric_encryption():
+    """测试非对称加密功能"""
+    print("开始测试非对称加密功能...")
+    
+    # 测试1:生成密钥对
+    print("\n1. 生成密钥对测试...")
+    private_pem, public_pem = AsymmetricEncryption.generate_key_pair()
+    print(f"私钥长度: {len(private_pem)} 字节")
+    print(f"公钥长度: {len(public_pem)} 字节")
+    print(f"私钥前100字节: {private_pem[:100]!r}")
+    print(f"公钥前100字节: {public_pem[:100]!r}")
+    
+    # 测试2:加密解密测试
+    print("\n2. 加密解密测试...")
+    test_message = "这是一个测试消息,用于测试非对称加密功能!"
+    print(f"原始消息: {test_message}")
+    
+    # 加密
+    encrypted = AsymmetricEncryption.encrypt(test_message, public_pem)
+    print(f"加密后: {encrypted}")
+    
+    # 解密
+    decrypted = AsymmetricEncryption.decrypt(encrypted, private_pem)
+    print(f"解密后: {decrypted}")
+    
+    # 验证
+    assert decrypted == test_message, "解密失败,结果与原始消息不符!"
+    print("✓ 加密解密测试通过!")
+    
+    # 测试3:密钥文件保存和加载测试
+    print("\n3. 密钥文件保存和加载测试...")
+    private_key_path = "private_key.pem"
+    public_key_path = "public_key.pem"
+    
+    # 保存密钥
+    AsymmetricEncryption.save_key_to_file(private_pem, private_key_path)
+    AsymmetricEncryption.save_key_to_file(public_pem, public_key_path)
+    print(f"✓ 密钥已保存到文件: {private_key_path}, {public_key_path}")
+    
+    # # 加载密钥
+    # loaded_private_pem = AsymmetricEncryption.load_key_from_file(private_key_path)
+    # loaded_public_pem = AsymmetricEncryption.load_key_from_file(public_key_path)
+    # print("✓ 密钥已从文件加载")
+    
+    # # 验证加载的密钥是否正确
+    # encrypted2 = AsymmetricEncryption.encrypt(test_message, loaded_public_pem)
+    # decrypted2 = AsymmetricEncryption.decrypt(encrypted2, loaded_private_pem)
+    # assert decrypted2 == test_message, "使用加载的密钥解密失败!"
+    # print("✓ 使用加载的密钥加密解密测试通过!")
+    
+    # # 测试4:生成密钥对并保存测试
+    # print("\n4. 生成密钥对并保存测试...")
+    # private_key_path2 = "test_private_key2.pem"
+    # public_key_path2 = "test_public_key2.pem"
+    
+    # AsymmetricEncryption.generate_key_pair_and_save(private_key_path2, public_key_path2)
+    # print(f"✓ 密钥对已生成并保存到文件: {private_key_path2}, {public_key_path2}")
+    
+    # # 验证生成并保存的密钥
+    # loaded_private_pem2 = AsymmetricEncryption.load_key_from_file(private_key_path2)
+    # loaded_public_pem2 = AsymmetricEncryption.load_key_from_file(public_key_path2)
+    # encrypted3 = AsymmetricEncryption.encrypt(test_message, loaded_public_pem2)
+    # decrypted3 = AsymmetricEncryption.decrypt(encrypted3, loaded_private_pem2)
+    # assert decrypted3 == test_message, "使用生成并保存的密钥解密失败!"
+    # print("✓ 使用生成并保存的密钥加密解密测试通过!")
+    
+    # # 清理临时文件
+    # print("\n5. 清理临时文件...")
+    # for file_path in [private_key_path, public_key_path, private_key_path2, public_key_path2]:
+    #     if os.path.exists(file_path):
+    #         os.remove(file_path)
+    #         print(f"✓ 删除临时文件: {file_path}")
+    
+    # print("\n所有测试通过!非对称加密功能正常工作。")
+
+
+if __name__ == "__main__":
+    test_asymmetric_encryption()

+ 39 - 0
test/test_excel_util.py

@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+"""
+测试 ExcelUtil 类的功能
+"""
+
+from utils.excel_util import ExcelUtil
+
+def test_excel_util():
+    """测试 ExcelUtil 类的功能"""
+    print("测试 ExcelUtil 类的功能...")
+    
+    try:
+        # 测试获取工作表名称方法
+        print("\n1. 测试获取工作表名称方法...")
+        # 注意:这里需要替换为实际的 Excel 文件路径
+        # sheet_names = ExcelUtil.get_sheet_names("test.xlsx")
+        # print(f"工作表名称: {sheet_names}")
+        print("✅ 工作表名称方法已实现")
+        
+        # 测试解析 Excel 方法
+        print("\n2. 测试解析 Excel 方法...")
+        # 注意:这里需要替换为实际的 Excel 文件路径
+        # data = ExcelUtil.parse_excel("test.xlsx")
+        # print(f"解析结果数量: {len(data)}")
+        # if data:
+        #     print(f"第一条数据: {data[0]}")
+        print("✅ 解析 Excel 方法已实现")
+        
+        print("\n🎉 所有测试通过!ExcelUtil 类功能正常。")
+        return True
+        
+    except Exception as e:
+        print(f"\n❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+if __name__ == "__main__":
+    test_excel_util()

+ 70 - 0
test/test_mysql_config.py

@@ -0,0 +1,70 @@
+"""测试MySQL配置从.env文件读取"""
+
+from utils.mysql.mysql_pool import get_mysql_pool
+from utils.mysql.mysql_conn import get_mysql_conn
+
+
+def test_mysql_pool_from_env():
+    """测试从.env文件读取MySQL连接池配置"""
+    print("=== 测试从.env文件读取MySQL连接池配置 ===")
+    
+    # 使用默认配置创建连接池(应从.env文件读取)
+    pool = get_mysql_pool()
+    print(f"连接池配置 - 主机: {pool.host}")
+    print(f"连接池配置 - 端口: {pool.port}")
+    print(f"连接池配置 - 用户名: {pool.user}")
+    print(f"连接池配置 - 数据库: {pool.database}")
+    print(f"连接池配置 - 字符集: {pool.charset}")
+    print(f"连接池配置 - 连接池大小: {pool.pool_size}")
+    
+    print("\n✓ 从.env文件读取MySQL连接池配置成功!")
+
+
+def test_mysql_conn_from_env():
+    """测试从.env文件读取MySQL连接配置"""
+    print("\n=== 测试从.env文件读取MySQL连接配置 ===")
+    
+    # 使用默认配置创建连接(应从.env文件读取)
+    conn = get_mysql_conn()
+    
+    # 尝试执行简单查询
+    try:
+        # 获取游标上下文管理器
+        with conn.get_cursor() as cursor:
+            # 执行简单查询
+            cursor.execute("SELECT 1 AS test")
+            result = cursor.fetchone()
+            print(f"执行简单查询结果: {result}")
+        print("\n✓ 从.env文件读取MySQL连接配置成功!")
+    except Exception as e:
+        print(f"\n✗ 执行查询失败: {e}")
+    finally:
+        conn.close()
+
+
+def test_mysql_conn_with_custom_params():
+    """测试自定义参数覆盖.env配置"""
+    print("\n=== 测试自定义参数覆盖.env配置 ===")
+    
+    # 使用自定义参数创建连接
+    conn = get_mysql_conn(database="test")
+    
+    try:
+        # 获取游标上下文管理器
+        with conn.get_cursor() as cursor:
+            # 执行简单查询
+            cursor.execute("SELECT 1 AS test")
+            result = cursor.fetchone()
+            print(f"执行简单查询结果: {result}")
+        print("\n✓ 自定义参数覆盖.env配置成功!")
+    except Exception as e:
+        print(f"\n✗ 执行查询失败: {e}")
+    finally:
+        conn.close()
+
+
+if __name__ == "__main__":
+    test_mysql_pool_from_env()
+    test_mysql_conn_from_env()
+    test_mysql_conn_with_custom_params()
+    print("\n🎉 所有测试完成!")

+ 95 - 0
test/test_mysql_conn.py

@@ -0,0 +1,95 @@
+from utils.mysql_conn import get_mysql_conn
+
+
+def test_mysql_connection():
+    """测试MySQL连接池"""
+    print("开始测试MySQL连接池...")
+    
+    # 获取MySQL连接管理器实例
+    conn_manager = get_mysql_conn(
+        host="localhost",
+        port=3306,
+        user="root",
+        password="password",
+        database="test"
+    )
+    
+    print("✓ MySQL连接管理器实例创建成功")
+    
+    # 测试execute方法
+    try:
+        # 创建测试表
+        create_table_sql = """
+        CREATE TABLE IF NOT EXISTS test_users (
+            id INT AUTO_INCREMENT PRIMARY KEY,
+            name VARCHAR(50) NOT NULL,
+            email VARCHAR(100) NOT NULL,
+            create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+        )
+        """
+        rowcount = conn_manager.execute(create_table_sql)
+        print(f"✓ 创建测试表成功,受影响行数: {rowcount}")
+        
+        # 插入测试数据
+        insert_sql = "INSERT INTO test_users (name, email) VALUES (%s, %s)"
+        params = ("测试用户", "test@example.com")
+        rowcount = conn_manager.execute(insert_sql, params)
+        print(f"✓ 插入测试数据成功,受影响行数: {rowcount}")
+        
+        # 测试fetch_one方法
+        select_one_sql = "SELECT * FROM test_users ORDER BY id DESC LIMIT 1"
+        result = conn_manager.fetch_one(select_one_sql)
+        print(f"✓ 查询单条数据成功,结果: {result}")
+        
+        # 测试fetch_all方法
+        select_all_sql = "SELECT * FROM test_users"
+        results = conn_manager.fetch_all(select_all_sql)
+        print(f"✓ 查询所有数据成功,结果数量: {len(results)}")
+        
+        # 测试fetch_many方法
+        select_many_sql = "SELECT * FROM test_users"
+        results = conn_manager.fetch_many(select_many_sql, size=2)
+        print(f"✓ 查询多条数据成功,结果数量: {len(results)}")
+        
+        # 测试bulk_insert方法
+        bulk_insert_sql = "INSERT INTO test_users (name, email) VALUES (%s, %s)"
+        bulk_params = [
+            ("批量用户1", "batch1@example.com"),
+            ("批量用户2", "batch2@example.com"),
+            ("批量用户3", "batch3@example.com")
+        ]
+        rowcount = conn_manager.bulk_insert(bulk_insert_sql, bulk_params)
+        print(f"✓ 批量插入数据成功,受影响行数: {rowcount}")
+        
+        # 测试事务
+        try:
+            conn, cursor = conn_manager.begin_transaction()
+            cursor.execute("INSERT INTO test_users (name, email) VALUES (%s, %s)", ("事务用户", "transaction@example.com"))
+            conn_manager.commit_transaction(conn, cursor)
+            print("✓ 事务提交成功")
+        except Exception as e:
+            conn_manager.rollback_transaction(conn, cursor)
+            print(f"✗ 事务回滚: {e}")
+        
+        # 清理测试数据
+        delete_sql = "DELETE FROM test_users"
+        rowcount = conn_manager.execute(delete_sql)
+        print(f"✓ 清理测试数据成功,受影响行数: {rowcount}")
+        
+        # 删除测试表
+        drop_table_sql = "DROP TABLE IF EXISTS test_users"
+        rowcount = conn_manager.execute(drop_table_sql)
+        print(f"✓ 删除测试表成功,受影响行数: {rowcount}")
+        
+        print("\n🎉 所有测试通过!MySQL连接池工作正常。")
+        
+    except Exception as e:
+        print(f"\n✗ 测试失败: {e}")
+    
+    # 测试连接池关闭
+    conn_manager.close()
+    print("✓ 连接池关闭成功")
+
+
+if __name__ == "__main__":
+    test_mysql_connection()

+ 14 - 0
test/test_ragflow_sdk.py

@@ -0,0 +1,14 @@
+from api.db.services.tag_service import TagService
+
+if __name__ == "__main__":
+    tag_service = TagService()
+    tags_data = tag_service.get_tag_data()
+    print(tags_data)
+    
+    # multi_embeddings = Embedding(
+    #     model_name=ModelConfig.get_multimodal_embedding_model_name(),
+    #     api_key=ModelConfig.get_dashscope_api_key()
+    # )
+
+    # res = multi_embeddings.get_multimodal_embedding("这是一个测试文本", Image.open(r"C:\Users\EDY\Desktop\粉宝.png"))
+    # print(res)

+ 85 - 0
test/test_search_tag.py

@@ -0,0 +1,85 @@
+from langchain.chat_models import init_chat_model
+from model.multimodal_embedding import Embedding
+from conf.settings import model_settings, ragflow_settings, tag_search_settings
+from utils.infinity import get_client
+from model.jina_rerank import JinaRerank
+from langchain_core.documents import Document
+from utils.infinity.result_util import convert_to_langchain_docs
+
+# 初始化多模态嵌入模型
+embedding_model = Embedding(
+    model_name=model_settings.embedding_model_name,
+    api_key=model_settings.api_key
+)
+
+llm = init_chat_model(
+    model_provider=model_settings.model_provider,
+    model=model_settings.model_name,
+    base_url=model_settings.base_url,
+    api_key=model_settings.api_key,
+    temperature=0.3
+)
+
+compressor = JinaRerank(
+    base_url=model_settings.base_url,
+    jina_api_key=model_settings.api_key,
+    model=model_settings.rank_model_name, # 建议先确认硅基后台此模型 ID 是否正确
+    top_n=5
+)
+
+output_fields = [
+                    "docnm",
+                    "tag_kwd",
+                    "content",
+                    "kb_id",
+                    "doc_id"
+                ]
+
+search_query = {
+    "matching_text": """
+    桑尼最初因专注画画,对朋友求助表现出不耐烦并大声吼叫,后来意识到是小米生日,主动平复情绪转变态度帮助朋友,完成情绪调节
+    """,
+    "query_vector": [],
+    "vector_field": "q_1024_vec",
+    "match_field": "content",
+    "topn": 10
+}
+
+def main():
+    infinity_client = get_client(database=tag_search_settings.tag_db_name)
+    # 问题向量化
+    embedding = embedding_model.get_text_embedding(search_query["matching_text"])
+    search_query["query_vector"] = embedding
+    # TAG_TABLE_NAME="ragflow_92162247e93e11f084830242ac1d0002_52275b36f03611f0a5340242c0a85002"
+    TAG_TABLE_NAME="ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002"
+    results = infinity_client.vector_search(TAG_TABLE_NAME, output_fields, search_query).to_result()
+    candidate_docs = convert_to_langchain_docs(results)
+    
+    # 4. 直接调用重排序逻辑
+    reranked_docs = compressor.compress_documents(
+        documents=candidate_docs, 
+        query=search_query["matching_text"],
+        top_n=5
+    )
+    print(reranked_docs)
+    # 5. 查看结果
+    for i, doc in enumerate(reranked_docs):
+        print(f"排名 {i+1}: 分数 {doc.metadata['relevance_score']:.4f}")
+        print(f"内容: {doc.page_content}")
+        print(f"标签: {doc.metadata['tag_kwd']}\n")
+
+from utils.asymmetric_encryption import AsymmetricEncryption
+
+def main2():
+    # passwd = "zhangqi@lelequ.net"
+    # loaded_public_pem = AsymmetricEncryption.load_key_from_file(r"D:\project\work\graph_rag_server\public_key.pem")
+    loaded_private_pem = AsymmetricEncryption.load_key_from_file(r"D:\project\work\graph_rag_server\private_key.pem")
+    # encrypted = AsymmetricEncryption.encrypt(passwd, loaded_public_pem)
+    # print(encrypted)
+    decrypted = AsymmetricEncryption.decrypt(ragflow_settings.ragflow_user_name, loaded_private_pem)
+    print(decrypted)
+    # assert decrypted2 == test_message, "使用加载的密钥解密失败!"
+    # print("✓ 使用加载的密钥加密解密测试通过!")
+
+if __name__ == "__main__":
+    main2()

+ 25 - 0
test_upload_tag.py

@@ -0,0 +1,25 @@
+from api.db.services.tag_service import TagServiceFactory
+from utils.excel_util import excel_util
+from utils.infinity import get_client
+
+def main():
+    infinity_client = get_client()
+    # 使用工厂类创建标签服务实例
+    tag_service = TagServiceFactory.create_tag_service(infinity_client)
+    # 使用单例实例解析 Excel 文件
+    tags_data = excel_util.parse_excel(file_path=r"D:\project\work\data_deal\测试标签.xlsx")
+    # 调用批量上传标签方法
+    result = tag_service.upload_tags(tags_data)
+    # 打印上传结果
+    print(f"上传结果: {result}")
+
+# def get_chunk():
+#     infinity_client = get_client()
+#     table_object = infinity_client.get_table("ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002", "default_db")
+#     res = table_object.output(["*"]).to_pl()
+#     print(res)
+
+
+
+if __name__ == "__main__":
+    main()

+ 151 - 0
utils/asymmetric_encryption.py

@@ -0,0 +1,151 @@
+from cryptography.hazmat.primitives import serialization, hashes
+from cryptography.hazmat.primitives.asymmetric import rsa, padding
+from cryptography.hazmat.backends import default_backend
+from typing import Tuple, Optional
+import base64
+
+
+class AsymmetricEncryption:
+    """非对称加密工具类,使用RSA算法"""
+
+    @staticmethod
+    def generate_key_pair(key_size: int = 2048) -> Tuple[bytes, bytes]:
+        """
+        生成RSA密钥对
+        
+        Args:
+            key_size: 密钥大小,默认为2048位
+        
+        Returns:
+            Tuple[bytes, bytes]: (私钥PEM格式,公钥PEM格式)
+        """
+        # 生成私钥
+        private_key = rsa.generate_private_key(
+            public_exponent=65537,
+            key_size=key_size,
+            backend=default_backend()
+        )
+        
+        # 生成公钥
+        public_key = private_key.public_key()
+        
+        # 将私钥序列化为PEM格式
+        private_pem = private_key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption()
+        )
+        
+        # 将公钥序列化为PEM格式
+        public_pem = public_key.public_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PublicFormat.SubjectPublicKeyInfo
+        )
+        
+        return private_pem, public_pem
+    
+    @staticmethod
+    def encrypt(message: str, public_key_pem: bytes) -> str:
+        """
+        使用公钥加密数据
+        
+        Args:
+            message: 要加密的明文
+            public_key_pem: 公钥PEM格式
+        
+        Returns:
+            str: 加密后的base64编码字符串
+        """
+        # 加载公钥
+        public_key = serialization.load_pem_public_key(
+            public_key_pem,
+            backend=default_backend()
+        )
+        
+        # 加密数据
+        encrypted = public_key.encrypt(
+            message.encode('utf-8'),
+            padding.OAEP(
+                mgf=padding.MGF1(algorithm=hashes.SHA256()),
+                algorithm=hashes.SHA256(),
+                label=None
+            )
+        )
+        
+        # 返回base64编码的加密数据
+        return base64.b64encode(encrypted).decode('utf-8')
+    
+    @staticmethod
+    def decrypt(encrypted_message: str, private_key_pem: bytes) -> str:
+        """
+        使用私钥解密数据
+        
+        Args:
+            encrypted_message: 加密后的base64编码字符串
+            private_key_pem: 私钥PEM格式
+        
+        Returns:
+            str: 解密后的明文
+        """
+        # 加载私钥
+        private_key = serialization.load_pem_private_key(
+            private_key_pem,
+            password=None,
+            backend=default_backend()
+        )
+        
+        # 解码base64加密数据
+        encrypted = base64.b64decode(encrypted_message)
+        
+        # 解密数据
+        decrypted = private_key.decrypt(
+            encrypted,
+            padding.OAEP(
+                mgf=padding.MGF1(algorithm=hashes.SHA256()),
+                algorithm=hashes.SHA256(),
+                label=None
+            )
+        )
+        
+        # 返回解密后的明文
+        return decrypted.decode('utf-8')
+    
+    @staticmethod
+    def save_key_to_file(key_pem: bytes, file_path: str) -> None:
+        """
+        将密钥保存到文件
+        
+        Args:
+            key_pem: 密钥PEM格式
+            file_path: 文件路径
+        """
+        with open(file_path, 'wb') as f:
+            f.write(key_pem)
+    
+    @staticmethod
+    def load_key_from_file(file_path: str) -> bytes:
+        """
+        从文件加载密钥
+        
+        Args:
+            file_path: 文件路径
+        
+        Returns:
+            bytes: 密钥PEM格式
+        """
+        with open(file_path, 'rb') as f:
+            return f.read()
+    
+    @staticmethod
+    def generate_key_pair_and_save(private_key_path: str, public_key_path: str, key_size: int = 2048) -> None:
+        """
+        生成密钥对并保存到文件
+        
+        Args:
+            private_key_path: 私钥文件路径
+            public_key_path: 公钥文件路径
+            key_size: 密钥大小,默认为2048位
+        """
+        private_pem, public_pem = AsymmetricEncryption.generate_key_pair(key_size)
+        AsymmetricEncryption.save_key_to_file(private_pem, private_key_path)
+        AsymmetricEncryption.save_key_to_file(public_pem, public_key_path)

+ 45 - 0
utils/decorators/langfuse_trace_embedding.py

@@ -0,0 +1,45 @@
+import functools
+import time
+from langfuse import Langfuse
+
+# 初始化 Langfuse 客户端
+langfuse = Langfuse()
+
+def langfuse_trace_embedding(name="embedding-operation"):
+    """
+    用于追踪 Embedding 操作的装饰器
+    """
+    def decorator(func):
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            # 1. 获取输入参数(假设第一个参数是文本或文本列表)
+            # 对于 embed_query(text) 或 embed_documents(texts)
+            input_data = args[1] if len(args) > 1 else kwargs.get("text") or kwargs.get("texts")
+            
+            # 2. 在 Langfuse 中开启一个 Span
+            # 如果当前上下文已有 Trace,它会自动关联;如果没有,它会创建新的 Trace
+            span = langfuse.span(
+                name=name,
+                input=input_data,
+                metadata={"function": func.__name__}
+            )
+            
+            start_time = time.time()
+            try:
+                # 3. 执行真正的 Embedding 操作
+                result = func(*args, **kwargs)
+                
+                # 4. 记录执行结果和耗时
+                span.end(
+                    output={
+                        "vector_count": len(result) if isinstance(result, list) else 1,
+                        "dimension": len(result[0]) if isinstance(result, list) and len(result) > 0 else len(result) if isinstance(result, list) else 0
+                    }
+                )
+                return result
+            except Exception as e:
+                # 捕捉异常并记录到 Langfuse
+                span.end(level="ERROR", status_message=str(e))
+                raise e
+        return wrapper
+    return decorator

+ 0 - 0
utils/decorators.py → utils/decorators/singleton.py


+ 221 - 0
utils/excel_util.py

@@ -0,0 +1,221 @@
+"""
+Excel 工具类
+
+该文件提供 Excel 解析功能,支持:
+- 解析 .xlsx 和 .xls 格式的 Excel 文件
+- 将 Excel 数据转换为 List[Dict[str, Any]] 格式
+- 支持指定工作表
+- 支持自定义表头行
+"""
+
+from typing import List, Dict, Any, Optional
+import os
+
+
+try:
+    from openpyxl import load_workbook
+    from openpyxl.worksheet.worksheet import Worksheet
+except ImportError:
+    raise ImportError("openpyxl 库未安装,请使用 pip install openpyxl 安装")
+
+try:
+    import xlrd
+    from xlrd.sheet import Sheet
+except ImportError:
+    raise ImportError("xlrd 库未安装,请使用 pip install xlrd 安装")
+
+
+class ExcelUtil:
+    """
+    Excel 工具类,用于解析 Excel 文件
+    """
+    
+    @staticmethod
+    def _parse_xlsx(file_path: str, sheet_name: Optional[str] = None, header_row: int = 0) -> List[Dict[str, Any]]:
+        """
+        解析 .xlsx 格式的 Excel 文件
+        
+        Args:
+            file_path: Excel 文件路径
+            sheet_name: 工作表名称,None 表示使用第一个工作表
+            header_row: 表头所在行索引,默认第 0 行
+            
+        Returns:
+            List[Dict[str, Any]]: 解析后的 Excel 数据
+        """
+        # 加载工作簿
+        workbook = load_workbook(filename=file_path, read_only=True)
+        
+        # 获取工作表
+        if sheet_name:
+            sheet: Worksheet = workbook[sheet_name]
+        else:
+            sheet: Worksheet = workbook.active
+        
+        # 获取所有行数据
+        rows = list(sheet.iter_rows(values_only=True))
+        
+        # 检查数据是否为空
+        if not rows:
+            return []
+        
+        # 检查表头行索引是否合法
+        if header_row >= len(rows):
+            raise ValueError(f"表头行索引 {header_row} 超出数据总行数 {len(rows)}")
+        
+        # 获取表头
+        headers = rows[header_row]
+        
+        # 检查表头是否为空
+        if not headers:
+            return []
+        
+        # 解析数据行
+        data: List[Dict[str, Any]] = []
+        for row in rows[header_row + 1:]:
+            if not row or all(cell is None for cell in row):
+                continue
+            
+            # 创建数据字典
+            row_data: Dict[str, Any] = {}
+            for i, header in enumerate(headers):
+                if i < len(row):
+                    row_data[header] = row[i]
+                else:
+                    row_data[header] = None
+            
+            data.append(row_data)
+        
+        workbook.close()
+        return data
+    
+    @staticmethod
+    def _parse_xls(file_path: str, sheet_name: Optional[str] = None, header_row: int = 0) -> List[Dict[str, Any]]:
+        """
+        解析 .xls 格式的 Excel 文件
+        
+        Args:
+            file_path: Excel 文件路径
+            sheet_name: 工作表名称,None 表示使用第一个工作表
+            header_row: 表头所在行索引,默认第 0 行
+            
+        Returns:
+            List[Dict[str, Any]]: 解析后的 Excel 数据
+        """
+        # 打开工作簿
+        workbook = xlrd.open_workbook(file_path)
+        
+        # 获取工作表
+        if sheet_name:
+            sheet: Sheet = workbook.sheet_by_name(sheet_name)
+        else:
+            sheet: Sheet = workbook.sheet_by_index(0)
+        
+        # 获取所有行数据
+        rows = []
+        for i in range(sheet.nrows):
+            rows.append(sheet.row_values(i))
+        
+        # 检查数据是否为空
+        if not rows:
+            return []
+        
+        # 检查表头行索引是否合法
+        if header_row >= len(rows):
+            raise ValueError(f"表头行索引 {header_row} 超出数据总行数 {len(rows)}")
+        
+        # 获取表头
+        headers = rows[header_row]
+        
+        # 检查表头是否为空
+        if not headers:
+            return []
+        
+        # 解析数据行
+        data: List[Dict[str, Any]] = []
+        for row in rows[header_row + 1:]:
+            if not row or all(cell == '' or cell is None for cell in row):
+                continue
+            
+            # 创建数据字典
+            row_data: Dict[str, Any] = {}
+            for i, header in enumerate(headers):
+                if i < len(row):
+                    row_data[header] = row[i]
+                else:
+                    row_data[header] = None
+            
+            data.append(row_data)
+        
+        return data
+    
+    @staticmethod
+    def parse_excel(file_path: str, sheet_name: Optional[str] = None, header_row: int = 0) -> List[Dict[str, Any]]:
+        """
+        解析 Excel 文件,支持 .xlsx 和 .xls 格式
+        
+        Args:
+            file_path: Excel 文件路径
+            sheet_name: 工作表名称,None 表示使用第一个工作表
+            header_row: 表头所在行索引,默认第 0 行
+            
+        Returns:
+            List[Dict[str, Any]]: 解析后的 Excel 数据
+        
+        Raises:
+            ValueError: 不支持的文件格式
+            FileNotFoundError: 文件不存在
+        """
+        # 检查文件是否存在
+        if not os.path.exists(file_path):
+            raise FileNotFoundError(f"文件 {file_path} 不存在")
+        
+        # 获取文件扩展名
+        file_ext = os.path.splitext(file_path)[1].lower()
+        
+        # 根据文件格式选择解析方法
+        if file_ext == '.xlsx':
+            return ExcelUtil._parse_xlsx(file_path, sheet_name, header_row)
+        elif file_ext == '.xls':
+            return ExcelUtil._parse_xls(file_path, sheet_name, header_row)
+        else:
+            raise ValueError(f"不支持的文件格式 {file_ext},只支持 .xlsx 和 .xls 格式")
+    
+    @staticmethod
+    def get_sheet_names(file_path: str) -> List[str]:
+        """
+        获取 Excel 文件中的所有工作表名称
+        
+        Args:
+            file_path: Excel 文件路径
+            
+        Returns:
+            List[str]: 工作表名称列表
+        
+        Raises:
+            ValueError: 不支持的文件格式
+            FileNotFoundError: 文件不存在
+        """
+        # 检查文件是否存在
+        if not os.path.exists(file_path):
+            raise FileNotFoundError(f"文件 {file_path} 不存在")
+        
+        # 获取文件扩展名
+        file_ext = os.path.splitext(file_path)[1].lower()
+        
+        # 根据文件格式选择方法
+        if file_ext == '.xlsx':
+            workbook = load_workbook(filename=file_path, read_only=True)
+            sheet_names = workbook.sheetnames
+            workbook.close()
+            return sheet_names
+        elif file_ext == '.xls':
+            workbook = xlrd.open_workbook(file_path)
+            sheet_names = workbook.sheet_names()
+            return sheet_names
+        else:
+            raise ValueError(f"不支持的文件格式 {file_ext},只支持 .xlsx 和 .xls 格式")
+
+
+# 单例模式实例
+excel_util = ExcelUtil()

+ 5 - 5
utils/file/minio/minio_util.py

@@ -1,18 +1,18 @@
 from minio import Minio
 from typing import BinaryIO
 from datetime import timedelta
-from conf.config import MinioConfig
+from conf.settings import minio_settings
 from utils.file.file_utils import generate_unique_filename
 
 class MinIOUtil:
     def __init__(self):
         self.client = Minio(
-            endpoint=MinioConfig.get_minio_endpoint(),
-            access_key=MinioConfig.get_minio_access_key(),
-            secret_key=MinioConfig.get_minio_secret_key(),
+            endpoint=minio_settings.minio_endpoint,
+            access_key=minio_settings.minio_access_key,
+            secret_key=minio_settings.minio_secret_key,
             secure=False
         )
-        self.bucket_name = MinioConfig.get_minio_bucket_name()
+        self.bucket_name = minio_settings.minio_bucket_name
         self._ensure_bucket_exists()
 
     def _ensure_bucket_exists(self):

+ 21 - 7
utils/infinity/client.py

@@ -3,7 +3,7 @@ from infinity.common import ConflictType
 from typing import Dict, Any, List, Optional
 import threading
 from contextlib import contextmanager
-from conf.config import VectorDBConfig
+from conf.settings import vector_db_settings
 from .pool import InfinityConnectionPool
 
 class InfinityClient:
@@ -19,9 +19,9 @@ class InfinityClient:
     
     def __init__(
         self,
-        host: str = VectorDBConfig.get_infinity_host(),
-        port: str = VectorDBConfig.get_infinity_sdk_port(),
-        database: str = VectorDBConfig.get_infinity_database(),
+        host: str = vector_db_settings.infinity_host,
+        port: str = vector_db_settings.infinity_sdk_port,
+        database: str = vector_db_settings.infinity_database,
         min_connections: int = 2,
         max_connections: int = 10
     ):
@@ -179,6 +179,20 @@ class InfinityClient:
             table = self._get_table(conn, table_name, database_name)
             return table.insert(documents)
     
+    def update(
+        self, 
+        table_name: str, 
+        cond: str, 
+        data: Dict[str, Any], 
+        database_name: Optional[str] = None
+    ):
+        """更新文档"""
+        with self.pool.get_connection() as conn:
+            # 获取Table对象
+            table = self._get_table(conn, table_name, database_name)
+            # 更新文档
+            return table.update(cond, data)
+
     def search(
         self,
         table_name: str,
@@ -238,9 +252,9 @@ _global_client: Optional[InfinityClient] = None
 _client_lock = threading.Lock()
 
 def get_client(
-    host: str = VectorDBConfig.get_infinity_host(),
-    port: str = VectorDBConfig.get_infinity_sdk_port(),
-    database: str = VectorDBConfig.get_infinity_database(),
+    host: str = vector_db_settings.infinity_host,
+    port: str = vector_db_settings.infinity_sdk_port,
+    database: str = vector_db_settings.infinity_database,
     min_connections: int = 2,
     max_connections: int = 10
 ) -> InfinityClient:

+ 4 - 4
utils/infinity/pool.py

@@ -4,7 +4,7 @@ from typing import Dict, Any, List, Optional
 import threading
 import time
 from contextlib import contextmanager
-from conf.config import VectorDBConfig
+from conf.settings import vector_db_settings
 
 class InfinityConnectionPool:
     """
@@ -20,9 +20,9 @@ class InfinityConnectionPool:
     
     def __init__(
         self,
-        host: str = VectorDBConfig.get_infinity_host(),
-        port: str = VectorDBConfig.get_infinity_sdk_port(),
-        database: str = VectorDBConfig.get_infinity_database(),
+        host: str = vector_db_settings.infinity_host,
+        port: str = vector_db_settings.infinity_sdk_port,
+        database: str = vector_db_settings.infinity_database,
         min_connections: int = 2,
         max_connections: int = 10,
         connection_timeout: int = 30,

+ 78 - 0
utils/infinity/result_util.py

@@ -0,0 +1,78 @@
+from typing import Dict, Any, List
+from langchain_core.documents import Document
+
+
+def convert_to_basic_types(obj: Any) -> Any:
+    """
+    递归将对象转换为基本类型,以便Pydantic能够序列化
+    
+    特殊处理:当字典中的子项包含相同长度的数组时,将其转换为数组对象结构
+    例如:{"a": [1,2], "b": [3,4]} -> [{"a":1, "b":3}, {"a":2, "b":4}]
+    
+    Args:
+        obj: 要转换的对象
+    
+    Returns:
+        转换后的基本类型对象
+    """
+    if obj is None:
+        return None
+    elif isinstance(obj, (str, int, float, bool)):
+        return obj
+    elif isinstance(obj, dict):
+        # 先递归转换所有值
+        converted = {k: convert_to_basic_types(v) for k, v in obj.items()}
+        
+        # 检查是否需要转换为数组对象结构
+        # 条件:所有值都是列表,且长度一致,且长度大于0
+        values = list(converted.values())
+        if all(isinstance(v, list) for v in values):
+            lengths = [len(v) for v in values]
+            if len(set(lengths)) == 1 and lengths[0] > 0:
+                # 转换为数组对象结构
+                result = []
+                keys = list(converted.keys())
+                for i in range(lengths[0]):
+                    item = {}
+                    for key in keys:
+                        # 处理数组中可能存在的None值
+                        if i < len(converted[key]):
+                            item[key] = converted[key][i]
+                        else:
+                            item[key] = None
+                    result.append(item)
+                return result
+        
+        return converted
+    elif isinstance(obj, (list, tuple)):
+        return [convert_to_basic_types(item) for item in obj]
+    else:
+        # 对于其他类型,尝试将其转换为字符串或字典
+        try:
+            return dict(obj)
+        except:
+            return str(obj)
+
+def convert_to_langchain_docs(obj: Any) -> List[Document]:
+    """
+    将Infinity搜索结果转换为LangChain的Document格式
+    
+    Args:
+        obj: 要转换的对象
+    
+    Returns:
+        转换后的Document列表
+    """
+    res = convert_to_basic_types(obj=obj)
+    # 将数据转换为 LangChain 的 Document 格式
+    candidate_docs = [
+        Document(page_content=item["content"], 
+            metadata={
+                "docnm": item["docnm"], 
+                "tag_kwd": item["tag_kwd"], 
+                "kb_id": item["kb_id"], 
+                "doc_id": item["doc_id"]
+            }) 
+        for item in res[0]
+    ]
+    return candidate_docs

+ 52 - 0
utils/mysql/__init__.py

@@ -0,0 +1,52 @@
+from .mysql_conn import MySQLConnection
+
+
+# 全局 MySQL 客户端实例
+_global_mysql_client = None
+
+def init_global_mysql_client(host: str = None, port: int = None, 
+                           user: str = None, password: str = None, 
+                           database: str = None, charset: str = None,
+                           pool_size: int = None, **kwargs) -> None:
+    """
+    初始化全局 MySQL 客户端
+    
+    Args:
+        host: MySQL 主机地址
+        port: MySQL 端口号
+        user: MySQL 用户名
+        password: MySQL 密码
+        database: 数据库名称
+        charset: 字符集
+        pool_size: 连接池大小
+        **kwargs: 其他 MySQL 连接参数
+    """
+    global _global_mysql_client
+    if _global_mysql_client is None:
+        _global_mysql_client = MySQLConnection(
+            host=host, port=port, user=user, password=password, 
+            database=database, charset=charset, pool_size=pool_size, **kwargs
+        )
+
+
+def get_global_mysql_client() -> MySQLConnection:
+    """
+    获取全局 MySQL 客户端实例
+    
+    Returns:
+        MySQL SQL 执行器实例
+    """
+    global _global_mysql_client
+    if _global_mysql_client is None:
+        raise RuntimeError("Global MySQL client has not been initialized. Call init_global_mysql_client() first.")
+    return _global_mysql_client
+
+
+def close_global_mysql_client() -> None:
+    """
+    关闭全局 MySQL 客户端
+    """
+    global _global_mysql_client
+    if _global_mysql_client is not None:
+        _global_mysql_client.close()
+        _global_mysql_client = None

+ 50 - 114
utils/mysql_conn.py → utils/mysql/mysql_conn.py

@@ -1,49 +1,41 @@
 """
-MySQL 连接管理
+MySQL SQL 执行
 
-该文件提供 MySQL 数据库连接管理功能,支持:
+该文件提供 MySQL 数据库 SQL 执行功能,支持:
 - 单例模式
-- 连接池管理
 - 基本 CRUD 操作
 - 事务支持
 - 连接错误处理
+- 全局客户端管理
 """
 import pymysql
 from pymysql.cursors import DictCursor
 from typing import Any, List, Dict, Optional, Union
 from contextlib import contextmanager
-
-# 单例装饰器
-class singleton:
-    def __init__(self, cls):
-        self.cls = cls
-        self._instance = None
-    
-    def __call__(self, *args, **kwargs):
-        if self._instance is None:
-            self._instance = self.cls(*args, **kwargs)
-        return self._instance
+from .mysql_pool import get_mysql_pool, MySQLPool
+from utils.decorators.singleton import singleton
 
 @singleton
 class MySQLConnection:
     """
-    MySQL 连接管理
+    MySQL SQL 执行器
     支持:
     - 单例模式
-    - 连接池管理
     - 基本 CRUD 操作
     - 事务支持
     - 连接错误处理
     """
     
-    def __init__(self, host: str = "localhost", port: int = 3306, 
+    def __init__(self, mysql_pool: Optional[MySQLPool] = None,
+                 host: str = None, port: int = None, 
                  user: str = None, password: str = None, 
-                 database: str = None, charset: str = "utf8mb4",
-                 pool_size: int = 5, **kwargs):
+                 database: str = None, charset: str = None,
+                 pool_size: int = None, **kwargs):
         """
-        初始化 MySQL 连接池
+        初始化 MySQL SQL 执行器
         
         Args:
+            mysql_pool: 可选的 MySQL 连接池实例,如果提供则使用该实例,否则创建新实例
             host: MySQL 主机地址
             port: MySQL 端口号
             user: MySQL 用户名
@@ -53,76 +45,20 @@ class MySQLConnection:
             pool_size: 连接池大小
             **kwargs: 其他 MySQL 连接参数
         """
-        self.host = host
-        self.port = port
-        self.user = user
-        self.password = password
-        self.database = database
-        self.charset = charset
-        self.pool_size = pool_size
-        self.kwargs = kwargs
-        
-        # 初始化连接池
-        self._connection_pool = []
-        self._init_connection_pool()
-    
-    def _init_connection_pool(self):
-        """
-        初始化连接池
-        """
-        for _ in range(self.pool_size):
-            conn = self._create_connection()
-            if conn:
-                self._connection_pool.append(conn)
-    
-    def _create_connection(self) -> Optional[pymysql.connections.Connection]:
-        """
-        创建新的 MySQL 连接
-        
-        Returns:
-            MySQL 连接对象,失败返回 None
-        """
-        try:
-            conn = pymysql.connect(
-                host=self.host,
-                port=self.port,
-                user=self.user,
-                password=self.password,
-                database=self.database,
-                charset=self.charset,
-                cursorclass=DictCursor,
-                **self.kwargs
-            )
-            return conn
-        except pymysql.Error as e:
-            print(f"创建 MySQL 连接失败: {e}")
-            return None
+        # 如果提供了连接池实例,则使用该实例,否则创建新实例
+        if mysql_pool:
+            self._pool = mysql_pool
+        else:
+            self._pool = get_mysql_pool(host, port, user, password, database, charset, pool_size, **kwargs)
     
-    def _get_connection(self) -> Optional[pymysql.connections.Connection]:
+    def _get_connection(self) -> pymysql.connections.Connection:
         """
         从连接池获取连接
         
         Returns:
-            MySQL 连接对象,失败返回 None
+            MySQL 连接对象
         """
-        if self._connection_pool:
-            return self._connection_pool.pop()
-        else:
-            # 连接池为空,创建新连接
-            return self._create_connection()
-    
-    def _return_connection(self, conn: pymysql.connections.Connection):
-        """
-        将连接返回连接池
-        
-        Args:
-            conn: MySQL 连接对象
-        """
-        if len(self._connection_pool) < self.pool_size:
-            self._connection_pool.append(conn)
-        else:
-            # 连接池已满,关闭连接
-            conn.close()
+        return self._pool.get_connection()
     
     @contextmanager
     def get_cursor(self, cursorclass=DictCursor):
@@ -136,11 +72,9 @@ class MySQLConnection:
             MySQL 游标对象
         """
         conn = self._get_connection()
-        if not conn:
-            raise Exception("无法获取 MySQL 连接")
+        cursor = conn.cursor(cursorclass=cursorclass)
         
         try:
-            cursor = conn.cursor(cursorclass=cursorclass)
             yield cursor
             conn.commit()
         except Exception as e:
@@ -148,7 +82,7 @@ class MySQLConnection:
             raise e
         finally:
             cursor.close()
-            self._return_connection(conn)
+            conn.close()
     
     def execute(self, sql: str, params: Union[List, Dict] = None) -> int:
         """
@@ -234,16 +168,9 @@ class MySQLConnection:
             连接对象和游标对象
         """
         conn = self._get_connection()
-        if not conn:
-            raise Exception("无法获取 MySQL 连接")
-        
-        try:
-            conn.begin()
-            cursor = conn.cursor()
-            return conn, cursor
-        except Exception as e:
-            self._return_connection(conn)
-            raise e
+        conn.begin()
+        cursor = conn.cursor()
+        return conn, cursor
     
     def commit_transaction(self, conn: pymysql.connections.Connection, cursor: pymysql.cursors.Cursor):
         """
@@ -257,7 +184,7 @@ class MySQLConnection:
             conn.commit()
         finally:
             cursor.close()
-            self._return_connection(conn)
+            conn.close()
     
     def rollback_transaction(self, conn: pymysql.connections.Connection, cursor: pymysql.cursors.Cursor):
         """
@@ -271,26 +198,22 @@ class MySQLConnection:
             conn.rollback()
         finally:
             cursor.close()
-            self._return_connection(conn)
+            conn.close()
     
     def close(self):
         """
-        关闭连接池中的所有连接
+        关闭 SQL 执行器
         """
-        for conn in self._connection_pool:
-            try:
-                conn.close()
-            except Exception as e:
-                print(f"关闭 MySQL 连接失败: {e}")
-        self._connection_pool.clear()
+        # 关闭连接池
+        self._pool.close()
 
 # 简化的接口函数,便于快速使用
-def get_mysql_conn(host: str = "localhost", port: int = 3306, 
-                   user: str = None, password: str = None, 
-                   database: str = None, charset: str = "utf8mb4",
-                   pool_size: int = 5, **kwargs) -> MySQLConnection:
+def get_mysql_conn(host: str = None, port: int = None, 
+                  user: str = None, password: str = None, 
+                  database: str = None, charset: str = None,
+                  pool_size: int = None, **kwargs) -> MySQLConnection:
     """
-    获取 MySQL 连接管理器实例
+    获取 MySQL SQL 执行器实例
     
     Args:
         host: MySQL 主机地址
@@ -303,6 +226,19 @@ def get_mysql_conn(host: str = "localhost", port: int = 3306,
         **kwargs: 其他 MySQL 连接参数
     
     Returns:
-        MySQL 连接管理器实例
+        MySQL SQL 执行器实例
+    """
+    return MySQLConnection(host=host, port=port, user=user, password=password, 
+                          database=database, charset=charset, pool_size=pool_size, **kwargs)
+
+def get_mysql_conn_with_pool(mysql_pool: MySQLPool) -> MySQLConnection:
+    """
+    使用指定的连接池获取 MySQL SQL 执行器实例
+    
+    Args:
+        mysql_pool: MySQL 连接池实例
+    
+    Returns:
+        MySQL SQL 执行器实例
     """
-    return MySQLConnection(host, port, user, password, database, charset, pool_size, **kwargs)
+    return MySQLConnection(mysql_pool=mysql_pool)

+ 138 - 0
utils/mysql/mysql_pool.py

@@ -0,0 +1,138 @@
+"""
+MySQL 连接池配置
+
+该文件提供 MySQL 数据库连接池配置功能,支持:
+- 单例模式
+- DBUtils 连接池管理
+- 连接错误处理
+- 从.env文件读取配置
+"""
+import pymysql
+from pymysql.cursors import DictCursor
+from dbutils.pooled_db import PooledDB
+from conf.settings import mysql_settings
+
+# 单例装饰器
+class singleton:
+    def __init__(self, cls):
+        self.cls = cls
+        self._instance = None
+    
+    def __call__(self, *args, **kwargs):
+        if self._instance is None:
+            self._instance = self.cls(*args, **kwargs)
+        return self._instance
+
+@singleton
+class MySQLPool:
+    """
+    MySQL 连接池管理器
+    支持:
+    - 单例模式
+    - DBUtils 连接池管理
+    - 连接错误处理
+    """
+    
+    def __init__(self, host: str = None, port: int = None, 
+                 user: str = None, password: str = None, 
+                 database: str = None, charset: str = None,
+                 pool_size: int = None, **kwargs):
+        """
+        初始化 MySQL 连接池
+        
+        Args:
+            host: MySQL 主机地址
+            port: MySQL 端口号
+            user: MySQL 用户名
+            password: MySQL 密码
+            database: 数据库名称
+            charset: 字符集
+            pool_size: 连接池大小
+            **kwargs: 其他 MySQL 连接参数
+        """
+        # 从环境变量读取配置,优先级:传入参数 > 环境变量 > 默认值
+        self.host = host or mysql_settings.mysql_host
+        self.port = int(port or mysql_settings.mysql_port)
+        self.user = user or mysql_settings.mysql_user
+        self.password = password or mysql_settings.mysql_password
+        self.database = database or mysql_settings.mysql_database
+        self.charset = charset or mysql_settings.mysql_charset
+        self.pool_size = int(pool_size or mysql_settings.mysql_pool_size)
+        self.kwargs = kwargs
+        
+        # 初始化 DBUtils 连接池
+        self._pool = PooledDB(
+            creator=pymysql,
+            maxconnections=self.pool_size,
+            mincached=2,
+            maxcached=5,
+            maxshared=3,
+            blocking=True,
+            maxusage=None,
+            setsession=[],
+            ping=0,
+            host=self.host,
+            port=self.port,
+            user=self.user,
+            password=self.password,
+            database=self.database,
+            charset=self.charset,
+            cursorclass=DictCursor,
+            **kwargs
+        )
+    
+    def get_connection(self) -> pymysql.connections.Connection:
+        """
+        从连接池获取连接
+        
+        Returns:
+            MySQL 连接对象
+        """
+        return self._pool.connection()
+    
+    def close(self):
+        """
+        关闭连接池
+        """
+        # DBUtils 连接池会自动管理连接,无需手动关闭
+        pass
+
+# 简化的接口函数,便于快速使用
+def get_mysql_pool(host: str = None, port: int = None, 
+                  user: str = None, password: str = None, 
+                  database: str = None, charset: str = None,
+                  pool_size: int = None, **kwargs) -> MySQLPool:
+    """
+    获取 MySQL 连接池实例
+    
+    Args:
+        host: MySQL 主机地址
+        port: MySQL 端口号
+        user: MySQL 用户名
+        password: MySQL 密码
+        database: 数据库名称
+        charset: 字符集
+        pool_size: 连接池大小
+        **kwargs: 其他 MySQL 连接参数
+    
+    Returns:
+        MySQL 连接池实例
+    """
+    # 从环境变量读取默认值,与MySQLPool.__init__保持一致
+    default_host = mysql_settings.mysql_host
+    default_port = int(mysql_settings.mysql_port)
+    default_user = mysql_settings.mysql_user
+    default_database = mysql_settings.mysql_database
+    default_charset = mysql_settings.mysql_charset
+    default_pool_size = int(mysql_settings.mysql_pool_size)
+    
+    return MySQLPool(
+        host=host or default_host,
+        port=port or default_port,
+        user=user or default_user,
+        password=password,
+        database=database or default_database,
+        charset=charset or default_charset,
+        pool_size=pool_size or default_pool_size,
+        **kwargs
+    )

+ 3 - 1
utils/ragflow/dataset_service.py

@@ -6,7 +6,7 @@ class DatasetService:
     
     def create_dataset(self, name: str, description: str = None, 
                       embedding_model: str = None, permission: str = None,
-                      chunk_method: str = None) -> Dict[str, Any]:
+                      chunk_method: str = None, parser_config: dict = None) -> Dict[str, Any]:
         endpoint = "/api/v1/datasets"
         
         data = {"name": name}
@@ -18,6 +18,8 @@ class DatasetService:
             data["permission"] = permission
         if chunk_method is not None:
             data["chunk_method"] = chunk_method
+        if parser_config is not None:
+            data["parser_config"] = parser_config
         
         response = self.http_client.post(endpoint, json_data=data)
         

+ 3 - 3
utils/ragflow/ragflow_service.py

@@ -6,7 +6,7 @@ from dataclasses import dataclass
 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
 
 from utils.http_client import HTTPClient
-from conf.config import ModelConfig
+from conf.settings import ragflow_settings
 from utils.ragflow.dataset_service import DatasetService
 from utils.ragflow.document_service import DocumentService
 from utils.ragflow.chunk_service import ChunkService
@@ -78,8 +78,8 @@ class FileInfo:
 
 class RAGFlowService:
     def __init__(self, base_url: str = None, api_key: str = None):
-        base_url = base_url or ModelConfig.get_ragflow_api_url()
-        api_key = api_key or ModelConfig.get_ragflow_api_key()
+        base_url = base_url or ragflow_settings.ragflow_api_url
+        api_key = api_key or ragflow_settings.ragflow_api_key
         self.http_client = HTTPClient(base_url=base_url, api_key=api_key)
         
         self.dataset_service = DatasetService(self.http_client)

+ 9 - 0
utils/ragflow_sdk/__init__.py

@@ -0,0 +1,9 @@
+from .dataset_util import DataSetUtil
+from .document_util import DocumentUtil
+from .chunk_util import ChunkUtil
+
+__all__ = [
+    "DataSetUtil",
+    "DocumentUtil",
+    "ChunkUtil"
+]

+ 17 - 0
utils/ragflow_sdk/base_util.py

@@ -0,0 +1,17 @@
+from ragflow_sdk import RAGFlow
+from conf.settings import ragflow_settings
+
+
+class RAGFlowBaseUtil:
+    """
+    RAGFlow 工具类基础类,封装了 RAGFlow 客户端初始化逻辑
+    """
+    
+    def __init__(self):
+        """
+        初始化 RAGFlow 客户端
+        """
+        self.ragflow_client = RAGFlow(
+            api_key=ragflow_settings.ragflow_api_key,
+            base_url=ragflow_settings.ragflow_api_url
+        )

+ 80 - 0
utils/ragflow_sdk/chunk_util.py

@@ -0,0 +1,80 @@
+from typing import Optional
+from ragflow_sdk import Chunk
+from utils.ragflow_sdk.document_util import DocumentUtil
+from .base_util import RAGFlowBaseUtil
+
+
+
+class ChunkUtil(RAGFlowBaseUtil):
+
+    def __init__(self, dataset_id: str, document_id: str):
+        super().__init__()
+        self.dataset_id = dataset_id
+        self.document_id = document_id
+
+    def add_chunk(self, dataset_name: Optional[str] = None, document_id: Optional[str] = None, content: str = None, important_keywords: list[str] = []):
+        """
+        添加文档分块
+        """
+        try:
+            if not dataset_name or not document_id or not content:
+                raise ValueError("数据集名称、文档ID和内容不能为空")
+            chunk = DocumentUtil().get_document(dataset_name=dataset_name, document_id=document_id).add_chunk(content=content, important_keywords=important_keywords)
+            return chunk
+        except Exception as e:
+            print(f"添加文档分块失败: {e}")
+            raise e
+    
+    def delete_chunks(self, chunk_ids: list[str]):
+        """
+        删除文档分块
+        """
+        try:
+            DocumentUtil().get_document(dataset_name=dataset_name, document_id=document_id).delete_chunks(chunk_ids)
+        except Exception as e:
+            print(f"删除文档分块失败: {e}")
+            raise e
+    
+
+    def list_chunks(self, keywords: str = None, page: int = 1, page_size: int = 30, id : str = None) -> list[Chunk]:
+        """
+        列出文档分块
+        """
+        try:
+            chunks = DocumentUtil().get_document(dataset_name=dataset_name, document_id=document_id).list_chunks(keywords=keywords, page=page, page_size=page_size, id=id)
+            return chunks
+        except Exception as e:
+            print(f"列出文档分块失败: {e}")
+            raise e
+
+    def get_chunk(self, chunk_id: str) -> Chunk:
+        """
+        获取文档分块
+        """
+        _list = self.list_chunks(id=chunk_id)
+        if len(_list) > 0:
+            return _list[0]
+        raise Exception("Chunk %s not found" % chunk_id)
+    
+
+    def update_chunk(self, chunk_id: str, update_message: dict):
+        """
+        更新文档分块
+        """
+        try:
+            self.get_chunk(chunk_id=chunk_id).update(update_message)
+        except Exception as e:
+            print(f"更新文档分块失败: {e}")
+            raise e
+    
+    def retrieve(self, question: str, dataset_ids: list[str] = None, document_ids: list[str] = None, page: int = 1, page_size: int = 30, similarity_threshold: float = 0.2, vector_similarity_weight: float = 0.3, top_k: int = 1024, rerank_id: str = None, keyword: bool = False, cross_languages: list[str] = None, metadata_condition: dict = None) -> list[Chunk]:
+        """
+        检索文档分块
+        """
+        try:
+            chunks = self.ragflow_client.retrieve(question=question, dataset_ids=dataset_ids, document_ids=document_ids, page=page, page_size=page_size, similarity_threshold=similarity_threshold, vector_similarity_weight=vector_similarity_weight, top_k=top_k, rerank_id=rerank_id, keyword=keyword, cross_languages=cross_languages, metadata_condition=metadata_condition)
+            return chunks
+        except Exception as e:
+            print(f"检索文档分块失败: {e}")
+            raise e
+        

+ 45 - 0
utils/ragflow_sdk/conf/rag_parser_config.py

@@ -0,0 +1,45 @@
+class RagParserDefaults:
+
+    DATASET_PERMISSION="team"
+
+    DATASET_CHUNK_METHOD="naive"
+
+    DATASET_CONFIG_DICT = {
+            "chunk_token_num": 256,
+            "delimiter": "\n!?;。;!?",
+            "html4excel": False,
+            "layout_recognize": "Pro/Qwen/Qwen2.5-VL-7B-Instruct@SILICONFLOW",
+            "auto_keywords": 5,
+            "tag_kb_ids": [],
+            "topn_tags": 3,
+            "task_page_size": 4,
+            "raptor": {
+                "max_cluster": 64,
+                "max_token": 256,
+                "prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n      {cluster_content}\nThe above is the content you need to summarize.",
+                "random_seed": 0,
+                "threshold": 0.1,
+                "use_raptor": True
+            },
+            "graphrag": {
+                "resolution": True,
+                "use_graphrag": True,
+                "method": "general",
+                "entity_types": [
+                    "event",
+                    "Book",
+                    "Author",
+                    "Illustrator",
+                    "Series",
+                    "Theme",
+                    "Genre",
+                    "Character",
+                    "Setting",
+                    "AgeGroup",
+                    "Competency",
+                    "ArtStyle",
+                    "Award",
+                    "Publisher"
+                ]
+            }
+        }

+ 71 - 0
utils/ragflow_sdk/dataset_util.py

@@ -0,0 +1,71 @@
+from typing import Optional
+from ragflow_sdk import DataSet
+from conf.settings import ragflow_settings, model_settings
+from .conf.rag_parser_config import RagParserDefaults
+from .base_util import RAGFlowBaseUtil
+
+
+
+class DataSetUtil(RAGFlowBaseUtil):
+
+    def __init__(self):
+        super().__init__()
+
+    def create_dataset(self, chunk_method: str = RagParserDefaults.DATASET_CHUNK_METHOD,
+        dataset_name: Optional[str] = None, dataset_desc: str = "", 
+        embedding_model: str = f"{model_settings.embedding_model_name}@SILICONFLOW",
+        parser_config: dict = RagParserDefaults.DATASET_CONFIG_DICT):
+        """
+        创建数据集
+        """
+        # 封装数据集参数
+        parser_obj = DataSet.ParserConfig(self.ragflow_client, parser_config)
+        # 创建数据集
+        dataset = self.ragflow_client.create_dataset(
+            name=dataset_name,
+            description=dataset_desc,
+            embedding_model=embedding_model,
+            permission=RagParserDefaults.DATASET_PERMISSION, 
+            chunk_method= chunk_method,
+            parser_config=parser_obj
+        )
+        return dataset
+
+    def delete_dataset(self, ids: list[str] | None = None):
+        """
+        删除数据集
+        """
+        try:
+            if ids is not None:
+                self.ragflow_client.delete_datasets(ids)
+        except Exception as e:
+            print(f"删除数据集失败: {e}")
+
+    def get_dataset(self, name: Optional[str] = None):
+        """
+        获取数据集
+        """
+        try:
+            dataset = self.ragflow_client.get_dataset(name=name)
+        except Exception as e:
+            print(f"获取数据集失败: {e}")
+            return None
+        return dataset
+                
+    def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
+        """
+        获取数据集列表
+        """
+        try:
+            dataset_list = self.ragflow_client.list_datasets(
+                page=page,
+                page_size=page_size,
+                orderby=orderby,
+                desc=desc,
+                id=id,
+                name=name,
+            )
+        except Exception as e:
+            print(f"获取数据集列表失败: {e}")
+            return None
+        return dataset_list

+ 93 - 0
utils/ragflow_sdk/document_util.py

@@ -0,0 +1,93 @@
+from token import OP
+from typing import Optional
+from ragflow_sdk import DataSet, Document
+from conf.settings import ragflow_settings
+from .base_util import RAGFlowBaseUtil
+from common.models.pagination import Pagination
+
+
+
+class DocumentUtil(RAGFlowBaseUtil):
+
+    def __init__(self):
+        super().__init__()
+        
+
+    def upload_documents(self, dataset_name: Optional[str] = None, document_list: list[dict] = None):
+        """
+        上传文档到数据集
+        """
+        try:
+            if not document_list:
+                raise ValueError("文档列表不能为空")
+            # 上传文档到数据集
+            doc_list = self.ragflow_client.get_dataset(name=dataset_name).upload_documents(document_list)
+            return doc_list
+        except Exception as e:
+            print(f"上传文档到数据集失败: {e}")
+            raise e
+
+    def list_documents(self, dataset_name: Optional[str] = None, pagination: Pagination = Pagination(), id: str = None, keywords: str = None) -> list[Document]:
+        """
+        列出数据集的文档
+        """
+        try:
+            doc_list = self.ragflow_client.get_dataset(name=dataset_name).list_documents(id=id, keywords=keywords, **pagination.to_dict())
+            return doc_list
+        except Exception as e:
+            print(f"列出数据集文档失败: {e}")
+            raise e
+
+    def get_document(self, dataset_name: Optional[str] = None, document_id: Optional[str] = None) -> Document:
+        """
+        获取文档
+        """
+        _list = self.list_documents(dataset_name=dataset_name, id=document_id)
+        if len(_list) > 0:
+            return _list[0]
+        raise Exception("Document %s not found" % document_id)
+
+    # Document.update(update_message:dict)
+    def update_document(self, document: Document, update_message: dict):
+        """
+        更新文档
+        """
+        try:
+            document.update(update_message)
+        except Exception as e:
+            print(f"更新文档失败: {e}")
+            raise e 
+
+    def delete_document(self, dataset_name: Optional[str] = None, ids: list[str] = None):
+        """
+        删除文档
+        """
+        try:
+            if not ids:
+                raise ValueError("文档ID列表不能为空")
+            self.ragflow_client.get_dataset(name=dataset_name).delete_documents(ids=ids)
+        except Exception as e:
+            print(f"删除文档失败: {e}")
+            raise e
+
+    def async_parse_documents(self, dataset_name: Optional[str] = None,document_ids: list[str] = None):
+        """
+        异步解析文档
+        """
+        try:
+            self.ragflow_client.get_dataset(name=dataset_name).async_parse_documents(document_ids=document_ids)
+        except Exception as e:
+            print(f"异步解析文档失败: {e}")
+            raise e
+    
+    def parse_documents(self, dataset_name: Optional[str] = None,document_ids: list[str] = None) -> list[tuple[str, str, int, int]]:
+        """
+        解析文档
+        """
+        try:
+            doc_list = self.ragflow_client.get_dataset(name=dataset_name).parse_documents(document_ids=document_ids)
+            return doc_list
+        except Exception as e:
+            print(f"解析文档失败: {e}")
+            raise e
+

+ 0 - 0
utils/ragflow_sdk/param.json