Browse Source

feat: 新增RAGFlow SDK工具类、MySQL连接池和标签管理API

refactor: 重构配置管理,从config.py迁移到settings模块

feat(utils): 添加非对称加密工具类和Langfuse追踪装饰器

feat(api): 实现标签管理API及相关服务

feat(model): 扩展多模态嵌入模型功能,集成Langfuse追踪

fix(infinity): 修复Infinity客户端更新操作和结果转换问题

docs: 更新.env文件配置说明和测试用例
yingge 3 months ago
parent
commit
eafbea0aec
50 changed files with 2359 additions and 250 deletions
  1. 25 1
      .env
  2. 9 9
      agent/test_image_agent.py
  3. 3 55
      api/db/services/infinity_search_service.py
  4. 207 0
      api/db/services/tag_service.py
  5. 7 5
      api/hybrid_search_mcp.py
  6. 95 0
      api/tag_manage.py
  7. 11 0
      common/models/pagination.py
  8. 10 0
      conf/age_level.json
  9. 31 0
      conf/config.py
  10. 123 0
      conf/settings.py
  11. 0 0
      ee.json
  12. 18 3
      main.py
  13. 133 0
      model/jina_rerank.py
  14. 30 20
      model/multimodal_embedding.py
  15. 8 7
      model/qwen_vl.py
  16. 21 0
      model/tracked_multi_embedding.py
  17. 12 0
      model/tracked_openai_embeddings.py
  18. 0 0
      parser/image_parser/__init__.py
  19. 2 2
      parser/image_parser/image_parser_workflow.py
  20. 82 14
      parser/pdf_parser/pdf_parser_workflow.py
  21. 27 0
      private_key.pem
  22. 0 0
      prompt/parser/page_parse_prompt.py
  23. 9 0
      public_key.pem
  24. 82 0
      test/test_asymmetric_encryption.py
  25. 39 0
      test/test_excel_util.py
  26. 70 0
      test/test_mysql_config.py
  27. 95 0
      test/test_mysql_conn.py
  28. 14 0
      test/test_ragflow_sdk.py
  29. 85 0
      test/test_search_tag.py
  30. 25 0
      test_upload_tag.py
  31. 151 0
      utils/asymmetric_encryption.py
  32. 45 0
      utils/decorators/langfuse_trace_embedding.py
  33. 0 0
      utils/decorators/singleton.py
  34. 221 0
      utils/excel_util.py
  35. 5 5
      utils/file/minio/minio_util.py
  36. 21 7
      utils/infinity/client.py
  37. 4 4
      utils/infinity/pool.py
  38. 78 0
      utils/infinity/result_util.py
  39. 52 0
      utils/mysql/__init__.py
  40. 50 114
      utils/mysql/mysql_conn.py
  41. 138 0
      utils/mysql/mysql_pool.py
  42. 3 1
      utils/ragflow/dataset_service.py
  43. 3 3
      utils/ragflow/ragflow_service.py
  44. 9 0
      utils/ragflow_sdk/__init__.py
  45. 17 0
      utils/ragflow_sdk/base_util.py
  46. 80 0
      utils/ragflow_sdk/chunk_util.py
  47. 45 0
      utils/ragflow_sdk/conf/rag_parser_config.py
  48. 71 0
      utils/ragflow_sdk/dataset_util.py
  49. 93 0
      utils/ragflow_sdk/document_util.py
  50. 0 0
      utils/ragflow_sdk/param.json

+ 25 - 1
.env

@@ -7,10 +7,15 @@ DASHSCOPE=sk-bc0f1026a41c4c92beb014be8973e4e2
 # embedding模型配置
 # embedding模型配置
 EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-0.6B
 EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-0.6B
 MULTIMODAL_EMBEDDING_MODEL_NAME=qwen2.5-vl-embedding
 MULTIMODAL_EMBEDDING_MODEL_NAME=qwen2.5-vl-embedding
+# rank模型配置
+RANK_MODEL_NAME=Qwen/Qwen3-Reranker-0.6B
+
 # RAGFLOW配置
 # RAGFLOW配置
-RAGFLOW_API_URL=http://192.168.16.134/
+RAGFLOW_API_URL=http://192.168.16.134:9380/
 RAGFLOW_API_KEY=ragflow-sPJ06xiUdRrcfDRlOD-GN2gl-U2DLB-PbgNGckUu0KM
 RAGFLOW_API_KEY=ragflow-sPJ06xiUdRrcfDRlOD-GN2gl-U2DLB-PbgNGckUu0KM
 DATASET_ID=a0f1aa03ed2c11f08b8f0242c0a85002
 DATASET_ID=a0f1aa03ed2c11f08b8f0242c0a85002
+RAGFLOW_USER_NAME=O75u85uh+PwmwmJvNebYUCNKpD812xhfnQOvB+Mwy+cHQtGBV2dy0tMQKQwGHiW7MiLJkHPqSLn7ULTzav0c2w3yIze71PAcIfxUScautg6xMMgtjHd4ex8peVyXTQcWc0bmD+GxADaZoOMnDz/XNUtB8mggx/VZ1RBdhrZMylTEGQUcUE8ylbLAVgDVOR6iqJEjGNZYoWUFiuYZAB6bIfPdaPG8Kr0KQrqHj3Y0zZHXl92AloXnw5RsaBOc3p01PKc+xYujrb+mdSGs8flzNHMusWhB/0bJ+t0XxqgePymCCk6+7nmm7M6iZ6pRlYEeERsjiNd/1lcMaWxda5jv4A==
+RAGFLOW_PASSWD=ggR2p1L7DRp5v+VcfikmfMYWh99QgBjtQV1fexSFP5C18HVXhkTD2MeMrcnJft8ciGeRUPDStaYZft3CbHOBY3Bzm2+/WX8iVuZ+5kMkmiAL9iUu0RxK74g1x11zy2CAoASy344ZiDy1p1co7hE0ksPdL62U38dxPGPaK8ngqfTsMvpb0XztWBUxaK5Izl3fqPKaYp+eqC75vvj6PeeaajIc7I7dqSai3a6jipruZeA8VxF/cWjVSKubQrm/YLs8mge2mcSM0pyUU2t9LeNS+hOiUyxRAqE2s+yGmRpeCMKIcfeGaurlVesy8v7zjIjZZBbEPV8rSZOEskGRSAG38Q==
 
 
 # 应用配置
 # 应用配置
 LOG_LEVEL=INFO
 LOG_LEVEL=INFO
@@ -28,6 +33,15 @@ INFINITY_USER=admin
 INFINITY_PASSWORD=admin
 INFINITY_PASSWORD=admin
 INFINITY_TABLE_NAME=book_page_image
 INFINITY_TABLE_NAME=book_page_image
 
 
+# MySQL配置
+MYSQL_HOST=192.168.16.134
+MYSQL_PORT=5455
+MYSQL_USER=root
+MYSQL_PASSWORD=ronshin@1234
+MYSQL_DATABASE=book_page
+MYSQL_CHARSET=utf8mb4
+MYSQL_POOL_SIZE=5
+
 # MinIO配置
 # MinIO配置
 MINIO_ENDPOINT=192.168.16.134:9000
 MINIO_ENDPOINT=192.168.16.134:9000
 MINIO_ACCESS_KEY=ck7I8Esssx6rzZrXQ5uP
 MINIO_ACCESS_KEY=ck7I8Esssx6rzZrXQ5uP
@@ -36,4 +50,14 @@ MINIO_BUCKET_NAME=bookpage
 # 本地测试设为false;生产环境设为true
 # 本地测试设为false;生产环境设为true
 MINIO_SECURE=False
 MINIO_SECURE=False
 
 
+# TAG_SEARCH
+TAG_DB_NAME=default_db
+TAG_DATASET_ID=18caf531f04d11f095670242c0a85002
+TAG_DOCUMENT_ID=3dda0a90f1e211f0a3b80242c0a85002
+#TAG_TABLE_NAME=ragflow_92162247e93e11f084830242ac1d0002_52275b36f03611f0a5340242c0a85002
+TAG_TABLE_NAME=ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002
 
 
+# LANGFUSE
+LANGFUSE_PUBLIC_KEY=pk-lf-6918a148-be72-4211-a22d-183a23e6643e
+LANGFUSE_SECRET_KEY=sk-lf-9c64d7ed-1618-4da9-a775-33e39b05448e
+LANGFUSE_HOST=http://192.168.16.134:3000

+ 9 - 9
agent/test_image_agent.py

@@ -9,7 +9,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 
 from utils.infinity_util import InfinityVectorDB
 from utils.infinity_util import InfinityVectorDB
 from model.multimodal_embedding import Embedding
 from model.multimodal_embedding import Embedding
-from conf.config import ModelConfig, VectorDBConfig
+from conf.settings import model_settings, ragflow_settings, vector_db_settings
 
 
 system_prompt = """
 system_prompt = """
 【角色设定】 你是一位拥有丰富经验的儿童绘本编辑与阅读推广人,擅长从视觉美学、儿童心理学和文学创作三个维度深度解读绘本。
 【角色设定】 你是一位拥有丰富经验的儿童绘本编辑与阅读推广人,擅长从视觉美学、儿童心理学和文学创作三个维度深度解读绘本。
@@ -29,15 +29,15 @@ vector_db = InfinityVectorDB()
 
 
 # 初始化多模态嵌入模型
 # 初始化多模态嵌入模型
 embedding_model = Embedding(
 embedding_model = Embedding(
-    model_name=ModelConfig.get_multimodal_embedding_model_name(),
-    api_key=ModelConfig.get_dashscope_api_key()
+    model_name=model_settings.multimodal_embedding_model_name,
+    api_key=model_settings.dashscope_api_key
 )
 )
 
 
 llm = init_chat_model(
 llm = init_chat_model(
-    model_provider=ModelConfig.get_model_provider(),
-    model=ModelConfig.get_model_name(),
-    base_url=ModelConfig.get_base_url(),
-    api_key=ModelConfig.get_api_key(),
+    model_provider=model_settings.model_provider,
+    model=model_settings.model_name,
+    base_url=model_settings.base_url,
+    api_key=model_settings.api_key,
     temperature=0.3
     temperature=0.3
 )
 )
 
 
@@ -52,10 +52,10 @@ def vector_search(text_query: str, image_path: str, topn: int = 2):
     :return: 检索结果
     :return: 检索结果
     """
     """
         
         
-    print(f"开始混合检索,数据库: {VectorDBConfig.get_infinity_database}, 知识库id: {ModelConfig.get_dataset_id()}, 文本查询: {text_query}, 返回数量: {topn}")
+    print(f"开始混合检索,数据库: {vector_db_settings.infinity_database}, 知识库id: {ragflow_settings.dataset_id}, 文本查询: {text_query}, 返回数量: {topn}")
         
         
     # 构建索引名称
     # 构建索引名称
-    index_name = f"{VectorDBConfig.get_infinity_table_name()}" 
+    index_name = f"{vector_db_settings.infinity_table_name}" 
     print(f"开始生成多模态嵌入,文本长度: {len(text_query)}")
     print(f"开始生成多模态嵌入,文本长度: {len(text_query)}")
         
         
     # 处理image_path为image: Image.Image
     # 处理image_path为image: Image.Image

+ 3 - 55
api/db/services/infinity_search_service.py

@@ -1,61 +1,9 @@
 from typing import Dict, Any, List
 from typing import Dict, Any, List
-from conf.config import VectorDBConfig
+from conf.settings import vector_db_settings
 from utils.infinity import InfinityClient
 from utils.infinity import InfinityClient
 from utils.file.image_util import image_util
 from utils.file.image_util import image_util
 from model.multimodal_embedding import get_embedding_model
 from model.multimodal_embedding import get_embedding_model
-
-
-def convert_to_basic_types(obj: Any) -> Any:
-    """
-    递归将对象转换为基本类型,以便Pydantic能够序列化
-    
-    特殊处理:当字典中的子项包含相同长度的数组时,将其转换为数组对象结构
-    例如:{"a": [1,2], "b": [3,4]} -> [{"a":1, "b":3}, {"a":2, "b":4}]
-    
-    Args:
-        obj: 要转换的对象
-    
-    Returns:
-        转换后的基本类型对象
-    """
-    if obj is None:
-        return None
-    elif isinstance(obj, (str, int, float, bool)):
-        return obj
-    elif isinstance(obj, dict):
-        # 先递归转换所有值
-        converted = {k: convert_to_basic_types(v) for k, v in obj.items()}
-        
-        # 检查是否需要转换为数组对象结构
-        # 条件:所有值都是列表,且长度一致,且长度大于0
-        values = list(converted.values())
-        if all(isinstance(v, list) for v in values):
-            lengths = [len(v) for v in values]
-            if len(set(lengths)) == 1 and lengths[0] > 0:
-                # 转换为数组对象结构
-                result = []
-                keys = list(converted.keys())
-                for i in range(lengths[0]):
-                    item = {}
-                    for key in keys:
-                        # 处理数组中可能存在的None值
-                        if i < len(converted[key]):
-                            item[key] = converted[key][i]
-                        else:
-                            item[key] = None
-                    result.append(item)
-                return result
-        
-        return converted
-    elif isinstance(obj, (list, tuple)):
-        return [convert_to_basic_types(item) for item in obj]
-    else:
-        # 对于其他类型,尝试将其转换为字符串或字典
-        try:
-            return dict(obj)
-        except:
-            return str(obj)
-
+from utils.infinity.result_util import convert_to_basic_types
 
 
 class InfinitySearchService:
 class InfinitySearchService:
     def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None):
     def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None):
@@ -72,7 +20,7 @@ class InfinitySearchService:
         self.vector_field = vector_field or "dense_vector_1024"
         self.vector_field = vector_field or "dense_vector_1024"
         self.match_field = match_field or "content"
         self.match_field = match_field or "content"
         self.match_type = match_type or "cosine"
         self.match_type = match_type or "cosine"
-        self.table_name = table_name or VectorDBConfig.get_infinity_table_name()
+        self.table_name = table_name or vector_db_settings.infinity_table_name
 
 
     def search(self, search_query: Dict[str, Any]) -> Dict[str, Any]:
     def search(self, search_query: Dict[str, Any]) -> Dict[str, Any]:
         """
         """

+ 207 - 0
api/db/services/tag_service.py

@@ -0,0 +1,207 @@
+from typing import List, Dict, Any, Optional
+from abc import ABC, abstractmethod
+from utils.ragflow_sdk.chunk_util import ChunkUtil
+from utils.infinity import InfinityClient
+from conf.settings import tag_search_settings
+
+
+
+class TagService(ABC):
+    """标签管理服务接口"""
+    
+    @abstractmethod
+    def create_tag(self, tag_data: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        新增标签
+        
+        Args:
+            tag_data: 标签数据,包含标签名称、描述等信息
+            
+        Returns:
+            Dict[str, Any]: 创建成功的标签信息
+        """
+        pass
+    
+    @abstractmethod
+    def upload_tags(self, tags_data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """
+        批量上传标签
+        
+        Args:
+            tags_data: 标签数据列表,每个元素包含标签名称、描述等信息
+            
+        Returns:
+            Dict[str, Any]: 上传结果,包含成功数量、失败数量等信息
+        """
+        pass
+    
+    @abstractmethod
+    def delete_tag(self, tag_id: str) -> Dict[str, Any]:
+        """
+        删除标签
+        
+        Args:
+            tag_id: 标签ID
+            
+        Returns:
+            Dict[str, Any]: 删除结果
+        """
+        pass
+
+
+class TagServiceImpl(TagService):
+    """标签管理服务实现"""
+    
+    def __init__(self, infinity_client: InfinityClient):
+        """
+        初始化标签服务
+        
+        Args:
+            db_client: 数据库客户端实例
+        """
+        self.tag_dataset_id=tag_search_settings.tag_dataset_id
+        self.tag_document_id=tag_search_settings.tag_document_id
+        self.infinity_client = infinity_client
+        self.chunk_util = ChunkUtil(self.tag_dataset_id, self.tag_document_id)
+    
+    def create_tag(self, tag_data: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        新增标签
+        
+        Args:
+            tag_data: 标签数据,包含标签名称、描述等信息
+            
+        Returns:
+            Dict[str, Any]: 创建成功的标签信息
+        """
+        tag_name = tag_data["name"]
+        tag_desc = tag_data["description"]
+        age_range = tag_data["age_range"]
+        # 步骤1:将分块数据写入ragflow
+        chunk = self.chunk_util.add_chunk(document_id=self.tag_document_id, content=tag_desc, important_keywords=[age_range])
+        chunk_id = chunk["id"]
+        # 步骤2: 调用infinity的update方法,将标签更新到块数据中
+        res = self.infinity_client.update(f"id = {chunk_id}", {"tag_kwd": tag_name})
+        if res["code"] != 0:
+            raise Exception(f"更新标签到infinity失败: {res}")
+        
+        # 返回创建成功的标签信息
+        return {
+            "name": tag_name,
+            "description": tag_desc,
+            "age_range": age_range,
+            "chunk_id": chunk_id,
+            "ragflow_chunk": chunk,
+            "infinity_update_result": res
+        }
+    
+    def upload_tags(self, tags_data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        """
+        批量上传标签
+        
+        Args:
+            tags_data: 标签数据列表,每个元素包含一级标签、二级标签、三级标签、标签描述等信息
+            
+        Returns:
+            Dict[str, Any]: 上传结果,包含成功数量、失败数量等信息
+        """
+        success_count = 0
+        fail_count = 0
+        failed_tags = []
+
+        tag_data_list = self.get_tag_data(tags_data)
+        # 标签数据入库
+        for tag_data in tag_data_list:
+            tag_name = tag_data["name"]
+            tag_desc = tag_data["description"]
+            age_range = tag_data["age_range"]
+            # 步骤1:将分块数据写入ragflow
+            chunk = self.chunk_util.add_chunk(dataset_name="标签", document_id=self.tag_document_id, content=tag_desc, important_keywords=[age_range])
+            print(f"分块数据写入成功, chunk_id: {chunk.id}")
+            # 步骤2: 调用infinity的update方法,将标签更新到块数据中
+            res = self.infinity_client.update(table_name="ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002", 
+                                              cond=f"id = '{chunk.id}'", 
+                                              data={"tag_kwd": tag_name},
+                                              database_name="default_db")
+            print(f"更新标签 {tag_name} 到 infinity 结果: {res}")
+            if res.error_code == 0:
+                print(f"标签 {tag_name} 更新到 infinity 成功")
+                success_count += 1
+            else:
+                fail_count += 1
+                failed_tags.append({"age_range": age_range, "tag_name": tag_name, "tag_desc": tag_desc, "error": res})
+        return {
+            "success": True,
+            "total": len(tags_data),
+            "success_count": success_count,
+            "fail_count": fail_count,
+            "failed_tags": failed_tags
+        }
+    
+    def delete_tag(self, tag_id: str) -> Dict[str, Any]:
+        """
+        删除标签
+        
+        Args:
+            tag_id: 标签ID
+            
+        Returns:
+            Dict[str, Any]: 删除结果
+        """
+        pass
+
+    def get_tag_data(self, tags_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+        import json
+        import os
+        
+        # 读取年龄段配置文件
+        age_level_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "conf", "age_level.json")
+        with open(age_level_path, "r", encoding="utf-8") as f:
+            age_level_map = json.load(f)
+        
+        # 存储转换后的标签数据
+        formatted_tags = []
+        
+        # 遍历标签数据列表,验证每个标签数据是否符合要求,并按照要求格式输出
+        for tag_data in tags_data:
+            # 提取标签数据
+            age_level_code = tag_data.get("年龄分级", "")  # 年龄分级编码
+            one_tag = tag_data.get("一级标签", "")  # 一级标签
+            two_tag = tag_data.get("二级标签", "")  # 二级标签
+            three_tag = tag_data.get("三级标签", "")  # 三级标签
+            tag_desc = tag_data.get("标签描述", "")  # 标签描述
+                
+            # 生成标签名称:一级标签_二级标签_三级标签
+            tag_name = f"{one_tag}_{two_tag}_{three_tag}"
+                
+            # 获取年龄段
+            age_range = age_level_map.get(age_level_code, "")
+                
+            # 构建最终标签数据
+            formatted_tag_data = {
+                "name": tag_name,
+                "description": tag_desc,
+                "age_range": age_range
+            }
+            
+            # 将转换后的标签数据添加到列表中
+            formatted_tags.append(formatted_tag_data)
+        
+        # 返回转换后的标签数据集合
+        return formatted_tags
+
+class TagServiceFactory:
+    """标签服务工厂类"""
+    
+    @staticmethod
+    def create_tag_service(db_client) -> TagService:
+        """
+        创建标签服务实例
+        
+        Args:
+            db_client: 数据库客户端实例
+            
+        Returns:
+            TagService: 标签服务实例
+        """
+        return TagServiceImpl(db_client)

+ 7 - 5
api/hybrid_search_mcp.py

@@ -15,7 +15,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 from PIL import Image
 from PIL import Image
 from utils.infinity_util import InfinityVectorDB
 from utils.infinity_util import InfinityVectorDB
 from model.multimodal_embedding import Embedding
 from model.multimodal_embedding import Embedding
-from conf.config import ModelConfig, VectorDBConfig
+from conf.settings import model_settings, ragflow_settings, vector_db_settings
+
+
 
 
 # 初始化fastmcp应用
 # 初始化fastmcp应用
 mcp = FastMCP("Multi_Vector_Search")
 mcp = FastMCP("Multi_Vector_Search")
@@ -25,8 +27,8 @@ vector_db = InfinityVectorDB()
 
 
 # 初始化多模态嵌入模型
 # 初始化多模态嵌入模型
 embedding_model = Embedding(
 embedding_model = Embedding(
-    model_name=ModelConfig.get_multimodal_embedding_model_name(),
-    api_key=ModelConfig.get_dashscope_api_key()
+    model_name=model_settings.multimodal_embedding_model_name,
+    api_key=model_settings.dashscope_api_key
 )
 )
 
 
 @mcp.tool(name="hybrid_search")
 @mcp.tool(name="hybrid_search")
@@ -41,10 +43,10 @@ def hybrid_search(request: Dict[str, Any]) -> Dict[str, Any]:
         image_url = request["image"]
         image_url = request["image"]
         topn = request.get("topn", 2)
         topn = request.get("topn", 2)
         
         
-        print(f"开始混合检索,数据库: {VectorDBConfig.get_infinity_database}, 知识库id: {ModelConfig.get_dataset_id()}, 文本查询: {text_query}, 返回数量: {topn}")
+        print(f"开始混合检索,数据库: {vector_db_settings.infinity_database}, 知识库id: {ragflow_settings.dataset_id}, 文本查询: {text_query}, 返回数量: {topn}")
         
         
         # 构建索引名称
         # 构建索引名称
-        index_name = f"pdf_documents_{ModelConfig.get_dataset_id()}"        
+        index_name = f"pdf_documents_{ragflow_settings.dataset_id}"        
         print(f"开始生成多模态嵌入,文本长度: {len(text_query)}")
         print(f"开始生成多模态嵌入,文本长度: {len(text_query)}")
         
         
         # 处理image_url为image: Image.Image
         # 处理image_url为image: Image.Image

+ 95 - 0
api/tag_manage.py

@@ -0,0 +1,95 @@
+# 标签管理API服务
+
+from fastapi import FastAPI, HTTPException, UploadFile, File
+from typing import List, Dict, Any, Optional
+from api.db.services.tag_service import TagServiceFactory
+from utils.infinity import get_client
+from utils.excel_util import excel_util
+from pydantic import BaseModel
+import os
+import tempfile
+
+
+# 创建FastAPI应用
+app = FastAPI(
+    title="标签管理 API",
+    description="标签管理服务,提供标签上传、查询、删除等功能",
+    version="1.0.0"
+)
+
+
+# 请求模型
+class TagCreateRequest(BaseModel):
+    """创建标签请求模型"""
+    name: str  # 标签名称,格式:一级标签_二级标签_三级标签
+    description: str  # 标签描述
+    age_range: str  # 年龄段,如:0-2、2-3等
+
+
+# 1. 标签上传接口
+@app.post("/upload", response_model=Dict[str, Any])
+async def upload_tags(file: UploadFile = File(...)):
+    """
+    批量上传标签接口
+    
+    - **file**: Excel格式的标签文件,包含年龄分级、一级标签、二级标签、三级标签、标签描述字段
+    
+    Excel文件格式要求:
+    - 支持 .xlsx 和 .xls 格式
+    - 第一行为表头,包含:年龄分级、一级标签、二级标签、三级标签、标签描述
+    - 年龄分级字段值应为 L1-L8 之间的一个(如 L1、L2 等)
+    """
+    try:
+        # 验证文件格式
+        file_ext = os.path.splitext(file.filename)[1].lower()
+        if file_ext not in [".xlsx", ".xls"]:
+            raise HTTPException(status_code=400, detail=f"不支持的文件格式 {file_ext},只支持 .xlsx 和 .xls 格式")
+        
+        # 创建临时文件
+        with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
+            temp_file.write(await file.read())
+            temp_file_path = temp_file.name
+        
+        try:
+            # 解析Excel文件获取标签数据
+            tags_data = excel_util.parse_excel(file_path=temp_file_path)
+            
+            # 使用工厂类创建标签服务实例
+            tag_service = TagServiceFactory.create_tag_service(get_client())
+            
+            # 调用批量上传标签方法
+            result = tag_service.upload_tags(tags_data)
+            
+            return {"success": True, "result": result}
+        finally:
+            # 删除临时文件
+            os.unlink(temp_file_path)
+    except HTTPException as e:
+        raise e
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"标签上传失败: {str(e)}")
+
+
+# 2. 新增标签接口
+@app.post("/create", response_model=Dict[str, Any])
+async def create_tag(request: TagCreateRequest):
+    """
+    新增标签接口
+    
+    - **name**: 标签名称,格式为:一级标签_二级标签_三级标签
+    - **description**: 标签描述,详细说明标签的含义和用途
+    - **age_range**: 年龄段,如:0-2、2-3、3-4等
+    """
+    try:
+        # 使用工厂类创建标签服务实例
+        tag_service = TagServiceFactory.create_tag_service(get_client())
+        
+        # 调用创建标签方法
+        tag_data = request.model_dump()
+        result = tag_service.create_tag(tag_data)
+        
+        return {"success": True, "tag": result}
+    except KeyError as e:
+        raise HTTPException(status_code=400, detail=f"缺少必填字段: {str(e)}")
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"创建标签失败: {str(e)}")

+ 11 - 0
common/models/pagination.py

@@ -0,0 +1,11 @@
+from pydantic import BaseModel, Field
+from typing import Optional
+
+class Pagination(BaseModel):
+    """通用分页与过滤模型"""
+    page: int = Field(default=1, ge=1, description="当前页码")
+    page_size: int = Field(default=30, ge=1, le=100, description="每页条数")
+    
+    def to_dict(self):
+        """过滤掉 None 值,转化为 API 要求的字典"""
+        return {k: v for k, v in self.model_dump().items() if v is not None}

+ 10 - 0
conf/age_level.json

@@ -0,0 +1,10 @@
+{
+    "L1": "0-2",
+    "L2": "2-3",
+    "L3": "3-4",
+    "L4": "4-5",
+    "L5": "5-6",
+    "L6": "6-10",
+    "L7": "10-14",
+    "L8": "14-100"
+}

+ 31 - 0
conf/config.py

@@ -33,6 +33,11 @@ class ModelConfig:
     def get_api_key() -> str:
     def get_api_key() -> str:
         """获取模型API密钥"""
         """获取模型API密钥"""
         return os.getenv("API_KEY", "")
         return os.getenv("API_KEY", "")
+
+    @staticmethod
+    def get_rank_model_name() -> str:
+        """获取模型名称"""
+        return os.getenv("RANK_MODEL_NAME", "Qwen/Qwen3-Reranker-0.6B")
     
     
     @staticmethod
     @staticmethod
     def get_model_config() -> dict:
     def get_model_config() -> dict:
@@ -64,6 +69,9 @@ class ModelConfig:
         """获取DASHSCOPE API密钥"""
         """获取DASHSCOPE API密钥"""
         return os.getenv("DASHSCOPE", "")
         return os.getenv("DASHSCOPE", "")
 
 
+class RagflowConfig:
+    """RAGFLOW配置类"""
+
     # RAGFLOW配置
     # RAGFLOW配置
     @staticmethod
     @staticmethod
     def get_ragflow_api_url() -> str:
     def get_ragflow_api_url() -> str:
@@ -80,6 +88,16 @@ class ModelConfig:
         """获取数据集ID"""
         """获取数据集ID"""
         return os.getenv("DATASET_ID", "")
         return os.getenv("DATASET_ID", "")
 
 
+    @staticmethod
+    def get_ragflow_user_name() -> str:
+        """获取RAGFLOW用户名"""
+        return os.getenv("RAGFLOW_USER_NAME", "")
+    
+    @staticmethod
+    def get_ragflow_passwd() -> str:
+        """获取RAGFLOW密码"""
+        return os.getenv("RAGFLOW_PASSWD", "")
+
 class AppConfig:
 class AppConfig:
     """应用配置类"""
     """应用配置类"""
     
     
@@ -159,6 +177,19 @@ class VectorDBConfig:
         """获取Infinity向量数据库表名"""
         """获取Infinity向量数据库表名"""
         return os.getenv("INFINITY_TABLE_NAME", "test")
         return os.getenv("INFINITY_TABLE_NAME", "test")
 
 
+class TagSearchConfig:
+    """标签搜索配置类"""
+
+    @staticmethod
+    def get_tag_db_name() -> str:
+        """获取标签数据库名称"""
+        return os.getenv("TAG_DB_NAME", "tag_db")
+    
+    @staticmethod
+    def get_tag_table_name() -> str:
+        """获取标签数据库表名"""
+        return os.getenv("TAG_TABLE_NAME", "tag_table")
+
 # 导出配置实例
 # 导出配置实例
 model_config = ModelConfig.get_model_config()
 model_config = ModelConfig.get_model_config()
 app_config = {
 app_config = {

+ 123 - 0
conf/settings.py

@@ -0,0 +1,123 @@
+"""配置管理模块,使用 pydantic-settings 从.env文件加载环境变量"""
+
+from pydantic_settings import BaseSettings, SettingsConfigDict
+from pydantic import Field
+
+class AppSettings(BaseSettings):
+    """应用配置类"""
+    log_level: str = Field(default="INFO", alias="LOG_LEVEL")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class ModelSettings(BaseSettings):
+    """模型配置类"""
+    model_provider: str = Field(default="openai", alias="MODEL_PROVIDER")
+    model_name: str = Field(default="Qwen/Qwen3-VL-8B-Instruct", alias="MODEL_NAME")
+    embedding_model_name: str = Field(default="Qwen/Qwen3-Embedding-0.6B", alias="EMBEDDING_MODEL_NAME")
+    base_url: str = Field(default="https://api.openai.com/v1", alias="BASE_URL")
+    api_key: str = Field(default="", alias="API_KEY")
+    rank_model_name: str = Field(default="Qwen/Qwen3-Reranker-0.6B", alias="RANK_MODEL_NAME")
+    multimodal_embedding_model_name: str = Field(default="qwen2.5-vl-embedding", alias="MULTIMODAL_EMBEDDING_MODEL_NAME")
+    dashscope: str = Field(default="", alias="DASHSCOPE")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class RagflowSettings(BaseSettings):
+    """RAGFLOW配置类"""
+    ragflow_api_url: str = Field(default="http://192.168.16.134/", alias="RAGFLOW_API_URL")
+    ragflow_api_key: str = Field(default="", alias="RAGFLOW_API_KEY")
+    dataset_id: str = Field(default="", alias="DATASET_ID")
+    ragflow_user_name: str = Field(default="", alias="RAGFLOW_USER_NAME")
+    ragflow_passwd: str = Field(default="", alias="RAGFLOW_PASSWD")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class MinioSettings(BaseSettings):
+    """MinIO配置类"""
+    minio_endpoint: str = Field(default="http://localhost:9000", alias="MINIO_ENDPOINT")
+    minio_access_key: str = Field(default="minioadmin", alias="MINIO_ACCESS_KEY")
+    minio_secret_key: str = Field(default="minioadmin", alias="MINIO_SECRET_KEY")
+    minio_bucket_name: str = Field(default="ragflow", alias="MINIO_BUCKET_NAME")
+    minio_secure: bool = Field(default=False, alias="MINIO_SECURE")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class VectorDBSettings(BaseSettings):
+    """向量数据库配置类"""
+    vector_db_type: str = Field(default="es", alias="VECTOR_DB_TYPE")
+    infinity_host: str = Field(default="192.168.16.134", alias="INFINITY_HOST")
+    infinity_port: int = Field(default=23820, alias="INFINITY_PORT")
+    infinity_sdk_port: int = Field(default=23817, alias="INFINITY_SDK_PORT")
+    infinity_user: str = Field(default="admin", alias="INFINITY_USER")
+    infinity_password: str = Field(default="admin", alias="INFINITY_PASSWORD")
+    infinity_database: str = Field(default="test", alias="INFINITY_DATABASE")
+    infinity_table_name: str = Field(default="test", alias="INFINITY_TABLE_NAME")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+class MysqlSettings(BaseSettings):
+    """MySQL配置类"""
+    mysql_host: str = Field(default="localhost", alias="MYSQL_HOST")
+    mysql_port: int = Field(default=3306, alias="MYSQL_PORT")
+    mysql_user: str = Field(default="root", alias="MYSQL_USER")
+    mysql_password: str = Field(default="", alias="MYSQL_PASSWORD")
+    mysql_database: str = Field(default="", alias="MYSQL_DATABASE")
+    mysql_charset: str = Field(default="utf8mb4", alias="MYSQL_CHARSET")
+    mysql_pool_size: int = Field(default=5, alias="MYSQL_POOL_SIZE")
+
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+
+class TagSearchSettings(BaseSettings):
+    """标签搜索配置类"""
+    tag_db_name: str = Field(default="tag_db", alias="TAG_DB_NAME")
+    tag_table_name: str = Field(default="tag_table", alias="TAG_TABLE_NAME")
+    tag_document_id: str = Field(default="", alias="TAG_DOCUMENT_ID")
+    tag_dataset_id: str = Field(default="", alias="TAG_DATASET_ID")
+    
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+        extra="ignore"
+    )
+
+
+# 创建配置实例
+model_settings = ModelSettings()
+ragflow_settings = RagflowSettings()
+app_settings = AppSettings()
+minio_settings = MinioSettings()
+vector_db_settings = VectorDBSettings()
+mysql_settings = MysqlSettings()
+tag_search_settings = TagSearchSettings()

+ 0 - 0
ee.json


+ 18 - 3
main.py

@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager
 
 
 # 导入所有子应用
 # 导入所有子应用
 from api.search_infinity import app as search_app
 from api.search_infinity import app as search_app
+from api.tag_manage import app as tag_app
 
 
 # 定义主应用的生命周期管理
 # 定义主应用的生命周期管理
 @asynccontextmanager
 @asynccontextmanager
@@ -12,15 +13,26 @@ async def main_lifespan(app: FastAPI):
     """主应用生命周期管理"""
     """主应用生命周期管理"""
     from utils.infinity import get_client, close_client
     from utils.infinity import get_client, close_client
     print("=== Infinity API Gateway 启动 ===")
     print("=== Infinity API Gateway 启动 ===")
-    # 1. 初始化全局客户端(在服务启动时)
+    # 1. 初始化Infinity全局客户端(在服务启动时)
     get_client(database="book_image_db")
     get_client(database="book_image_db")
     print("✅ Infinity客户端已初始化")
     print("✅ Infinity客户端已初始化")
+    
+    # 2. 初始化MySQL全局客户端
+    from utils.mysql import init_global_mysql_client
+    init_global_mysql_client()
+    print("✅ MySQL客户端已初始化")
+    
     yield
     yield
 
 
     print("=== Infinity API Gateway 关闭 ===")
     print("=== Infinity API Gateway 关闭 ===")
-     # 2. 关闭全局客户端(在服务关闭时)
+    # 1. 关闭Infinity全局客户端(在服务关闭时)
     close_client()
     close_client()
     print("✅ Infinity客户端已关闭")
     print("✅ Infinity客户端已关闭")
+    
+    # 2. 关闭MySQL全局客户端
+    from utils.mysql import close_global_mysql_client
+    close_global_mysql_client()
+    print("✅ MySQL客户端已关闭")
 
 
 # 创建主应用
 # 创建主应用
 main_app = FastAPI(
 main_app = FastAPI(
@@ -33,6 +45,8 @@ main_app = FastAPI(
 # 挂载子应用
 # 挂载子应用
 # 1. 搜索 API - 访问路径: /search/*
 # 1. 搜索 API - 访问路径: /search/*
 main_app.mount("/search", search_app, name="search_api")
 main_app.mount("/search", search_app, name="search_api")
+# 2. 标签管理 API - 访问路径: /tag/*
+main_app.mount("/tag", tag_app, name="tag_api")
 
 
 # 主应用根路径
 # 主应用根路径
 @main_app.get("/")
 @main_app.get("/")
@@ -42,7 +56,8 @@ async def root():
         "message": "Welcome to GRAPH_RAG API Gateway",
         "message": "Welcome to GRAPH_RAG API Gateway",
         "available_apps": {
         "available_apps": {
             "search_api": "访问路径: /search, 文档: /search/docs",
             "search_api": "访问路径: /search, 文档: /search/docs",
-            "hybrid_http_api": "访问路径: /hybrid, 文档: /hybrid/docs"
+            "hybrid_http_api": "访问路径: /hybrid, 文档: /hybrid/docs",
+            "tag_api": "访问路径: /tag, 文档: /tag/docs"
         }
         }
     }
     }
 
 

+ 133 - 0
model/jina_rerank.py

@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+import requests
+from langchain_core.callbacks import Callbacks
+from langchain_core.documents import BaseDocumentCompressor, Document
+from langchain_core.utils import get_from_dict_or_env
+from pydantic import ConfigDict, model_validator
+
+JINA_API_URL: str = "https://api.jina.ai/v1/rerank"
+
+
+class JinaRerank(BaseDocumentCompressor):
+    """Document compressor that uses `Jina Rerank API` with support for custom base_url."""
+
+    session: Any = None
+    """Requests session to communicate with API."""
+    top_n: Optional[int] = 3
+    """Number of documents to return."""
+    model: str = "jina-reranker-v1-base-en"
+    """Model to use for reranking."""
+    jina_api_key: Optional[str] = None
+    """Jina API key. Must be specified directly or via environment variable 
+        JINA_API_KEY."""
+    user_agent: str = "langchain"
+    """Identifier for the application making the request."""
+    base_url: str = JINA_API_URL
+    """Base URL for the Jina API. Defaults to JINA_API_URL."""
+
+    model_config = ConfigDict(
+        arbitrary_types_allowed=True,
+        extra="forbid",
+    )
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_environment(cls, values: Dict) -> Any:
+        """Validate that api key exists in environment."""
+        jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
+        user_agent = values.get("user_agent", "langchain")
+        session = requests.Session()
+        session.headers.update(
+            {
+                "Authorization": f"Bearer {jina_api_key}",
+                "Accept-Encoding": "identity",
+                "Content-type": "application/json",
+                "user-agent": user_agent,
+            }
+        )
+        values["session"] = session
+        return values
+
+    def rerank(
+        self,
+        documents: Sequence[Union[str, Document, dict]],
+        query: str,
+        *,
+        model: Optional[str] = None,
+        top_n: Optional[int] = -1,
+        max_chunks_per_doc: Optional[int] = None,
+    ) -> List[Dict[str, Any]]:
+        """Returns an ordered list of documents ordered by their relevance to the provided query.
+
+        Args:
+            query: The query to use for reranking.
+            documents: A sequence of documents to rerank.
+            model: The model to use for re-ranking. Default to self.model.
+            top_n : The number of results to return. If None returns all results.
+                Defaults to self.top_n.
+            max_chunks_per_doc : The maximum number of chunks derived from a document.
+        """  # noqa: E501
+        if len(documents) == 0:  # to avoid empty api call
+            return []
+        docs = [
+            doc.page_content if isinstance(doc, Document) else doc for doc in documents
+        ]
+        model = model or self.model
+        top_n = top_n if (top_n is None or top_n > 0) else self.top_n
+        data = {
+            "query": query,
+            "documents": docs,
+            "model": model,
+            "top_n": top_n,
+        }
+
+        # 构建完整的API路径
+        api_url = self.base_url.rstrip('/') + '/rerank'
+        resp = self.session.post(
+            api_url,
+            json=data,
+        ).json()
+
+        if "results" not in resp:
+            raise RuntimeError(resp["detail"])
+
+        results = resp["results"]
+        result_dicts = []
+        for res in results:
+            result_dicts.append(
+                {
+                    "index": res["index"],
+                    "relevance_score": res["relevance_score"],
+                }
+            )
+        return result_dicts
+
+    def compress_documents(
+        self,
+        documents: Sequence[Document],
+        query: str,
+        top_n: Optional[int] = None,
+        callbacks: Optional[Callbacks] = None,
+    ) -> Sequence[Document]:
+        """
+        Compress documents using Jina's Rerank API.
+
+        Args:
+            documents: A sequence of documents to compress.
+            query: The query to use for compressing the documents.
+            callbacks: Callbacks to run during the compression process.
+
+        Returns:
+            A sequence of compressed documents.
+        """
+        compressed = []
+        for res in self.rerank(documents=documents, query=query, top_n=top_n):
+            doc = documents[res["index"]]
+            doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
+            doc_copy.metadata["relevance_score"] = res["relevance_score"]
+            compressed.append(doc_copy)
+        return compressed

+ 30 - 20
model/multimodal_embedding.py

@@ -4,8 +4,9 @@ import base64
 import io
 import io
 from langchain_openai import OpenAIEmbeddings
 from langchain_openai import OpenAIEmbeddings
 from dashscope import MultiModalEmbedding
 from dashscope import MultiModalEmbedding
-from conf.config import ModelConfig
-from utils.file.image_util import image_util as ImageUtil
+from conf.settings import model_settings
+from langfuse import observe
+from utils.file.image_util import image_util
 
 
 class Embedding:
 class Embedding:
     """Embedding模型工具"""
     """Embedding模型工具"""
@@ -19,18 +20,12 @@ class Embedding:
             api_key: API密钥,若为None则使用配置文件中的值
             api_key: API密钥,若为None则使用配置文件中的值
         """
         """
         # 获取模型配置
         # 获取模型配置
-        self.model_provider = ModelConfig.get_model_provider()
-        self.model_name = model_name or ModelConfig.get_model_name()
-        self.base_url = ModelConfig.get_base_url()
-        self.api_key = api_key or ModelConfig.get_api_key()
-        
-        # 使用langchain_openai初始化OpenAI Embeddings模型
-        self.embeddings = OpenAIEmbeddings(
-            model=self.model_name,
-            base_url=self.base_url,
-            api_key=self.api_key
-        )
+        self.model_provider = model_settings.model_provider
+        self.model_name = model_name or model_settings.model_name
+        self.base_url = model_settings.base_url
+        self.api_key = api_key or model_settings.api_key
     
     
+    @observe(name="text_embedding", as_type="embedding")
     def get_text_embedding(self, text: str) -> List[float]:
     def get_text_embedding(self, text: str) -> List[float]:
         """
         """
         获取文本的embedding
         获取文本的embedding
@@ -42,11 +37,18 @@ class Embedding:
             List[float]: 文本的embedding向量
             List[float]: 文本的embedding向量
         """
         """
         try:
         try:
-            embedding = self.embeddings.embed_query(text)
+            # 使用langchain_openai初始化OpenAI Embeddings模型
+            embeddings = OpenAIEmbeddings(
+                model=self.model_name,
+                base_url=self.base_url,
+                api_key=self.api_key
+            )
+            embedding = embeddings.embed_query(text)
             return embedding
             return embedding
         except Exception as e:
         except Exception as e:
             raise Exception(f"文本embedding生成失败: {str(e)}")
             raise Exception(f"文本embedding生成失败: {str(e)}")
     
     
+    @observe(name="texts_embedding", as_type="embedding")
     def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
     def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
         """
         """
         获取多个文本的embedding
         获取多个文本的embedding
@@ -58,11 +60,18 @@ class Embedding:
             List[List[float]]: 文本列表的embedding向量列表
             List[List[float]]: 文本列表的embedding向量列表
         """
         """
         try:
         try:
-            embeddings = self.embeddings.embed_documents(texts)
+            # 使用langchain_openai初始化OpenAI Embeddings模型
+            embeddings = OpenAIEmbeddings(
+                model=self.model_name,
+                base_url=self.base_url,
+                api_key=self.api_key
+            )
+            embeddings = embeddings.embed_documents(texts)
             return embeddings
             return embeddings
         except Exception as e:
         except Exception as e:
             raise Exception(f"多个文本embedding生成失败: {str(e)}")
             raise Exception(f"多个文本embedding生成失败: {str(e)}")
     
     
+    @observe(name="image_embedding", as_type="embedding")
     def get_image_embedding(self, image: Image.Image) -> List[float]:
     def get_image_embedding(self, image: Image.Image) -> List[float]:
         """
         """
         获取图像的embedding
         获取图像的embedding
@@ -80,7 +89,7 @@ class Embedding:
             buffer.seek(0)
             buffer.seek(0)
             
             
             # 压缩图片字节流
             # 压缩图片字节流
-            compressed_bytes = ImageUtil._compress_image_to_bytes(buffer)
+            compressed_bytes = image_util._compress_image_to_bytes(buffer)
             
             
             image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
             image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
             # 构建输入项
             # 构建输入项
@@ -98,7 +107,8 @@ class Embedding:
                 return response.output["embeddings"][0]["embedding"]
                 return response.output["embeddings"][0]["embedding"]
         except Exception as e:
         except Exception as e:
             raise Exception(f"图像embedding生成失败: {str(e)}")
             raise Exception(f"图像embedding生成失败: {str(e)}")
-    
+
+    @observe(name="multimodal_embedding", as_type="embedding")
     def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
     def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
         """
         """
         获取多模态(文本+图像)的embedding
         获取多模态(文本+图像)的embedding
@@ -118,7 +128,7 @@ class Embedding:
                 buffer.seek(0)
                 buffer.seek(0)
             
             
                 # 压缩图片字节流
                 # 压缩图片字节流
-                compressed_bytes = ImageUtil()._compress_image_to_bytes(buffer)
+                compressed_bytes = image_util._compress_image_to_bytes(buffer)
             
             
                 image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
                 image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
                 item.append({'image': f"data:image/png;base64,{image_base64}"})
                 item.append({'image': f"data:image/png;base64,{image_base64}"})
@@ -148,6 +158,6 @@ def get_embedding_model() -> Embedding:
         Embedding: 全局单例的Embedding模型
         Embedding: 全局单例的Embedding模型
     """
     """
     return Embedding(
     return Embedding(
-        model_name=ModelConfig.get_multimodal_embedding_model_name(),
-        api_key=ModelConfig.get_dashscope_api_key()
+        model_name=model_settings.embedding_model_name,
+        api_key=model_settings.api_key
     )
     )

+ 8 - 7
model/qwen_vl.py

@@ -3,7 +3,8 @@ from PIL import Image
 import base64
 import base64
 import io
 import io
 from langchain.chat_models import init_chat_model
 from langchain.chat_models import init_chat_model
-from conf.config import ModelConfig
+from conf.settings import model_settings
+from langfuse.callback import CallbackHandler
 
 
 class QWenVLParser:
 class QWenVLParser:
     """QWEN VL模型图像解析工具"""
     """QWEN VL模型图像解析工具"""
@@ -16,11 +17,11 @@ class QWenVLParser:
             model_name: 模型名称,若为None则使用配置文件中的值
             model_name: 模型名称,若为None则使用配置文件中的值
         """
         """
         # 获取模型配置
         # 获取模型配置
-        self.model_provider = ModelConfig.get_model_provider()
-        self.model_name = model_name or ModelConfig.get_model_name()
-        self.base_url = ModelConfig.get_base_url()
-        self.api_key = ModelConfig.get_api_key()
-        
+        self.model_provider = model_settings.model_provider
+        self.model_name = model_name or model_settings.model_name
+        self.base_url = model_settings.base_url
+        self.api_key = model_settings.api_key
+        self.langfuse_handler = CallbackHandler()
         # 使用langchain的init_chat_model初始化模型
         # 使用langchain的init_chat_model初始化模型
         self.model = init_chat_model(
         self.model = init_chat_model(
             model_provider=self.model_provider,
             model_provider=self.model_provider,
@@ -82,7 +83,7 @@ class QWenVLParser:
             ]
             ]
             
             
             # 使用langchain模型调用OpenAI API
             # 使用langchain模型调用OpenAI API
-            response = self.model.invoke(messages)
+            response = self.model.invoke(input=messages, config={"callbacks": [self.langfuse_handler]})
             
             
             # 提取解析结果
             # 提取解析结果
             content = response.content
             content = response.content

+ 21 - 0
model/tracked_multi_embedding.py

@@ -0,0 +1,21 @@
+from model.multimodal_embedding import Embedding
+from PIL import Image
+from utils.decorators.langfuse_trace_embedding import langfuse_trace_embedding
+
+class TrackedMultiEmbeddings(Embedding):
+    
+    @langfuse_trace_embedding(name="multi-embed-query")
+    def embed_query(self, text: str):
+        return super().get_text_embedding(text)
+
+    @langfuse_trace_embedding(name="multi-embed-documents")
+    def embed_documents(self, texts: list[str]):
+        return super().get_texts_embedding(texts)
+
+    @langfuse_trace_embedding(name="multi-embed-image")
+    def embed_image(self, image: Image):
+        return super().get_image_embedding(image)
+
+    @langfuse_trace_embedding(name="multi-embed-multimodal")
+    def embed_multimodal(self, text: str, image: Image):
+        return super().get_multimodal_embedding(text, image)

+ 12 - 0
model/tracked_openai_embeddings.py

@@ -0,0 +1,12 @@
+from langchain_openai import OpenAIEmbeddings
+from utils.decorators.langfuse_trace_embedding import langfuse_trace_embedding
+
+class TrackedOpenAIEmbeddings(OpenAIEmbeddings):
+    
+    @langfuse_trace_embedding(name="openai-embed-query")
+    def embed_query(self, text: str):
+        return super().embed_query(text)
+
+    @langfuse_trace_embedding(name="openai-embed-documents")
+    def embed_documents(self, texts: list[str]):
+        return super().embed_documents(texts)

+ 0 - 0
parser/image_parser/__init__.py


+ 2 - 2
workflow/image_parsing_workflow.py → parser/image_parser/image_parser_workflow.py

@@ -19,7 +19,7 @@ from model.qwen_vl import QWenVLParser
 from utils.ragflow.ragflow_service import RAGFlowService
 from utils.ragflow.ragflow_service import RAGFlowService
 from model.multimodal_embedding import Embedding
 from model.multimodal_embedding import Embedding
 from utils.minio.image_util import image_util
 from utils.minio.image_util import image_util
-from conf.config import ModelConfig
+from conf.settings import model_settings
 from utils.infinity import get_client
 from utils.infinity import get_client
 
 
 # 定义工作流状态类
 # 定义工作流状态类
@@ -298,7 +298,7 @@ class ImageParsingWorkflow:
             zip_file_path=zip_file_path,
             zip_file_path=zip_file_path,
             book_name=book_name,
             book_name=book_name,
             dataset_id=dataset_id,
             dataset_id=dataset_id,
-            embedding_model=Embedding(model_name=ModelConfig.get_multimodal_embedding_model_name(), api_key=ModelConfig.get_dashscope_api_key()),
+            embedding_model=Embedding(model_name=model_settings.multimodal_embedding_model_name, api_key=model_settings.dashscope_api_key),
             ragflow_service=RAGFlowService(base_url=ragflow_api_url, api_key=rag_flow_api_key)
             ragflow_service=RAGFlowService(base_url=ragflow_api_url, api_key=rag_flow_api_key)
         )
         )
         result = self.workflow.invoke(initial_state)
         result = self.workflow.invoke(initial_state)

+ 82 - 14
workflow/workflow.py → parser/pdf_parser/pdf_parser_workflow.py

@@ -1,19 +1,15 @@
-import sys
 import os
 import os
 import concurrent.futures
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-# 添加项目根目录到Python路径
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
-
 from langgraph.graph import StateGraph, START, END
 from langgraph.graph import StateGraph, START, END
-from langgraph.graph.message import add_messages
 from typing import List, Dict, Any
 from typing import List, Dict, Any
 from pydantic import BaseModel, Field, ConfigDict
 from pydantic import BaseModel, Field, ConfigDict
 from parser.pdf_parser.pdf_splitter import PDFSplitter
 from parser.pdf_parser.pdf_splitter import PDFSplitter
 from model.qwen_vl import QWenVLParser
 from model.qwen_vl import QWenVLParser
+from utils.ragflow_sdk import DataSetUtil, DocumentUtil, ChunkUtil
 from utils.ragflow.ragflow_service import RAGFlowService
 from utils.ragflow.ragflow_service import RAGFlowService
 from model.multimodal_embedding import Embedding
 from model.multimodal_embedding import Embedding
-from conf.config import ModelConfig, VectorDBConfig
+from conf.settings import model_settings, vector_db_settings
 from utils.infinity import get_client
 from utils.infinity import get_client
 
 
 # 定义工作流状态类
 # 定义工作流状态类
@@ -22,7 +18,11 @@ class PDFParsingState(BaseModel):
     model_config = ConfigDict(arbitrary_types_allowed=True)
     model_config = ConfigDict(arbitrary_types_allowed=True)
     pdf_path: str = Field(..., description="PDF文件路径")
     pdf_path: str = Field(..., description="PDF文件路径")
     dataset_id: str = Field(..., description="数据集ID")
     dataset_id: str = Field(..., description="数据集ID")
-    ragflow_service: RAGFlowService = Field(default_factory=RAGFlowService, description="RAGFLOW服务")
+    page_dataset_id: str = Field(..., description="页面数据集ID")
+    ragflow_service: RAGFlowService = Field(default_factory=RAGFlowService, description="RAGFlow服务实例")
+    dataset_util: DataSetUtil = Field(default_factory=DataSetUtil, description="数据集工具类实例")
+    document_util: DocumentUtil = Field(default_factory=DocumentUtil, description="文档工具类实例")
+    chunk_util: ChunkUtil = Field(default_factory=ChunkUtil, description="文档工具类实例")
     embedding_model: Embedding = Field(default_factory=Embedding, description="多模态嵌入模型实例")
     embedding_model: Embedding = Field(default_factory=Embedding, description="多模态嵌入模型实例")
     document_id: str = Field(default="", description="上传后的文档ID")
     document_id: str = Field(default="", description="上传后的文档ID")
     split_pages: List[Dict[str, Any]] = Field(default_factory=list, description="拆分后的页面列表")
     split_pages: List[Dict[str, Any]] = Field(default_factory=list, description="拆分后的页面列表")
@@ -46,6 +46,7 @@ class PDFParsingWorkflow:
         """
         """
         self.model_name = model_name
         self.model_name = model_name
         self.workflow = self._build_workflow()
         self.workflow = self._build_workflow()
+        
     
     
     def _build_workflow(self):
     def _build_workflow(self):
         """构建langgraph工作流,实现基于条件路由的并行处理"""
         """构建langgraph工作流,实现基于条件路由的并行处理"""
@@ -97,6 +98,37 @@ class PDFParsingWorkflow:
         # 编译工作流
         # 编译工作流
         return graph.compile()
         return graph.compile()
     
     
+    def get_ragflow_dataset(self, dataset_name: str) -> str:
+        """获取RAGFLOW数据集ID"""
+        try:
+            dataset_id = self.dataset_util.get_dataset(name=dataset_name)
+            print(f"数据集 {dataset_name} 的ID为: {dataset_id}")
+            return dataset_id
+        except Exception as e:
+            print(f"获取数据集ID时出错: {str(e)}")
+            raise
+
+    def create_ragflow_dataset(self, state: PDFParsingState, dataset_name: str) -> str:
+        """创建RAGFLOW数据集"""
+        if state.dataset_id:
+            print(f"数据集 {dataset_name} 已存在,数据集ID: {state.dataset_id}")
+            return state.dataset_id
+        
+        print(f"开始创建数据集: {dataset_name}")
+        
+        try: 
+            # 创建数据集
+            dataset_id = self.dataset_util.create_dataset(
+                chunk_method="naive",
+                dataset_name=dataset_name,
+                dataset_desc="",
+            )
+            print(f"数据集创建成功,数据集ID: {dataset_id}")
+            return dataset_id
+        except Exception as e:
+            print(f"创建数据集时出错: {str(e)}")
+            raise
+
     def _upload_document_node(self, state: PDFParsingState) -> Dict[str, Any]:
     def _upload_document_node(self, state: PDFParsingState) -> Dict[str, Any]:
         """RAGFLOW上传文档节点"""
         """RAGFLOW上传文档节点"""
         print(f"开始上传文档到数据集 {state.dataset_id}: {state.pdf_path}")
         print(f"开始上传文档到数据集 {state.dataset_id}: {state.pdf_path}")
@@ -107,13 +139,20 @@ class PDFParsingWorkflow:
                 dataset_id=state.dataset_id,
                 dataset_id=state.dataset_id,
                 file_path=state.pdf_path
                 file_path=state.pdf_path
             )
             )
+            # 上传文档
+            document_info_list2 = state.ragflow_service.upload_document(
+                dataset_id=state.page_dataset_id,
+                file_path=state.pdf_path
+            )
             
             
             # 检查响应
             # 检查响应
             if document_info_list and len(document_info_list) > 0:
             if document_info_list and len(document_info_list) > 0:
                 document_id = document_info_list[0]["id"]
                 document_id = document_info_list[0]["id"]
+                page_document_id = document_info_list2[0]["id"]
                 print(f"文档上传成功,文档ID: {document_id}")
                 print(f"文档上传成功,文档ID: {document_id}")
                 return {
                 return {
-                    "document_id": document_id
+                    "document_id": document_id,
+                    "page_document_id": page_document_id
                 }
                 }
             else:
             else:
                 print("文档上传失败: 未返回有效的文档信息")
                 print("文档上传失败: 未返回有效的文档信息")
@@ -278,12 +317,38 @@ class PDFParsingWorkflow:
         # 所以这里总是返回"complete"
         # 所以这里总是返回"complete"
         return "complete"
         return "complete"
     
     
+    def create_ragflow_chunk(self, state: PDFParsingState):
+        """单页上传节点"""
+        print(f"开始单页上传,共 {len(state.parsed_results)} 页")
+        
+        # 遍历所有解析结果,上传单页
+        for parsed_result in state.parsed_results:
+            page_number = parsed_result.get("page_number")
+            text = parsed_result.get("content", "")
+            image = state.split_pages[page_number - 1].get("image")
+            
+            # 上传单页到RagFlow Chunk
+            chunk = state.chunk_util.add_chunk(
+                dataset_name=state.dataset_name,
+                document_id=state.page_document_id,
+                content=text,
+            )
+
+            infinity_client = get_client()
+            infinity_client.update(database_name=state.dataset_name, table_name="", cond=f"id = {chunk_id}", data={"tag_kwd": tag_name})
+            
+            # 检查响应
+            if document_info and document_info.get("id"):
+                print(f"第 {page_number} 页上传成功,文档ID: {document_info['id']}")
+            else:
+                print(f"第 {page_number} 页上传失败")
+
     def _vectorize_store_node(self, state: PDFParsingState) -> Dict[str, Any]:
     def _vectorize_store_node(self, state: PDFParsingState) -> Dict[str, Any]:
         """向量化入库节点"""
         """向量化入库节点"""
         print(f"开始向量化入库,共 {len(state.parsed_results)} 页")
         print(f"开始向量化入库,共 {len(state.parsed_results)} 页")
         
         
         # 创建索引(如果不存在)
         # 创建索引(如果不存在)
-        index_name = f"{VectorDBConfig.get_infinity_table_name()}"
+        index_name = f"{vector_db_settings.infinity_table_name}"
         state.vector_db.create_index(index_name)
         state.vector_db.create_index(index_name)
         
         
         # 准备要入库的文档列表
         # 准备要入库的文档列表
@@ -349,13 +414,13 @@ class PDFParsingWorkflow:
             "is_complete": True
             "is_complete": True
         }
         }
     
     
-    def run(self, pdf_path: str, dataset_id: str, ragflow_api_url: str, rag_flow_api_key: str) -> Dict[str, Any]:
+    def run(self, pdf_path: str, page_dataset_id: str, ragflow_api_url: str, rag_flow_api_key: str) -> Dict[str, Any]:
         """
         """
         运行PDF解析工作流
         运行PDF解析工作流
         
         
         Args:
         Args:
             pdf_path: PDF文件路径
             pdf_path: PDF文件路径
-            dataset_id: 数据集ID
+            page_dataset_id: 数据集ID
             ragflow_api_url: RAGFLOW API URL
             ragflow_api_url: RAGFLOW API URL
             rag_flow_api_key: RAGFLOW API密钥
             rag_flow_api_key: RAGFLOW API密钥
             
             
@@ -364,9 +429,12 @@ class PDFParsingWorkflow:
         """
         """
         initial_state = PDFParsingState(
         initial_state = PDFParsingState(
             pdf_path=pdf_path,
             pdf_path=pdf_path,
-            dataset_id=dataset_id,
-            embedding_model=Embedding(model_name=ModelConfig.get_multimodal_embedding_model_name(), api_key=ModelConfig.get_dashscope_api_key()),
-            ragflow_service=RAGFlowService(base_url=ragflow_api_url, api_key=rag_flow_api_key)
+            page_dataset_id=page_dataset_id,
+            embedding_model=Embedding(model_name=model_settings.multimodal_embedding_model_name, api_key=model_settings.dashscope_api_key),
+            dataset_util=DataSetUtil(),
+            document_util=DocumentUtil(),
+            chunk_util=ChunkUtil(),
+            ragflow_service=RAGFlowService(api_url=ragflow_api_url, api_key=rag_flow_api_key)
         )
         )
         result = self.workflow.invoke(initial_state)
         result = self.workflow.invoke(initial_state)
         
         

+ 27 - 0
private_key.pem

@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEAk737JTysaX2ZnsMtvU8uOB8sSTqoL4/z7hcO01IwYDFTZPUO
+UeBcTbwsIxoyh3IQReW/SL9kay3ZeZzg3xRZYaiKkMVtbnu1ijt3/t4/K+rmIBs/
+8Lg9vpJGVyY43rPlaJP0IlrttnGAOiurDSnSM3iJN3BiZsdBz5fMgEShp6PQpWpJ
+ggOjK+2wuMlI1GP2nrJs/9/FN5yIcoqbJtxOkXiHTRLvn/6V0iwO0KR8n5WWUseA
+ziVtIEDTZxIGkYyfK1DrRP3WZYmXS3gDbV5wqtNDLhnrbYvarZA6ntC1aAa2S8YA
+Ey3dgGee66547mF8nML7qWzxcVrLpRGYdcz0wQIDAQABAoIBACknHyC1MNVFjhfi
+CAiWWKp7TGutchGliG3GcHO+jgkNVExn0jBOTYe9pZG1h2OhCxVdNJX/8enE4LH2
+V9DrqyX/u3Et/RhsTxoNDJSqtW/3Q1ioHnxf5ZsCxaMa23Z1J66kn7k+GL7cCLqr
+VT8aahm6NoaBwHceg59eijEkbN7z3vEdpcoM7Y/T9vsIiRxof7XbIkoMAR14ZZsF
+51ht/Uw6lcY+ZBil4m+PSGkTF00WrTYYySVpLNeOg7gHOZhCvkegu5ZjAgJf4468
++eG6OW57UP+Fvx1cy1YY+2rO1FiNclmcNKEj6m8kvbcsGDI84NzXx2Z5UNMehTZm
+TsaJBG0CgYEAzgpzsPK/NyZmLwmoty9h+7cIGsVXGDbJBSbHDKLbt1B/QEtUGvVZ
+MdilQBWmYPw74F+dyl/pOiaXppEwwCej/5aIQVc2LruekOj2nlC6aAE+vI1fFJwn
+MPWqHgQKv3Q0oC/0jOA4WfDAuPMAlWgCmPOHX+dpn60Ehd2Gg9jzdGsCgYEAt5DC
+hx5ag7X3YAqm16o63ztDDLmyA9SWuD41wmXqxiVANV4aLOb5xOKhTGePN+vL4uYF
+3XD+SBWeaksIlp8jKgQ6lo+E7ODvFCWp7anUE4joqRYDLX76YfYaV55ftn2QG13g
+FlSJrcHW/LDW2K4VljZ/RMe9EjvwZpwXTLHxpoMCgYBljPc6dlpSZ0Wj12Hlb6Ae
+fTasYykrEQ2A0Xz3uGkMb8tYOAhrBRQCImJINKN9V7R3QQk3TpBYDd9exF6/xonp
+FhtwAPj+jZB4oamzm9xUGJEAQ7liIhia4+oFwz484Jxze09QA88iXSvPvE1EVjA5
+jIXjed4wSglZpTaAU3SiowKBgQCq6eH1pNFqKRc3GTPXQeymxKrdQlj5MAtVaLbt
+Va+O5kguaHu+qyCLrOe6GAJFcwrXLIzxba+UtmxVU5+ffnRjPqS+YsrFM3nnLFpX
+qvSHa8YVCSXUSpIee+JzCkbbx3VRxcpLgFLgscbWaGDiRtRalEtTYxYr0Bs5rsQS
+6OuyuQKBgBtf5/D46K+gnzEHKxvrRNSWvzkWa9t89v/N6ndOhfTPanWfvwCeqEzQ
+nnWfYaQkB3vG54Rzv0pW7/UJnPxRLg6OuCaYBjYYXR7rzyBXuHotAH7MNf8nLiC1
+0ymLnAopyhJK0K17jzPh4vIMNFcjYMnVg7/NB825Ufj88VRow5TE
+-----END RSA PRIVATE KEY-----

+ 0 - 0
prompt/parser/page_parse_prompt.py


+ 9 - 0
public_key.pem

@@ -0,0 +1,9 @@
+-----BEGIN PUBLIC KEY-----
+MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAk737JTysaX2ZnsMtvU8u
+OB8sSTqoL4/z7hcO01IwYDFTZPUOUeBcTbwsIxoyh3IQReW/SL9kay3ZeZzg3xRZ
+YaiKkMVtbnu1ijt3/t4/K+rmIBs/8Lg9vpJGVyY43rPlaJP0IlrttnGAOiurDSnS
+M3iJN3BiZsdBz5fMgEShp6PQpWpJggOjK+2wuMlI1GP2nrJs/9/FN5yIcoqbJtxO
+kXiHTRLvn/6V0iwO0KR8n5WWUseAziVtIEDTZxIGkYyfK1DrRP3WZYmXS3gDbV5w
+qtNDLhnrbYvarZA6ntC1aAa2S8YAEy3dgGee66547mF8nML7qWzxcVrLpRGYdcz0
+wQIDAQAB
+-----END PUBLIC KEY-----

+ 82 - 0
test/test_asymmetric_encryption.py

@@ -0,0 +1,82 @@
+from utils.asymmetric_encryption import AsymmetricEncryption
+import os
+
+
+def test_asymmetric_encryption():
+    """测试非对称加密功能"""
+    print("开始测试非对称加密功能...")
+    
+    # 测试1:生成密钥对
+    print("\n1. 生成密钥对测试...")
+    private_pem, public_pem = AsymmetricEncryption.generate_key_pair()
+    print(f"私钥长度: {len(private_pem)} 字节")
+    print(f"公钥长度: {len(public_pem)} 字节")
+    print(f"私钥前100字节: {private_pem[:100]!r}")
+    print(f"公钥前100字节: {public_pem[:100]!r}")
+    
+    # 测试2:加密解密测试
+    print("\n2. 加密解密测试...")
+    test_message = "这是一个测试消息,用于测试非对称加密功能!"
+    print(f"原始消息: {test_message}")
+    
+    # 加密
+    encrypted = AsymmetricEncryption.encrypt(test_message, public_pem)
+    print(f"加密后: {encrypted}")
+    
+    # 解密
+    decrypted = AsymmetricEncryption.decrypt(encrypted, private_pem)
+    print(f"解密后: {decrypted}")
+    
+    # 验证
+    assert decrypted == test_message, "解密失败,结果与原始消息不符!"
+    print("✓ 加密解密测试通过!")
+    
+    # 测试3:密钥文件保存和加载测试
+    print("\n3. 密钥文件保存和加载测试...")
+    private_key_path = "private_key.pem"
+    public_key_path = "public_key.pem"
+    
+    # 保存密钥
+    AsymmetricEncryption.save_key_to_file(private_pem, private_key_path)
+    AsymmetricEncryption.save_key_to_file(public_pem, public_key_path)
+    print(f"✓ 密钥已保存到文件: {private_key_path}, {public_key_path}")
+    
+    # # 加载密钥
+    # loaded_private_pem = AsymmetricEncryption.load_key_from_file(private_key_path)
+    # loaded_public_pem = AsymmetricEncryption.load_key_from_file(public_key_path)
+    # print("✓ 密钥已从文件加载")
+    
+    # # 验证加载的密钥是否正确
+    # encrypted2 = AsymmetricEncryption.encrypt(test_message, loaded_public_pem)
+    # decrypted2 = AsymmetricEncryption.decrypt(encrypted2, loaded_private_pem)
+    # assert decrypted2 == test_message, "使用加载的密钥解密失败!"
+    # print("✓ 使用加载的密钥加密解密测试通过!")
+    
+    # # 测试4:生成密钥对并保存测试
+    # print("\n4. 生成密钥对并保存测试...")
+    # private_key_path2 = "test_private_key2.pem"
+    # public_key_path2 = "test_public_key2.pem"
+    
+    # AsymmetricEncryption.generate_key_pair_and_save(private_key_path2, public_key_path2)
+    # print(f"✓ 密钥对已生成并保存到文件: {private_key_path2}, {public_key_path2}")
+    
+    # # 验证生成并保存的密钥
+    # loaded_private_pem2 = AsymmetricEncryption.load_key_from_file(private_key_path2)
+    # loaded_public_pem2 = AsymmetricEncryption.load_key_from_file(public_key_path2)
+    # encrypted3 = AsymmetricEncryption.encrypt(test_message, loaded_public_pem2)
+    # decrypted3 = AsymmetricEncryption.decrypt(encrypted3, loaded_private_pem2)
+    # assert decrypted3 == test_message, "使用生成并保存的密钥解密失败!"
+    # print("✓ 使用生成并保存的密钥加密解密测试通过!")
+    
+    # # 清理临时文件
+    # print("\n5. 清理临时文件...")
+    # for file_path in [private_key_path, public_key_path, private_key_path2, public_key_path2]:
+    #     if os.path.exists(file_path):
+    #         os.remove(file_path)
+    #         print(f"✓ 删除临时文件: {file_path}")
+    
+    # print("\n所有测试通过!非对称加密功能正常工作。")
+
+
+if __name__ == "__main__":
+    test_asymmetric_encryption()

+ 39 - 0
test/test_excel_util.py

@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+"""
+测试 ExcelUtil 类的功能
+"""
+
+from utils.excel_util import ExcelUtil
+
+def test_excel_util():
+    """测试 ExcelUtil 类的功能"""
+    print("测试 ExcelUtil 类的功能...")
+    
+    try:
+        # 测试获取工作表名称方法
+        print("\n1. 测试获取工作表名称方法...")
+        # 注意:这里需要替换为实际的 Excel 文件路径
+        # sheet_names = ExcelUtil.get_sheet_names("test.xlsx")
+        # print(f"工作表名称: {sheet_names}")
+        print("✅ 工作表名称方法已实现")
+        
+        # 测试解析 Excel 方法
+        print("\n2. 测试解析 Excel 方法...")
+        # 注意:这里需要替换为实际的 Excel 文件路径
+        # data = ExcelUtil.parse_excel("test.xlsx")
+        # print(f"解析结果数量: {len(data)}")
+        # if data:
+        #     print(f"第一条数据: {data[0]}")
+        print("✅ 解析 Excel 方法已实现")
+        
+        print("\n🎉 所有测试通过!ExcelUtil 类功能正常。")
+        return True
+        
+    except Exception as e:
+        print(f"\n❌ 测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+if __name__ == "__main__":
+    test_excel_util()

+ 70 - 0
test/test_mysql_config.py

@@ -0,0 +1,70 @@
+"""测试MySQL配置从.env文件读取"""
+
+from utils.mysql.mysql_pool import get_mysql_pool
+from utils.mysql.mysql_conn import get_mysql_conn
+
+
+def test_mysql_pool_from_env():
+    """测试从.env文件读取MySQL连接池配置"""
+    print("=== 测试从.env文件读取MySQL连接池配置 ===")
+    
+    # 使用默认配置创建连接池(应从.env文件读取)
+    pool = get_mysql_pool()
+    print(f"连接池配置 - 主机: {pool.host}")
+    print(f"连接池配置 - 端口: {pool.port}")
+    print(f"连接池配置 - 用户名: {pool.user}")
+    print(f"连接池配置 - 数据库: {pool.database}")
+    print(f"连接池配置 - 字符集: {pool.charset}")
+    print(f"连接池配置 - 连接池大小: {pool.pool_size}")
+    
+    print("\n✓ 从.env文件读取MySQL连接池配置成功!")
+
+
+def test_mysql_conn_from_env():
+    """测试从.env文件读取MySQL连接配置"""
+    print("\n=== 测试从.env文件读取MySQL连接配置 ===")
+    
+    # 使用默认配置创建连接(应从.env文件读取)
+    conn = get_mysql_conn()
+    
+    # 尝试执行简单查询
+    try:
+        # 获取游标上下文管理器
+        with conn.get_cursor() as cursor:
+            # 执行简单查询
+            cursor.execute("SELECT 1 AS test")
+            result = cursor.fetchone()
+            print(f"执行简单查询结果: {result}")
+        print("\n✓ 从.env文件读取MySQL连接配置成功!")
+    except Exception as e:
+        print(f"\n✗ 执行查询失败: {e}")
+    finally:
+        conn.close()
+
+
+def test_mysql_conn_with_custom_params():
+    """测试自定义参数覆盖.env配置"""
+    print("\n=== 测试自定义参数覆盖.env配置 ===")
+    
+    # 使用自定义参数创建连接
+    conn = get_mysql_conn(database="test")
+    
+    try:
+        # 获取游标上下文管理器
+        with conn.get_cursor() as cursor:
+            # 执行简单查询
+            cursor.execute("SELECT 1 AS test")
+            result = cursor.fetchone()
+            print(f"执行简单查询结果: {result}")
+        print("\n✓ 自定义参数覆盖.env配置成功!")
+    except Exception as e:
+        print(f"\n✗ 执行查询失败: {e}")
+    finally:
+        conn.close()
+
+
+if __name__ == "__main__":
+    test_mysql_pool_from_env()
+    test_mysql_conn_from_env()
+    test_mysql_conn_with_custom_params()
+    print("\n🎉 所有测试完成!")

+ 95 - 0
test/test_mysql_conn.py

@@ -0,0 +1,95 @@
+from utils.mysql_conn import get_mysql_conn
+
+
+def test_mysql_connection():
+    """测试MySQL连接池"""
+    print("开始测试MySQL连接池...")
+    
+    # 获取MySQL连接管理器实例
+    conn_manager = get_mysql_conn(
+        host="localhost",
+        port=3306,
+        user="root",
+        password="password",
+        database="test"
+    )
+    
+    print("✓ MySQL连接管理器实例创建成功")
+    
+    # 测试execute方法
+    try:
+        # 创建测试表
+        create_table_sql = """
+        CREATE TABLE IF NOT EXISTS test_users (
+            id INT AUTO_INCREMENT PRIMARY KEY,
+            name VARCHAR(50) NOT NULL,
+            email VARCHAR(100) NOT NULL,
+            create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+        )
+        """
+        rowcount = conn_manager.execute(create_table_sql)
+        print(f"✓ 创建测试表成功,受影响行数: {rowcount}")
+        
+        # 插入测试数据
+        insert_sql = "INSERT INTO test_users (name, email) VALUES (%s, %s)"
+        params = ("测试用户", "test@example.com")
+        rowcount = conn_manager.execute(insert_sql, params)
+        print(f"✓ 插入测试数据成功,受影响行数: {rowcount}")
+        
+        # 测试fetch_one方法
+        select_one_sql = "SELECT * FROM test_users ORDER BY id DESC LIMIT 1"
+        result = conn_manager.fetch_one(select_one_sql)
+        print(f"✓ 查询单条数据成功,结果: {result}")
+        
+        # 测试fetch_all方法
+        select_all_sql = "SELECT * FROM test_users"
+        results = conn_manager.fetch_all(select_all_sql)
+        print(f"✓ 查询所有数据成功,结果数量: {len(results)}")
+        
+        # 测试fetch_many方法
+        select_many_sql = "SELECT * FROM test_users"
+        results = conn_manager.fetch_many(select_many_sql, size=2)
+        print(f"✓ 查询多条数据成功,结果数量: {len(results)}")
+        
+        # 测试bulk_insert方法
+        bulk_insert_sql = "INSERT INTO test_users (name, email) VALUES (%s, %s)"
+        bulk_params = [
+            ("批量用户1", "batch1@example.com"),
+            ("批量用户2", "batch2@example.com"),
+            ("批量用户3", "batch3@example.com")
+        ]
+        rowcount = conn_manager.bulk_insert(bulk_insert_sql, bulk_params)
+        print(f"✓ 批量插入数据成功,受影响行数: {rowcount}")
+        
+        # 测试事务
+        try:
+            conn, cursor = conn_manager.begin_transaction()
+            cursor.execute("INSERT INTO test_users (name, email) VALUES (%s, %s)", ("事务用户", "transaction@example.com"))
+            conn_manager.commit_transaction(conn, cursor)
+            print("✓ 事务提交成功")
+        except Exception as e:
+            conn_manager.rollback_transaction(conn, cursor)
+            print(f"✗ 事务回滚: {e}")
+        
+        # 清理测试数据
+        delete_sql = "DELETE FROM test_users"
+        rowcount = conn_manager.execute(delete_sql)
+        print(f"✓ 清理测试数据成功,受影响行数: {rowcount}")
+        
+        # 删除测试表
+        drop_table_sql = "DROP TABLE IF EXISTS test_users"
+        rowcount = conn_manager.execute(drop_table_sql)
+        print(f"✓ 删除测试表成功,受影响行数: {rowcount}")
+        
+        print("\n🎉 所有测试通过!MySQL连接池工作正常。")
+        
+    except Exception as e:
+        print(f"\n✗ 测试失败: {e}")
+    
+    # 测试连接池关闭
+    conn_manager.close()
+    print("✓ 连接池关闭成功")
+
+
+if __name__ == "__main__":
+    test_mysql_connection()

+ 14 - 0
test/test_ragflow_sdk.py

@@ -0,0 +1,14 @@
+from api.db.services.tag_service import TagService
+
+if __name__ == "__main__":
+    tag_service = TagService()
+    tags_data = tag_service.get_tag_data()
+    print(tags_data)
+    
+    # multi_embeddings = Embedding(
+    #     model_name=ModelConfig.get_multimodal_embedding_model_name(),
+    #     api_key=ModelConfig.get_dashscope_api_key()
+    # )
+
+    # res = multi_embeddings.get_multimodal_embedding("这是一个测试文本", Image.open(r"C:\Users\EDY\Desktop\粉宝.png"))
+    # print(res)

+ 85 - 0
test/test_search_tag.py

@@ -0,0 +1,85 @@
+from langchain.chat_models import init_chat_model
+from model.multimodal_embedding import Embedding
+from conf.settings import model_settings, ragflow_settings, tag_search_settings
+from utils.infinity import get_client
+from model.jina_rerank import JinaRerank
+from langchain_core.documents import Document
+from utils.infinity.result_util import convert_to_langchain_docs
+
+# 初始化多模态嵌入模型
+embedding_model = Embedding(
+    model_name=model_settings.embedding_model_name,
+    api_key=model_settings.api_key
+)
+
+llm = init_chat_model(
+    model_provider=model_settings.model_provider,
+    model=model_settings.model_name,
+    base_url=model_settings.base_url,
+    api_key=model_settings.api_key,
+    temperature=0.3
+)
+
+compressor = JinaRerank(
+    base_url=model_settings.base_url,
+    jina_api_key=model_settings.api_key,
+    model=model_settings.rank_model_name, # 建议先确认硅基后台此模型 ID 是否正确
+    top_n=5
+)
+
+output_fields = [
+                    "docnm",
+                    "tag_kwd",
+                    "content",
+                    "kb_id",
+                    "doc_id"
+                ]
+
+search_query = {
+    "matching_text": """
+    桑尼最初因专注画画,对朋友求助表现出不耐烦并大声吼叫,后来意识到是小米生日,主动平复情绪转变态度帮助朋友,完成情绪调节
+    """,
+    "query_vector": [],
+    "vector_field": "q_1024_vec",
+    "match_field": "content",
+    "topn": 10
+}
+
+def main():
+    infinity_client = get_client(database=tag_search_settings.tag_db_name)
+    # 问题向量化
+    embedding = embedding_model.get_text_embedding(search_query["matching_text"])
+    search_query["query_vector"] = embedding
+    # TAG_TABLE_NAME="ragflow_92162247e93e11f084830242ac1d0002_52275b36f03611f0a5340242c0a85002"
+    TAG_TABLE_NAME="ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002"
+    results = infinity_client.vector_search(TAG_TABLE_NAME, output_fields, search_query).to_result()
+    candidate_docs = convert_to_langchain_docs(results)
+    
+    # 4. 直接调用重排序逻辑
+    reranked_docs = compressor.compress_documents(
+        documents=candidate_docs, 
+        query=search_query["matching_text"],
+        top_n=5
+    )
+    print(reranked_docs)
+    # 5. 查看结果
+    for i, doc in enumerate(reranked_docs):
+        print(f"排名 {i+1}: 分数 {doc.metadata['relevance_score']:.4f}")
+        print(f"内容: {doc.page_content}")
+        print(f"标签: {doc.metadata['tag_kwd']}\n")
+
+from utils.asymmetric_encryption import AsymmetricEncryption
+
+def main2():
+    # passwd = "zhangqi@lelequ.net"
+    # loaded_public_pem = AsymmetricEncryption.load_key_from_file(r"D:\project\work\graph_rag_server\public_key.pem")
+    loaded_private_pem = AsymmetricEncryption.load_key_from_file(r"D:\project\work\graph_rag_server\private_key.pem")
+    # encrypted = AsymmetricEncryption.encrypt(passwd, loaded_public_pem)
+    # print(encrypted)
+    decrypted = AsymmetricEncryption.decrypt(ragflow_settings.ragflow_user_name, loaded_private_pem)
+    print(decrypted)
+    # assert decrypted2 == test_message, "使用加载的密钥解密失败!"
+    # print("✓ 使用加载的密钥加密解密测试通过!")
+
+if __name__ == "__main__":
+    main2()

+ 25 - 0
test_upload_tag.py

@@ -0,0 +1,25 @@
+from api.db.services.tag_service import TagServiceFactory
+from utils.excel_util import excel_util
+from utils.infinity import get_client
+
+def main():
+    infinity_client = get_client()
+    # 使用工厂类创建标签服务实例
+    tag_service = TagServiceFactory.create_tag_service(infinity_client)
+    # 使用单例实例解析 Excel 文件
+    tags_data = excel_util.parse_excel(file_path=r"D:\project\work\data_deal\测试标签.xlsx")
+    # 调用批量上传标签方法
+    result = tag_service.upload_tags(tags_data)
+    # 打印上传结果
+    print(f"上传结果: {result}")
+
+# def get_chunk():
+#     infinity_client = get_client()
+#     table_object = infinity_client.get_table("ragflow_92162247e93e11f084830242ac1d0002_18caf531f04d11f095670242c0a85002", "default_db")
+#     res = table_object.output(["*"]).to_pl()
+#     print(res)
+
+
+
+if __name__ == "__main__":
+    main()

+ 151 - 0
utils/asymmetric_encryption.py

@@ -0,0 +1,151 @@
+from cryptography.hazmat.primitives import serialization, hashes
+from cryptography.hazmat.primitives.asymmetric import rsa, padding
+from cryptography.hazmat.backends import default_backend
+from typing import Tuple, Optional
+import base64
+
+
+class AsymmetricEncryption:
+    """非对称加密工具类,使用RSA算法"""
+
+    @staticmethod
+    def generate_key_pair(key_size: int = 2048) -> Tuple[bytes, bytes]:
+        """
+        生成RSA密钥对
+        
+        Args:
+            key_size: 密钥大小,默认为2048位
+        
+        Returns:
+            Tuple[bytes, bytes]: (私钥PEM格式,公钥PEM格式)
+        """
+        # 生成私钥
+        private_key = rsa.generate_private_key(
+            public_exponent=65537,
+            key_size=key_size,
+            backend=default_backend()
+        )
+        
+        # 生成公钥
+        public_key = private_key.public_key()
+        
+        # 将私钥序列化为PEM格式
+        private_pem = private_key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption()
+        )
+        
+        # 将公钥序列化为PEM格式
+        public_pem = public_key.public_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PublicFormat.SubjectPublicKeyInfo
+        )
+        
+        return private_pem, public_pem
+    
+    @staticmethod
+    def encrypt(message: str, public_key_pem: bytes) -> str:
+        """
+        使用公钥加密数据
+        
+        Args:
+            message: 要加密的明文
+            public_key_pem: 公钥PEM格式
+        
+        Returns:
+            str: 加密后的base64编码字符串
+        """
+        # 加载公钥
+        public_key = serialization.load_pem_public_key(
+            public_key_pem,
+            backend=default_backend()
+        )
+        
+        # 加密数据
+        encrypted = public_key.encrypt(
+            message.encode('utf-8'),
+            padding.OAEP(
+                mgf=padding.MGF1(algorithm=hashes.SHA256()),
+                algorithm=hashes.SHA256(),
+                label=None
+            )
+        )
+        
+        # 返回base64编码的加密数据
+        return base64.b64encode(encrypted).decode('utf-8')
+    
+    @staticmethod
+    def decrypt(encrypted_message: str, private_key_pem: bytes) -> str:
+        """
+        使用私钥解密数据
+        
+        Args:
+            encrypted_message: 加密后的base64编码字符串
+            private_key_pem: 私钥PEM格式
+        
+        Returns:
+            str: 解密后的明文
+        """
+        # 加载私钥
+        private_key = serialization.load_pem_private_key(
+            private_key_pem,
+            password=None,
+            backend=default_backend()
+        )
+        
+        # 解码base64加密数据
+        encrypted = base64.b64decode(encrypted_message)
+        
+        # 解密数据
+        decrypted = private_key.decrypt(
+            encrypted,
+            padding.OAEP(
+                mgf=padding.MGF1(algorithm=hashes.SHA256()),
+                algorithm=hashes.SHA256(),
+                label=None
+            )
+        )
+        
+        # 返回解密后的明文
+        return decrypted.decode('utf-8')
+    
+    @staticmethod
+    def save_key_to_file(key_pem: bytes, file_path: str) -> None:
+        """
+        将密钥保存到文件
+        
+        Args:
+            key_pem: 密钥PEM格式
+            file_path: 文件路径
+        """
+        with open(file_path, 'wb') as f:
+            f.write(key_pem)
+    
+    @staticmethod
+    def load_key_from_file(file_path: str) -> bytes:
+        """
+        从文件加载密钥
+        
+        Args:
+            file_path: 文件路径
+        
+        Returns:
+            bytes: 密钥PEM格式
+        """
+        with open(file_path, 'rb') as f:
+            return f.read()
+    
+    @staticmethod
+    def generate_key_pair_and_save(private_key_path: str, public_key_path: str, key_size: int = 2048) -> None:
+        """
+        生成密钥对并保存到文件
+        
+        Args:
+            private_key_path: 私钥文件路径
+            public_key_path: 公钥文件路径
+            key_size: 密钥大小,默认为2048位
+        """
+        private_pem, public_pem = AsymmetricEncryption.generate_key_pair(key_size)
+        AsymmetricEncryption.save_key_to_file(private_pem, private_key_path)
+        AsymmetricEncryption.save_key_to_file(public_pem, public_key_path)

+ 45 - 0
utils/decorators/langfuse_trace_embedding.py

@@ -0,0 +1,45 @@
+import functools
+import time
+from langfuse import Langfuse
+
+# 初始化 Langfuse 客户端
+langfuse = Langfuse()
+
+def langfuse_trace_embedding(name="embedding-operation"):
+    """
+    用于追踪 Embedding 操作的装饰器
+    """
+    def decorator(func):
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            # 1. 获取输入参数(假设第一个参数是文本或文本列表)
+            # 对于 embed_query(text) 或 embed_documents(texts)
+            input_data = args[1] if len(args) > 1 else kwargs.get("text") or kwargs.get("texts")
+            
+            # 2. 在 Langfuse 中开启一个 Span
+            # 如果当前上下文已有 Trace,它会自动关联;如果没有,它会创建新的 Trace
+            span = langfuse.span(
+                name=name,
+                input=input_data,
+                metadata={"function": func.__name__}
+            )
+            
+            start_time = time.time()
+            try:
+                # 3. 执行真正的 Embedding 操作
+                result = func(*args, **kwargs)
+                
+                # 4. 记录执行结果和耗时
+                span.end(
+                    output={
+                        "vector_count": len(result) if isinstance(result, list) else 1,
+                        "dimension": len(result[0]) if isinstance(result, list) and len(result) > 0 else len(result) if isinstance(result, list) else 0
+                    }
+                )
+                return result
+            except Exception as e:
+                # 捕捉异常并记录到 Langfuse
+                span.end(level="ERROR", status_message=str(e))
+                raise e
+        return wrapper
+    return decorator

+ 0 - 0
utils/decorators.py → utils/decorators/singleton.py


+ 221 - 0
utils/excel_util.py

@@ -0,0 +1,221 @@
+"""
+Excel 工具类
+
+该文件提供 Excel 解析功能,支持:
+- 解析 .xlsx 和 .xls 格式的 Excel 文件
+- 将 Excel 数据转换为 List[Dict[str, Any]] 格式
+- 支持指定工作表
+- 支持自定义表头行
+"""
+
+from typing import List, Dict, Any, Optional
+import os
+
+
+try:
+    from openpyxl import load_workbook
+    from openpyxl.worksheet.worksheet import Worksheet
+except ImportError:
+    raise ImportError("openpyxl 库未安装,请使用 pip install openpyxl 安装")
+
+try:
+    import xlrd
+    from xlrd.sheet import Sheet
+except ImportError:
+    raise ImportError("xlrd 库未安装,请使用 pip install xlrd 安装")
+
+
+class ExcelUtil:
+    """
+    Excel 工具类,用于解析 Excel 文件
+    """
+    
+    @staticmethod
+    def _parse_xlsx(file_path: str, sheet_name: Optional[str] = None, header_row: int = 0) -> List[Dict[str, Any]]:
+        """
+        解析 .xlsx 格式的 Excel 文件
+        
+        Args:
+            file_path: Excel 文件路径
+            sheet_name: 工作表名称,None 表示使用第一个工作表
+            header_row: 表头所在行索引,默认第 0 行
+            
+        Returns:
+            List[Dict[str, Any]]: 解析后的 Excel 数据
+        """
+        # 加载工作簿
+        workbook = load_workbook(filename=file_path, read_only=True)
+        
+        # 获取工作表
+        if sheet_name:
+            sheet: Worksheet = workbook[sheet_name]
+        else:
+            sheet: Worksheet = workbook.active
+        
+        # 获取所有行数据
+        rows = list(sheet.iter_rows(values_only=True))
+        
+        # 检查数据是否为空
+        if not rows:
+            return []
+        
+        # 检查表头行索引是否合法
+        if header_row >= len(rows):
+            raise ValueError(f"表头行索引 {header_row} 超出数据总行数 {len(rows)}")
+        
+        # 获取表头
+        headers = rows[header_row]
+        
+        # 检查表头是否为空
+        if not headers:
+            return []
+        
+        # 解析数据行
+        data: List[Dict[str, Any]] = []
+        for row in rows[header_row + 1:]:
+            if not row or all(cell is None for cell in row):
+                continue
+            
+            # 创建数据字典
+            row_data: Dict[str, Any] = {}
+            for i, header in enumerate(headers):
+                if i < len(row):
+                    row_data[header] = row[i]
+                else:
+                    row_data[header] = None
+            
+            data.append(row_data)
+        
+        workbook.close()
+        return data
+    
+    @staticmethod
+    def _parse_xls(file_path: str, sheet_name: Optional[str] = None, header_row: int = 0) -> List[Dict[str, Any]]:
+        """
+        解析 .xls 格式的 Excel 文件
+        
+        Args:
+            file_path: Excel 文件路径
+            sheet_name: 工作表名称,None 表示使用第一个工作表
+            header_row: 表头所在行索引,默认第 0 行
+            
+        Returns:
+            List[Dict[str, Any]]: 解析后的 Excel 数据
+        """
+        # 打开工作簿
+        workbook = xlrd.open_workbook(file_path)
+        
+        # 获取工作表
+        if sheet_name:
+            sheet: Sheet = workbook.sheet_by_name(sheet_name)
+        else:
+            sheet: Sheet = workbook.sheet_by_index(0)
+        
+        # 获取所有行数据
+        rows = []
+        for i in range(sheet.nrows):
+            rows.append(sheet.row_values(i))
+        
+        # 检查数据是否为空
+        if not rows:
+            return []
+        
+        # 检查表头行索引是否合法
+        if header_row >= len(rows):
+            raise ValueError(f"表头行索引 {header_row} 超出数据总行数 {len(rows)}")
+        
+        # 获取表头
+        headers = rows[header_row]
+        
+        # 检查表头是否为空
+        if not headers:
+            return []
+        
+        # 解析数据行
+        data: List[Dict[str, Any]] = []
+        for row in rows[header_row + 1:]:
+            if not row or all(cell == '' or cell is None for cell in row):
+                continue
+            
+            # 创建数据字典
+            row_data: Dict[str, Any] = {}
+            for i, header in enumerate(headers):
+                if i < len(row):
+                    row_data[header] = row[i]
+                else:
+                    row_data[header] = None
+            
+            data.append(row_data)
+        
+        return data
+    
+    @staticmethod
+    def parse_excel(file_path: str, sheet_name: Optional[str] = None, header_row: int = 0) -> List[Dict[str, Any]]:
+        """
+        解析 Excel 文件,支持 .xlsx 和 .xls 格式
+        
+        Args:
+            file_path: Excel 文件路径
+            sheet_name: 工作表名称,None 表示使用第一个工作表
+            header_row: 表头所在行索引,默认第 0 行
+            
+        Returns:
+            List[Dict[str, Any]]: 解析后的 Excel 数据
+        
+        Raises:
+            ValueError: 不支持的文件格式
+            FileNotFoundError: 文件不存在
+        """
+        # 检查文件是否存在
+        if not os.path.exists(file_path):
+            raise FileNotFoundError(f"文件 {file_path} 不存在")
+        
+        # 获取文件扩展名
+        file_ext = os.path.splitext(file_path)[1].lower()
+        
+        # 根据文件格式选择解析方法
+        if file_ext == '.xlsx':
+            return ExcelUtil._parse_xlsx(file_path, sheet_name, header_row)
+        elif file_ext == '.xls':
+            return ExcelUtil._parse_xls(file_path, sheet_name, header_row)
+        else:
+            raise ValueError(f"不支持的文件格式 {file_ext},只支持 .xlsx 和 .xls 格式")
+    
+    @staticmethod
+    def get_sheet_names(file_path: str) -> List[str]:
+        """
+        获取 Excel 文件中的所有工作表名称
+        
+        Args:
+            file_path: Excel 文件路径
+            
+        Returns:
+            List[str]: 工作表名称列表
+        
+        Raises:
+            ValueError: 不支持的文件格式
+            FileNotFoundError: 文件不存在
+        """
+        # 检查文件是否存在
+        if not os.path.exists(file_path):
+            raise FileNotFoundError(f"文件 {file_path} 不存在")
+        
+        # 获取文件扩展名
+        file_ext = os.path.splitext(file_path)[1].lower()
+        
+        # 根据文件格式选择方法
+        if file_ext == '.xlsx':
+            workbook = load_workbook(filename=file_path, read_only=True)
+            sheet_names = workbook.sheetnames
+            workbook.close()
+            return sheet_names
+        elif file_ext == '.xls':
+            workbook = xlrd.open_workbook(file_path)
+            sheet_names = workbook.sheet_names()
+            return sheet_names
+        else:
+            raise ValueError(f"不支持的文件格式 {file_ext},只支持 .xlsx 和 .xls 格式")
+
+
+# 单例模式实例
+excel_util = ExcelUtil()

+ 5 - 5
utils/file/minio/minio_util.py

@@ -1,18 +1,18 @@
 from minio import Minio
 from minio import Minio
 from typing import BinaryIO
 from typing import BinaryIO
 from datetime import timedelta
 from datetime import timedelta
-from conf.config import MinioConfig
+from conf.settings import minio_settings
 from utils.file.file_utils import generate_unique_filename
 from utils.file.file_utils import generate_unique_filename
 
 
 class MinIOUtil:
 class MinIOUtil:
     def __init__(self):
     def __init__(self):
         self.client = Minio(
         self.client = Minio(
-            endpoint=MinioConfig.get_minio_endpoint(),
-            access_key=MinioConfig.get_minio_access_key(),
-            secret_key=MinioConfig.get_minio_secret_key(),
+            endpoint=minio_settings.minio_endpoint,
+            access_key=minio_settings.minio_access_key,
+            secret_key=minio_settings.minio_secret_key,
             secure=False
             secure=False
         )
         )
-        self.bucket_name = MinioConfig.get_minio_bucket_name()
+        self.bucket_name = minio_settings.minio_bucket_name
         self._ensure_bucket_exists()
         self._ensure_bucket_exists()
 
 
     def _ensure_bucket_exists(self):
     def _ensure_bucket_exists(self):

+ 21 - 7
utils/infinity/client.py

@@ -3,7 +3,7 @@ from infinity.common import ConflictType
 from typing import Dict, Any, List, Optional
 from typing import Dict, Any, List, Optional
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
-from conf.config import VectorDBConfig
+from conf.settings import vector_db_settings
 from .pool import InfinityConnectionPool
 from .pool import InfinityConnectionPool
 
 
 class InfinityClient:
 class InfinityClient:
@@ -19,9 +19,9 @@ class InfinityClient:
     
     
     def __init__(
     def __init__(
         self,
         self,
-        host: str = VectorDBConfig.get_infinity_host(),
-        port: str = VectorDBConfig.get_infinity_sdk_port(),
-        database: str = VectorDBConfig.get_infinity_database(),
+        host: str = vector_db_settings.infinity_host,
+        port: str = vector_db_settings.infinity_sdk_port,
+        database: str = vector_db_settings.infinity_database,
         min_connections: int = 2,
         min_connections: int = 2,
         max_connections: int = 10
         max_connections: int = 10
     ):
     ):
@@ -179,6 +179,20 @@ class InfinityClient:
             table = self._get_table(conn, table_name, database_name)
             table = self._get_table(conn, table_name, database_name)
             return table.insert(documents)
             return table.insert(documents)
     
     
+    def update(
+        self, 
+        table_name: str, 
+        cond: str, 
+        data: Dict[str, Any], 
+        database_name: Optional[str] = None
+    ):
+        """更新文档"""
+        with self.pool.get_connection() as conn:
+            # 获取Table对象
+            table = self._get_table(conn, table_name, database_name)
+            # 更新文档
+            return table.update(cond, data)
+
     def search(
     def search(
         self,
         self,
         table_name: str,
         table_name: str,
@@ -238,9 +252,9 @@ _global_client: Optional[InfinityClient] = None
 _client_lock = threading.Lock()
 _client_lock = threading.Lock()
 
 
 def get_client(
 def get_client(
-    host: str = VectorDBConfig.get_infinity_host(),
-    port: str = VectorDBConfig.get_infinity_sdk_port(),
-    database: str = VectorDBConfig.get_infinity_database(),
+    host: str = vector_db_settings.infinity_host,
+    port: str = vector_db_settings.infinity_sdk_port,
+    database: str = vector_db_settings.infinity_database,
     min_connections: int = 2,
     min_connections: int = 2,
     max_connections: int = 10
     max_connections: int = 10
 ) -> InfinityClient:
 ) -> InfinityClient:

+ 4 - 4
utils/infinity/pool.py

@@ -4,7 +4,7 @@ from typing import Dict, Any, List, Optional
 import threading
 import threading
 import time
 import time
 from contextlib import contextmanager
 from contextlib import contextmanager
-from conf.config import VectorDBConfig
+from conf.settings import vector_db_settings
 
 
 class InfinityConnectionPool:
 class InfinityConnectionPool:
     """
     """
@@ -20,9 +20,9 @@ class InfinityConnectionPool:
     
     
     def __init__(
     def __init__(
         self,
         self,
-        host: str = VectorDBConfig.get_infinity_host(),
-        port: str = VectorDBConfig.get_infinity_sdk_port(),
-        database: str = VectorDBConfig.get_infinity_database(),
+        host: str = vector_db_settings.infinity_host,
+        port: str = vector_db_settings.infinity_sdk_port,
+        database: str = vector_db_settings.infinity_database,
         min_connections: int = 2,
         min_connections: int = 2,
         max_connections: int = 10,
         max_connections: int = 10,
         connection_timeout: int = 30,
         connection_timeout: int = 30,

+ 78 - 0
utils/infinity/result_util.py

@@ -0,0 +1,78 @@
+from typing import Dict, Any, List
+from langchain_core.documents import Document
+
+
+def convert_to_basic_types(obj: Any) -> Any:
+    """
+    递归将对象转换为基本类型,以便Pydantic能够序列化
+    
+    特殊处理:当字典中的子项包含相同长度的数组时,将其转换为数组对象结构
+    例如:{"a": [1,2], "b": [3,4]} -> [{"a":1, "b":3}, {"a":2, "b":4}]
+    
+    Args:
+        obj: 要转换的对象
+    
+    Returns:
+        转换后的基本类型对象
+    """
+    if obj is None:
+        return None
+    elif isinstance(obj, (str, int, float, bool)):
+        return obj
+    elif isinstance(obj, dict):
+        # 先递归转换所有值
+        converted = {k: convert_to_basic_types(v) for k, v in obj.items()}
+        
+        # 检查是否需要转换为数组对象结构
+        # 条件:所有值都是列表,且长度一致,且长度大于0
+        values = list(converted.values())
+        if all(isinstance(v, list) for v in values):
+            lengths = [len(v) for v in values]
+            if len(set(lengths)) == 1 and lengths[0] > 0:
+                # 转换为数组对象结构
+                result = []
+                keys = list(converted.keys())
+                for i in range(lengths[0]):
+                    item = {}
+                    for key in keys:
+                        # 处理数组中可能存在的None值
+                        if i < len(converted[key]):
+                            item[key] = converted[key][i]
+                        else:
+                            item[key] = None
+                    result.append(item)
+                return result
+        
+        return converted
+    elif isinstance(obj, (list, tuple)):
+        return [convert_to_basic_types(item) for item in obj]
+    else:
+        # 对于其他类型,尝试将其转换为字符串或字典
+        try:
+            return dict(obj)
+        except:
+            return str(obj)
+
+def convert_to_langchain_docs(obj: Any) -> List[Document]:
+    """
+    将Infinity搜索结果转换为LangChain的Document格式
+    
+    Args:
+        obj: 要转换的对象
+    
+    Returns:
+        转换后的Document列表
+    """
+    res = convert_to_basic_types(obj=obj)
+    # 将数据转换为 LangChain 的 Document 格式
+    candidate_docs = [
+        Document(page_content=item["content"], 
+            metadata={
+                "docnm": item["docnm"], 
+                "tag_kwd": item["tag_kwd"], 
+                "kb_id": item["kb_id"], 
+                "doc_id": item["doc_id"]
+            }) 
+        for item in res[0]
+    ]
+    return candidate_docs

+ 52 - 0
utils/mysql/__init__.py

@@ -0,0 +1,52 @@
+from .mysql_conn import MySQLConnection
+
+
+# 全局 MySQL 客户端实例
+_global_mysql_client = None
+
+def init_global_mysql_client(host: str = None, port: int = None, 
+                           user: str = None, password: str = None, 
+                           database: str = None, charset: str = None,
+                           pool_size: int = None, **kwargs) -> None:
+    """
+    初始化全局 MySQL 客户端
+    
+    Args:
+        host: MySQL 主机地址
+        port: MySQL 端口号
+        user: MySQL 用户名
+        password: MySQL 密码
+        database: 数据库名称
+        charset: 字符集
+        pool_size: 连接池大小
+        **kwargs: 其他 MySQL 连接参数
+    """
+    global _global_mysql_client
+    if _global_mysql_client is None:
+        _global_mysql_client = MySQLConnection(
+            host=host, port=port, user=user, password=password, 
+            database=database, charset=charset, pool_size=pool_size, **kwargs
+        )
+
+
+def get_global_mysql_client() -> MySQLConnection:
+    """
+    获取全局 MySQL 客户端实例
+    
+    Returns:
+        MySQL SQL 执行器实例
+    """
+    global _global_mysql_client
+    if _global_mysql_client is None:
+        raise RuntimeError("Global MySQL client has not been initialized. Call init_global_mysql_client() first.")
+    return _global_mysql_client
+
+
+def close_global_mysql_client() -> None:
+    """
+    关闭全局 MySQL 客户端
+    """
+    global _global_mysql_client
+    if _global_mysql_client is not None:
+        _global_mysql_client.close()
+        _global_mysql_client = None

+ 50 - 114
utils/mysql_conn.py → utils/mysql/mysql_conn.py

@@ -1,49 +1,41 @@
 """
 """
-MySQL 连接管理
+MySQL SQL 执行
 
 
-该文件提供 MySQL 数据库连接管理功能,支持:
+该文件提供 MySQL 数据库 SQL 执行功能,支持:
 - 单例模式
 - 单例模式
-- 连接池管理
 - 基本 CRUD 操作
 - 基本 CRUD 操作
 - 事务支持
 - 事务支持
 - 连接错误处理
 - 连接错误处理
+- 全局客户端管理
 """
 """
 import pymysql
 import pymysql
 from pymysql.cursors import DictCursor
 from pymysql.cursors import DictCursor
 from typing import Any, List, Dict, Optional, Union
 from typing import Any, List, Dict, Optional, Union
 from contextlib import contextmanager
 from contextlib import contextmanager
-
-# 单例装饰器
-class singleton:
-    def __init__(self, cls):
-        self.cls = cls
-        self._instance = None
-    
-    def __call__(self, *args, **kwargs):
-        if self._instance is None:
-            self._instance = self.cls(*args, **kwargs)
-        return self._instance
+from .mysql_pool import get_mysql_pool, MySQLPool
+from utils.decorators.singleton import singleton
 
 
 @singleton
 @singleton
 class MySQLConnection:
 class MySQLConnection:
     """
     """
-    MySQL 连接管理
+    MySQL SQL 执行器
     支持:
     支持:
     - 单例模式
     - 单例模式
-    - 连接池管理
     - 基本 CRUD 操作
     - 基本 CRUD 操作
     - 事务支持
     - 事务支持
     - 连接错误处理
     - 连接错误处理
     """
     """
     
     
-    def __init__(self, host: str = "localhost", port: int = 3306, 
+    def __init__(self, mysql_pool: Optional[MySQLPool] = None,
+                 host: str = None, port: int = None, 
                  user: str = None, password: str = None, 
                  user: str = None, password: str = None, 
-                 database: str = None, charset: str = "utf8mb4",
-                 pool_size: int = 5, **kwargs):
+                 database: str = None, charset: str = None,
+                 pool_size: int = None, **kwargs):
         """
         """
-        初始化 MySQL 连接池
+        初始化 MySQL SQL 执行器
         
         
         Args:
         Args:
+            mysql_pool: 可选的 MySQL 连接池实例,如果提供则使用该实例,否则创建新实例
             host: MySQL 主机地址
             host: MySQL 主机地址
             port: MySQL 端口号
             port: MySQL 端口号
             user: MySQL 用户名
             user: MySQL 用户名
@@ -53,76 +45,20 @@ class MySQLConnection:
             pool_size: 连接池大小
             pool_size: 连接池大小
             **kwargs: 其他 MySQL 连接参数
             **kwargs: 其他 MySQL 连接参数
         """
         """
-        self.host = host
-        self.port = port
-        self.user = user
-        self.password = password
-        self.database = database
-        self.charset = charset
-        self.pool_size = pool_size
-        self.kwargs = kwargs
-        
-        # 初始化连接池
-        self._connection_pool = []
-        self._init_connection_pool()
-    
-    def _init_connection_pool(self):
-        """
-        初始化连接池
-        """
-        for _ in range(self.pool_size):
-            conn = self._create_connection()
-            if conn:
-                self._connection_pool.append(conn)
-    
-    def _create_connection(self) -> Optional[pymysql.connections.Connection]:
-        """
-        创建新的 MySQL 连接
-        
-        Returns:
-            MySQL 连接对象,失败返回 None
-        """
-        try:
-            conn = pymysql.connect(
-                host=self.host,
-                port=self.port,
-                user=self.user,
-                password=self.password,
-                database=self.database,
-                charset=self.charset,
-                cursorclass=DictCursor,
-                **self.kwargs
-            )
-            return conn
-        except pymysql.Error as e:
-            print(f"创建 MySQL 连接失败: {e}")
-            return None
+        # 如果提供了连接池实例,则使用该实例,否则创建新实例
+        if mysql_pool:
+            self._pool = mysql_pool
+        else:
+            self._pool = get_mysql_pool(host, port, user, password, database, charset, pool_size, **kwargs)
     
     
-    def _get_connection(self) -> Optional[pymysql.connections.Connection]:
+    def _get_connection(self) -> pymysql.connections.Connection:
         """
         """
         从连接池获取连接
         从连接池获取连接
         
         
         Returns:
         Returns:
-            MySQL 连接对象,失败返回 None
+            MySQL 连接对象
         """
         """
-        if self._connection_pool:
-            return self._connection_pool.pop()
-        else:
-            # 连接池为空,创建新连接
-            return self._create_connection()
-    
-    def _return_connection(self, conn: pymysql.connections.Connection):
-        """
-        将连接返回连接池
-        
-        Args:
-            conn: MySQL 连接对象
-        """
-        if len(self._connection_pool) < self.pool_size:
-            self._connection_pool.append(conn)
-        else:
-            # 连接池已满,关闭连接
-            conn.close()
+        return self._pool.get_connection()
     
     
     @contextmanager
     @contextmanager
     def get_cursor(self, cursorclass=DictCursor):
     def get_cursor(self, cursorclass=DictCursor):
@@ -136,11 +72,9 @@ class MySQLConnection:
             MySQL 游标对象
             MySQL 游标对象
         """
         """
         conn = self._get_connection()
         conn = self._get_connection()
-        if not conn:
-            raise Exception("无法获取 MySQL 连接")
+        cursor = conn.cursor(cursorclass=cursorclass)
         
         
         try:
         try:
-            cursor = conn.cursor(cursorclass=cursorclass)
             yield cursor
             yield cursor
             conn.commit()
             conn.commit()
         except Exception as e:
         except Exception as e:
@@ -148,7 +82,7 @@ class MySQLConnection:
             raise e
             raise e
         finally:
         finally:
             cursor.close()
             cursor.close()
-            self._return_connection(conn)
+            conn.close()
     
     
     def execute(self, sql: str, params: Union[List, Dict] = None) -> int:
     def execute(self, sql: str, params: Union[List, Dict] = None) -> int:
         """
         """
@@ -234,16 +168,9 @@ class MySQLConnection:
             连接对象和游标对象
             连接对象和游标对象
         """
         """
         conn = self._get_connection()
         conn = self._get_connection()
-        if not conn:
-            raise Exception("无法获取 MySQL 连接")
-        
-        try:
-            conn.begin()
-            cursor = conn.cursor()
-            return conn, cursor
-        except Exception as e:
-            self._return_connection(conn)
-            raise e
+        conn.begin()
+        cursor = conn.cursor()
+        return conn, cursor
     
     
     def commit_transaction(self, conn: pymysql.connections.Connection, cursor: pymysql.cursors.Cursor):
     def commit_transaction(self, conn: pymysql.connections.Connection, cursor: pymysql.cursors.Cursor):
         """
         """
@@ -257,7 +184,7 @@ class MySQLConnection:
             conn.commit()
             conn.commit()
         finally:
         finally:
             cursor.close()
             cursor.close()
-            self._return_connection(conn)
+            conn.close()
     
     
     def rollback_transaction(self, conn: pymysql.connections.Connection, cursor: pymysql.cursors.Cursor):
     def rollback_transaction(self, conn: pymysql.connections.Connection, cursor: pymysql.cursors.Cursor):
         """
         """
@@ -271,26 +198,22 @@ class MySQLConnection:
             conn.rollback()
             conn.rollback()
         finally:
         finally:
             cursor.close()
             cursor.close()
-            self._return_connection(conn)
+            conn.close()
     
     
     def close(self):
     def close(self):
         """
         """
-        关闭连接池中的所有连接
+        关闭 SQL 执行器
         """
         """
-        for conn in self._connection_pool:
-            try:
-                conn.close()
-            except Exception as e:
-                print(f"关闭 MySQL 连接失败: {e}")
-        self._connection_pool.clear()
+        # 关闭连接池
+        self._pool.close()
 
 
 # 简化的接口函数,便于快速使用
 # 简化的接口函数,便于快速使用
-def get_mysql_conn(host: str = "localhost", port: int = 3306, 
-                   user: str = None, password: str = None, 
-                   database: str = None, charset: str = "utf8mb4",
-                   pool_size: int = 5, **kwargs) -> MySQLConnection:
+def get_mysql_conn(host: str = None, port: int = None, 
+                  user: str = None, password: str = None, 
+                  database: str = None, charset: str = None,
+                  pool_size: int = None, **kwargs) -> MySQLConnection:
     """
     """
-    获取 MySQL 连接管理器实例
+    获取 MySQL SQL 执行器实例
     
     
     Args:
     Args:
         host: MySQL 主机地址
         host: MySQL 主机地址
@@ -303,6 +226,19 @@ def get_mysql_conn(host: str = "localhost", port: int = 3306,
         **kwargs: 其他 MySQL 连接参数
         **kwargs: 其他 MySQL 连接参数
     
     
     Returns:
     Returns:
-        MySQL 连接管理器实例
+        MySQL SQL 执行器实例
+    """
+    return MySQLConnection(host=host, port=port, user=user, password=password, 
+                          database=database, charset=charset, pool_size=pool_size, **kwargs)
+
+def get_mysql_conn_with_pool(mysql_pool: MySQLPool) -> MySQLConnection:
+    """
+    使用指定的连接池获取 MySQL SQL 执行器实例
+    
+    Args:
+        mysql_pool: MySQL 连接池实例
+    
+    Returns:
+        MySQL SQL 执行器实例
     """
     """
-    return MySQLConnection(host, port, user, password, database, charset, pool_size, **kwargs)
+    return MySQLConnection(mysql_pool=mysql_pool)

+ 138 - 0
utils/mysql/mysql_pool.py

@@ -0,0 +1,138 @@
+"""
+MySQL 连接池配置
+
+该文件提供 MySQL 数据库连接池配置功能,支持:
+- 单例模式
+- DBUtils 连接池管理
+- 连接错误处理
+- 从.env文件读取配置
+"""
+import pymysql
+from pymysql.cursors import DictCursor
+from dbutils.pooled_db import PooledDB
+from conf.settings import mysql_settings
+
+# 单例装饰器
+class singleton:
+    def __init__(self, cls):
+        self.cls = cls
+        self._instance = None
+    
+    def __call__(self, *args, **kwargs):
+        if self._instance is None:
+            self._instance = self.cls(*args, **kwargs)
+        return self._instance
+
+@singleton
+class MySQLPool:
+    """
+    MySQL 连接池管理器
+    支持:
+    - 单例模式
+    - DBUtils 连接池管理
+    - 连接错误处理
+    """
+    
+    def __init__(self, host: str = None, port: int = None, 
+                 user: str = None, password: str = None, 
+                 database: str = None, charset: str = None,
+                 pool_size: int = None, **kwargs):
+        """
+        初始化 MySQL 连接池
+        
+        Args:
+            host: MySQL 主机地址
+            port: MySQL 端口号
+            user: MySQL 用户名
+            password: MySQL 密码
+            database: 数据库名称
+            charset: 字符集
+            pool_size: 连接池大小
+            **kwargs: 其他 MySQL 连接参数
+        """
+        # 从环境变量读取配置,优先级:传入参数 > 环境变量 > 默认值
+        self.host = host or mysql_settings.mysql_host
+        self.port = int(port or mysql_settings.mysql_port)
+        self.user = user or mysql_settings.mysql_user
+        self.password = password or mysql_settings.mysql_password
+        self.database = database or mysql_settings.mysql_database
+        self.charset = charset or mysql_settings.mysql_charset
+        self.pool_size = int(pool_size or mysql_settings.mysql_pool_size)
+        self.kwargs = kwargs
+        
+        # 初始化 DBUtils 连接池
+        self._pool = PooledDB(
+            creator=pymysql,
+            maxconnections=self.pool_size,
+            mincached=2,
+            maxcached=5,
+            maxshared=3,
+            blocking=True,
+            maxusage=None,
+            setsession=[],
+            ping=0,
+            host=self.host,
+            port=self.port,
+            user=self.user,
+            password=self.password,
+            database=self.database,
+            charset=self.charset,
+            cursorclass=DictCursor,
+            **kwargs
+        )
+    
+    def get_connection(self) -> pymysql.connections.Connection:
+        """
+        从连接池获取连接
+        
+        Returns:
+            MySQL 连接对象
+        """
+        return self._pool.connection()
+    
+    def close(self):
+        """
+        关闭连接池
+        """
+        # DBUtils 连接池会自动管理连接,无需手动关闭
+        pass
+
+# 简化的接口函数,便于快速使用
+def get_mysql_pool(host: str = None, port: int = None, 
+                  user: str = None, password: str = None, 
+                  database: str = None, charset: str = None,
+                  pool_size: int = None, **kwargs) -> MySQLPool:
+    """
+    获取 MySQL 连接池实例
+    
+    Args:
+        host: MySQL 主机地址
+        port: MySQL 端口号
+        user: MySQL 用户名
+        password: MySQL 密码
+        database: 数据库名称
+        charset: 字符集
+        pool_size: 连接池大小
+        **kwargs: 其他 MySQL 连接参数
+    
+    Returns:
+        MySQL 连接池实例
+    """
+    # 从环境变量读取默认值,与MySQLPool.__init__保持一致
+    default_host = mysql_settings.mysql_host
+    default_port = int(mysql_settings.mysql_port)
+    default_user = mysql_settings.mysql_user
+    default_database = mysql_settings.mysql_database
+    default_charset = mysql_settings.mysql_charset
+    default_pool_size = int(mysql_settings.mysql_pool_size)
+    
+    return MySQLPool(
+        host=host or default_host,
+        port=port or default_port,
+        user=user or default_user,
+        password=password,
+        database=database or default_database,
+        charset=charset or default_charset,
+        pool_size=pool_size or default_pool_size,
+        **kwargs
+    )

+ 3 - 1
utils/ragflow/dataset_service.py

@@ -6,7 +6,7 @@ class DatasetService:
     
     
     def create_dataset(self, name: str, description: str = None, 
     def create_dataset(self, name: str, description: str = None, 
                       embedding_model: str = None, permission: str = None,
                       embedding_model: str = None, permission: str = None,
-                      chunk_method: str = None) -> Dict[str, Any]:
+                      chunk_method: str = None, parser_config: dict = None) -> Dict[str, Any]:
         endpoint = "/api/v1/datasets"
         endpoint = "/api/v1/datasets"
         
         
         data = {"name": name}
         data = {"name": name}
@@ -18,6 +18,8 @@ class DatasetService:
             data["permission"] = permission
             data["permission"] = permission
         if chunk_method is not None:
         if chunk_method is not None:
             data["chunk_method"] = chunk_method
             data["chunk_method"] = chunk_method
+        if parser_config is not None:
+            data["parser_config"] = parser_config
         
         
         response = self.http_client.post(endpoint, json_data=data)
         response = self.http_client.post(endpoint, json_data=data)
         
         

+ 3 - 3
utils/ragflow/ragflow_service.py

@@ -6,7 +6,7 @@ from dataclasses import dataclass
 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
 
 
 from utils.http_client import HTTPClient
 from utils.http_client import HTTPClient
-from conf.config import ModelConfig
+from conf.settings import ragflow_settings
 from utils.ragflow.dataset_service import DatasetService
 from utils.ragflow.dataset_service import DatasetService
 from utils.ragflow.document_service import DocumentService
 from utils.ragflow.document_service import DocumentService
 from utils.ragflow.chunk_service import ChunkService
 from utils.ragflow.chunk_service import ChunkService
@@ -78,8 +78,8 @@ class FileInfo:
 
 
 class RAGFlowService:
 class RAGFlowService:
     def __init__(self, base_url: str = None, api_key: str = None):
     def __init__(self, base_url: str = None, api_key: str = None):
-        base_url = base_url or ModelConfig.get_ragflow_api_url()
-        api_key = api_key or ModelConfig.get_ragflow_api_key()
+        base_url = base_url or ragflow_settings.ragflow_api_url
+        api_key = api_key or ragflow_settings.ragflow_api_key
         self.http_client = HTTPClient(base_url=base_url, api_key=api_key)
         self.http_client = HTTPClient(base_url=base_url, api_key=api_key)
         
         
         self.dataset_service = DatasetService(self.http_client)
         self.dataset_service = DatasetService(self.http_client)

+ 9 - 0
utils/ragflow_sdk/__init__.py

@@ -0,0 +1,9 @@
+from .dataset_util import DataSetUtil
+from .document_util import DocumentUtil
+from .chunk_util import ChunkUtil
+
+__all__ = [
+    "DataSetUtil",
+    "DocumentUtil",
+    "ChunkUtil"
+]

+ 17 - 0
utils/ragflow_sdk/base_util.py

@@ -0,0 +1,17 @@
+from ragflow_sdk import RAGFlow
+from conf.settings import ragflow_settings
+
+
+class RAGFlowBaseUtil:
+    """
+    RAGFlow 工具类基础类,封装了 RAGFlow 客户端初始化逻辑
+    """
+    
+    def __init__(self):
+        """
+        初始化 RAGFlow 客户端
+        """
+        self.ragflow_client = RAGFlow(
+            api_key=ragflow_settings.ragflow_api_key,
+            base_url=ragflow_settings.ragflow_api_url
+        )

+ 80 - 0
utils/ragflow_sdk/chunk_util.py

@@ -0,0 +1,80 @@
+from typing import Optional
+from ragflow_sdk import Chunk
+from utils.ragflow_sdk.document_util import DocumentUtil
+from .base_util import RAGFlowBaseUtil
+
+
+
+class ChunkUtil(RAGFlowBaseUtil):
+
+    def __init__(self, dataset_id: str, document_id: str):
+        super().__init__()
+        self.dataset_id = dataset_id
+        self.document_id = document_id
+
+    def add_chunk(self, dataset_name: Optional[str] = None, document_id: Optional[str] = None, content: str = None, important_keywords: list[str] = []):
+        """
+        添加文档分块
+        """
+        try:
+            if not dataset_name or not document_id or not content:
+                raise ValueError("数据集名称、文档ID和内容不能为空")
+            chunk = DocumentUtil().get_document(dataset_name=dataset_name, document_id=document_id).add_chunk(content=content, important_keywords=important_keywords)
+            return chunk
+        except Exception as e:
+            print(f"添加文档分块失败: {e}")
+            raise e
+    
+    def delete_chunks(self, chunk_ids: list[str]):
+        """
+        删除文档分块
+        """
+        try:
+            DocumentUtil().get_document(dataset_name=dataset_name, document_id=document_id).delete_chunks(chunk_ids)
+        except Exception as e:
+            print(f"删除文档分块失败: {e}")
+            raise e
+    
+
+    def list_chunks(self, keywords: str = None, page: int = 1, page_size: int = 30, id : str = None) -> list[Chunk]:
+        """
+        列出文档分块
+        """
+        try:
+            chunks = DocumentUtil().get_document(dataset_name=dataset_name, document_id=document_id).list_chunks(keywords=keywords, page=page, page_size=page_size, id=id)
+            return chunks
+        except Exception as e:
+            print(f"列出文档分块失败: {e}")
+            raise e
+
+    def get_chunk(self, chunk_id: str) -> Chunk:
+        """
+        获取文档分块
+        """
+        _list = self.list_chunks(id=chunk_id)
+        if len(_list) > 0:
+            return _list[0]
+        raise Exception("Chunk %s not found" % chunk_id)
+    
+
+    def update_chunk(self, chunk_id: str, update_message: dict):
+        """
+        更新文档分块
+        """
+        try:
+            self.get_chunk(chunk_id=chunk_id).update(update_message)
+        except Exception as e:
+            print(f"更新文档分块失败: {e}")
+            raise e
+    
+    def retrieve(self, question: str, dataset_ids: list[str] = None, document_ids: list[str] = None, page: int = 1, page_size: int = 30, similarity_threshold: float = 0.2, vector_similarity_weight: float = 0.3, top_k: int = 1024, rerank_id: str = None, keyword: bool = False, cross_languages: list[str] = None, metadata_condition: dict = None) -> list[Chunk]:
+        """
+        检索文档分块
+        """
+        try:
+            chunks = self.ragflow_client.retrieve(question=question, dataset_ids=dataset_ids, document_ids=document_ids, page=page, page_size=page_size, similarity_threshold=similarity_threshold, vector_similarity_weight=vector_similarity_weight, top_k=top_k, rerank_id=rerank_id, keyword=keyword, cross_languages=cross_languages, metadata_condition=metadata_condition)
+            return chunks
+        except Exception as e:
+            print(f"检索文档分块失败: {e}")
+            raise e
+        

+ 45 - 0
utils/ragflow_sdk/conf/rag_parser_config.py

@@ -0,0 +1,45 @@
+class RagParserDefaults:
+
+    DATASET_PERMISSION="team"
+
+    DATASET_CHUNK_METHOD="naive"
+
+    DATASET_CONFIG_DICT = {
+            "chunk_token_num": 256,
+            "delimiter": "\n!?;。;!?",
+            "html4excel": False,
+            "layout_recognize": "Pro/Qwen/Qwen2.5-VL-7B-Instruct@SILICONFLOW",
+            "auto_keywords": 5,
+            "tag_kb_ids": [],
+            "topn_tags": 3,
+            "task_page_size": 4,
+            "raptor": {
+                "max_cluster": 64,
+                "max_token": 256,
+                "prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n      {cluster_content}\nThe above is the content you need to summarize.",
+                "random_seed": 0,
+                "threshold": 0.1,
+                "use_raptor": True
+            },
+            "graphrag": {
+                "resolution": True,
+                "use_graphrag": True,
+                "method": "general",
+                "entity_types": [
+                    "event",
+                    "Book",
+                    "Author",
+                    "Illustrator",
+                    "Series",
+                    "Theme",
+                    "Genre",
+                    "Character",
+                    "Setting",
+                    "AgeGroup",
+                    "Competency",
+                    "ArtStyle",
+                    "Award",
+                    "Publisher"
+                ]
+            }
+        }

+ 71 - 0
utils/ragflow_sdk/dataset_util.py

@@ -0,0 +1,71 @@
+from typing import Optional
+from ragflow_sdk import DataSet
+from conf.settings import ragflow_settings, model_settings
+from .conf.rag_parser_config import RagParserDefaults
+from .base_util import RAGFlowBaseUtil
+
+
+
+class DataSetUtil(RAGFlowBaseUtil):
+
+    def __init__(self):
+        super().__init__()
+
+    def create_dataset(self, chunk_method: str = RagParserDefaults.DATASET_CHUNK_METHOD,
+        dataset_name: Optional[str] = None, dataset_desc: str = "", 
+        embedding_model: str = f"{model_settings.embedding_model_name}@SILICONFLOW",
+        parser_config: dict = RagParserDefaults.DATASET_CONFIG_DICT):
+        """
+        创建数据集
+        """
+        # 封装数据集参数
+        parser_obj = DataSet.ParserConfig(self.ragflow_client, parser_config)
+        # 创建数据集
+        dataset = self.ragflow_client.create_dataset(
+            name=dataset_name,
+            description=dataset_desc,
+            embedding_model=embedding_model,
+            permission=RagParserDefaults.DATASET_PERMISSION, 
+            chunk_method= chunk_method,
+            parser_config=parser_obj
+        )
+        return dataset
+
+    def delete_dataset(self, ids: list[str] | None = None):
+        """
+        删除数据集
+        """
+        try:
+            if ids is not None:
+                self.ragflow_client.delete_datasets(ids)
+        except Exception as e:
+            print(f"删除数据集失败: {e}")
+
+    def get_dataset(self, name: Optional[str] = None):
+        """
+        获取数据集
+        """
+        try:
+            dataset = self.ragflow_client.get_dataset(name=name)
+        except Exception as e:
+            print(f"获取数据集失败: {e}")
+            return None
+        return dataset
+                
+    def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
+        """
+        获取数据集列表
+        """
+        try:
+            dataset_list = self.ragflow_client.list_datasets(
+                page=page,
+                page_size=page_size,
+                orderby=orderby,
+                desc=desc,
+                id=id,
+                name=name,
+            )
+        except Exception as e:
+            print(f"获取数据集列表失败: {e}")
+            return None
+        return dataset_list

+ 93 - 0
utils/ragflow_sdk/document_util.py

@@ -0,0 +1,93 @@
+from token import OP
+from typing import Optional
+from ragflow_sdk import DataSet, Document
+from conf.settings import ragflow_settings
+from .base_util import RAGFlowBaseUtil
+from common.models.pagination import Pagination
+
+
+
+class DocumentUtil(RAGFlowBaseUtil):
+
+    def __init__(self):
+        super().__init__()
+        
+
+    def upload_documents(self, dataset_name: Optional[str] = None, document_list: list[dict] = None):
+        """
+        上传文档到数据集
+        """
+        try:
+            if not document_list:
+                raise ValueError("文档列表不能为空")
+            # 上传文档到数据集
+            doc_list = self.ragflow_client.get_dataset(name=dataset_name).upload_documents(document_list)
+            return doc_list
+        except Exception as e:
+            print(f"上传文档到数据集失败: {e}")
+            raise e
+
+    def list_documents(self, dataset_name: Optional[str] = None, pagination: Pagination = Pagination(), id: str = None, keywords: str = None) -> list[Document]:
+        """
+        列出数据集的文档
+        """
+        try:
+            doc_list = self.ragflow_client.get_dataset(name=dataset_name).list_documents(id=id, keywords=keywords, **pagination.to_dict())
+            return doc_list
+        except Exception as e:
+            print(f"列出数据集文档失败: {e}")
+            raise e
+
+    def get_document(self, dataset_name: Optional[str] = None, document_id: Optional[str] = None) -> Document:
+        """
+        获取文档
+        """
+        _list = self.list_documents(dataset_name=dataset_name, id=document_id)
+        if len(_list) > 0:
+            return _list[0]
+        raise Exception("Document %s not found" % document_id)
+
+    # Document.update(update_message:dict)
+    def update_document(self, document: Document, update_message: dict):
+        """
+        更新文档
+        """
+        try:
+            document.update(update_message)
+        except Exception as e:
+            print(f"更新文档失败: {e}")
+            raise e 
+
+    def delete_document(self, dataset_name: Optional[str] = None, ids: list[str] = None):
+        """
+        删除文档
+        """
+        try:
+            if not ids:
+                raise ValueError("文档ID列表不能为空")
+            self.ragflow_client.get_dataset(name=dataset_name).delete_documents(ids=ids)
+        except Exception as e:
+            print(f"删除文档失败: {e}")
+            raise e
+
+    def async_parse_documents(self, dataset_name: Optional[str] = None,document_ids: list[str] = None):
+        """
+        异步解析文档
+        """
+        try:
+            self.ragflow_client.get_dataset(name=dataset_name).async_parse_documents(document_ids=document_ids)
+        except Exception as e:
+            print(f"异步解析文档失败: {e}")
+            raise e
+    
+    def parse_documents(self, dataset_name: Optional[str] = None,document_ids: list[str] = None) -> list[tuple[str, str, int, int]]:
+        """
+        解析文档
+        """
+        try:
+            doc_list = self.ragflow_client.get_dataset(name=dataset_name).parse_documents(document_ids=document_ids)
+            return doc_list
+        except Exception as e:
+            print(f"解析文档失败: {e}")
+            raise e
+

+ 0 - 0
utils/ragflow_sdk/param.json