Pārlūkot izejas kodu

多维度流程增加书、页拆分

yingge 3 mēneši atpakaļ
vecāks
revīzija
d8499a205e
31 mainītis faili ar 2756 papildinājumiem un 172 dzēšanām
  1. BIN
      requirements.txt
  2. 180 0
      requirements.txt.bak.1
  3. 7 3
      src/api/db/models/search_request_models.py
  4. 13 0
      src/api/db/repositories/__init__.py
  5. 254 0
      src/api/db/repositories/example_usage.py
  6. 105 0
      src/api/db/repositories/prompt_dimension_repository.py
  7. 46 0
      src/api/db/repositories/ragflow_user_repository.py
  8. 2 2
      src/api/db/services/infinity_search_service.py
  9. 4 4
      src/api/db/services/vector_search_service.py
  10. 2 2
      src/api/sdk/search_infinity.py
  11. 2 0
      src/datasets/parser/nodes/__init__.py
  12. 133 55
      src/datasets/parser/nodes/image_parse_node.py
  13. 105 0
      src/datasets/parser/nodes/picture_stitching_node.py
  14. 111 0
      src/datasets/parser/nodes/picture_stitching_node_example.md
  15. 205 29
      src/datasets/parser/nodes/qa_generate_node.py
  16. 14 3
      src/datasets/parser/nodes/ragflow_nodes.py
  17. 12 6
      src/datasets/parser/states/parser_states.py
  18. 5 3
      src/datasets/parser/workflow_nodes/__init__.py
  19. 152 0
      src/datasets/parser/workflow_nodes/dimension_book_split_node.py
  20. 11 11
      src/datasets/parser/workflow_nodes/dimension_page_split_node.py
  21. 6 3
      src/datasets/parser/workflows/dynamic_dimension_workflow.py
  22. 7 4
      src/datasets/parser/workflows/pdf_workflow.py
  23. 7 3
      src/datasets/parser/workflows/qa_workflow.py
  24. 2 2
      src/model/multimodal_embedding.py
  25. 188 0
      src/utils/mysql/ARCHITECTURE.md
  26. 160 0
      src/utils/mysql/README.md
  27. 23 1
      src/utils/mysql/__init__.py
  28. 198 0
      src/utils/mysql/base_repository.py
  29. 306 0
      src/utils/mysql/mysql_repository.py
  30. 442 0
      src/utils/mysql/query_builder.py
  31. 54 41
      src/utils/vector_db/elasticsearch_adapter.py

BIN
requirements.txt


+ 180 - 0
requirements.txt.bak.1

@@ -0,0 +1,180 @@
+aiohappyeyeballs==2.6.1
+aiohttp==3.13.3
+aiosignal==1.4.0
+annotated-doc==0.0.4
+annotated-types==0.7.0
+anyio==4.12.0
+APScheduler==3.11.2
+argon2-cffi==25.1.0
+argon2-cffi-bindings==25.1.0
+attrs==25.4.0
+Authlib==1.6.6
+backoff==2.2.1
+beartype==0.22.9
+cachetools==6.2.4
+certifi==2026.1.4
+cffi==2.0.0
+charset-normalizer==3.4.4
+click==8.3.1
+cloudpickle==3.1.2
+colorama==0.4.6
+cryptography==46.0.3
+cyclopts==4.4.4
+dashscope==1.25.5
+dataclasses-json==0.6.7
+datrie==0.8.3
+DBUtils==3.1.2
+diskcache==5.6.3
+distro==1.9.0
+dnspython==2.8.0
+docstring_parser==0.17.0
+docutils==0.22.4
+elastic-transport==8.17.1
+elasticsearch==8.11.1
+email-validator==2.3.0
+et_xmlfile==2.0.0
+exceptiongroup==1.3.1
+fakeredis==2.33.0
+fastapi==0.128.0
+fastmcp==2.14.2
+filelock==3.20.2
+frozenlist==1.8.0
+fsspec==2025.12.0
+googleapis-common-protos==1.72.0
+greenlet==3.3.0
+h11==0.16.0
+hanziconv==0.3.2
+hf-xet==1.2.0
+httpcore==1.0.9
+httpx==0.28.1
+httpx-sse==0.4.3
+huggingface_hub==1.2.3
+idna==3.11
+importlib_metadata==8.7.1
+infinity-sdk==0.6.15
+infinity_emb==0.0.77
+jaraco.classes==3.4.0
+jaraco.context==6.0.2
+jaraco.functools==4.4.0
+jiter==0.12.0
+joblib==1.5.3
+jsonpatch==1.33
+jsonpointer==3.0.0
+jsonschema==4.26.0
+jsonschema-path==0.3.4
+jsonschema-specifications==2025.9.1
+keyring==25.7.0
+langchain==1.2.0
+langchain-classic==1.0.1
+langchain-community==0.4.1
+langchain-core==1.2.6
+langchain-mcp-adapters==0.2.1
+langchain-openai==1.1.6
+langchain-text-splitters==1.1.0
+langfuse==3.12.0
+langgraph==1.0.5
+langgraph-checkpoint==3.0.1
+langgraph-prebuilt==1.0.5
+langgraph-sdk==0.3.1
+langsmith==0.6.0
+lupa==2.6
+markdown-it-py==4.0.0
+marshmallow==3.26.2
+mcp==1.25.0
+mdurl==0.1.2
+minio==7.2.20
+more-itertools==10.8.0
+multidict==6.7.0
+mypy_extensions==1.1.0
+nltk==3.9.2
+numpy==1.26.4
+ollama==0.6.1
+openai==2.14.0
+openapi-pydantic==0.5.1
+openpyxl==3.1.5
+opentelemetry-api==1.39.1
+opentelemetry-exporter-otlp-proto-common==1.39.1
+opentelemetry-exporter-otlp-proto-http==1.39.1
+opentelemetry-exporter-prometheus==0.60b1
+opentelemetry-instrumentation==0.60b1
+opentelemetry-proto==1.39.1
+opentelemetry-sdk==1.39.1
+opentelemetry-semantic-conventions==0.60b1
+orjson==3.11.5
+ormsgpack==1.12.1
+packaging==25.0
+pandas==2.3.3
+pathable==0.4.4
+pathvalidate==3.3.1
+pdf2image==1.17.0
+pillow==12.1.0
+platformdirs==4.5.1
+polars-lts-cpu==1.33.1
+prometheus_client==0.23.1
+propcache==0.4.1
+protobuf==6.33.4
+py-key-value-aio==0.3.0
+py-key-value-shared==0.3.0
+pyarrow==22.0.0
+pycparser==2.23
+pycryptodome==3.23.0
+pydantic==2.12.5
+pydantic-settings==2.12.0
+pydantic_core==2.41.5
+pydocket==0.16.3
+Pygments==2.19.2
+PyJWT==2.10.1
+PyMuPDF==1.26.7
+PyMySQL==1.1.2
+pyperclip==1.11.0
+python-dateutil==2.9.0.post0
+python-dotenv==1.2.1
+python-json-logger==4.0.0
+python-multipart==0.0.21
+pytz==2025.2
+pywin32==311
+pywin32-ctypes==0.2.3
+PyYAML==6.0.3
+ragflow-sdk==0.23.1
+readerwriterlock==1.0.9
+redis==7.1.0
+referencing==0.36.2
+regex==2025.11.3
+requests==2.32.5
+requests-toolbelt==1.0.0
+rich==14.2.0
+rich-rst==1.3.2
+rpds-py==0.30.0
+setuptools==80.9.0
+shellingham==1.5.4
+six==1.17.0
+sniffio==1.3.1
+sortedcontainers==2.4.0
+SQLAlchemy==2.0.45
+sqlglot==28.5.0
+sqlglotrs==0.10.0
+sse-starlette==3.1.2
+starlette==0.50.0
+tenacity==9.1.2
+thrift==0.22.0
+tiktoken==0.12.0
+tqdm==4.67.1
+typer==0.21.1
+typer-slim==0.21.0
+typing-inspect==0.9.0
+typing-inspection==0.4.2
+typing_extensions==4.15.0
+tzdata==2025.3
+tzlocal==5.3.1
+urllib3==2.6.2
+uuid_utils==0.12.0
+uvicorn==0.40.0
+websocket-client==1.9.0
+websockets==15.0.1
+wheel==0.45.1
+wrapt==1.17.3
+xlrd==2.0.2
+xxhash==3.6.0
+yarl==1.22.0
+zipp==3.23.0
+zstandard==0.25.0

+ 7 - 3
src/api/db/models/search_request_models.py

@@ -19,7 +19,11 @@ class SearchRequest(BaseModel):
         matching_text: 用于文本匹配搜索的查询文本 (可选)
         topn: 返回的最大结果数量,默认为10,范围1-100
     """
-    
+    database_ids: List[str] = Field(
+        default=[],
+        description="知识库id",
+        examples=["database_id1", "database_id2"]
+    )
     image_url: Optional[str] = Field(
         default=None,
         description="用于图像搜索的图片URL",
@@ -30,12 +34,12 @@ class SearchRequest(BaseModel):
         description="用于文本匹配搜索的查询文本",
         examples=["搜索关键词"]
     )
-    topn: int = Field(
+    topk: int = Field(
         default=10,
         ge=1,
         le=100,
         description="返回的最大结果数量"
-    ),
+    )
     score_threshold: float = Field(
         default=0.0,
         ge=0.0,

+ 13 - 0
src/api/db/repositories/__init__.py

@@ -0,0 +1,13 @@
+"""
+数据库 Repository 模块
+
+提供各业务表的 Repository 实现。
+"""
+
+from .prompt_dimension_repository import PromptDimensionRepository
+from .ragflow_user_repository import RagflowUserRepository
+
+__all__ = [
+    "PromptDimensionRepository",
+    "RagflowUserRepository",
+]

+ 254 - 0
src/api/db/repositories/example_usage.py

@@ -0,0 +1,254 @@
+"""
+Repository 使用示例
+
+展示如何使用新的 Repository 层重构现有服务。
+"""
+
+from typing import List, Dict, Any, Optional
+from src.api.db.repositories import PromptDimensionRepository, RagflowUserRepository
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+# ========== 示例 1: 基础 CRUD 操作 ==========
+
+def example_basic_crud():
+    """基础 CRUD 操作示例"""
+    repo = PromptDimensionRepository()
+    
+    # 1. 查询
+    dimension = repo.find_by_id(1)
+    dimension_by_name = repo.find_by_name("知识实体与百科拆解")
+    all_dimensions = repo.find_all(order_by="created_at DESC")
+    
+    # 2. 条件查询
+    active_dimensions = repo.find_all({"status": 1})
+    
+    # 3. 插入
+    new_dimension = repo.insert({
+        "name": "新维度",
+        "description": "描述",
+        "dataset_id": "dataset_123"
+    })
+    
+    # 4. 更新
+    repo.update_by_id(1, {"description": "更新后的描述"})
+    
+    # 5. 删除
+    repo.delete_by_id(1)
+    
+    # 6. 统计
+    count = repo.count({"status": 1})
+
+
+# ========== 示例 2: 分页查询 ==========
+
+def example_pagination():
+    """分页查询示例"""
+    repo = PromptDimensionRepository()
+    
+    # 分页查询
+    result = repo.paginate(
+        conditions={"status": 1},
+        order_by="created_at DESC",
+        page=1,
+        page_size=10
+    )
+    
+    # result 包含:
+    # {
+    #     "total": 100,
+    #     "items": [...],
+    #     "page": 1,
+    #     "page_size": 10,
+    #     "total_pages": 10
+    # }
+    
+    logger.info(f"总共 {result['total']} 条记录,当前第 {result['page']} 页")
+
+
+# ========== 示例 3: 关联查询 ==========
+
+def example_join_query():
+    """关联查询示例"""
+    repo = PromptDimensionRepository()
+    
+    # 获取维度和激活的提示词
+    dimensions_with_prompts = repo.get_dimensions_with_active_prompts()
+    
+    for item in dimensions_with_prompts:
+        logger.info(f"维度: {item['name']}, 提示词版本: {item['version_number']}")
+
+
+# ========== 示例 4: 复杂条件查询 ==========
+
+def example_complex_query():
+    """复杂条件查询示例"""
+    from src.utils.mysql import QueryBuilder, get_global_mysql_client
+    
+    db = get_global_mysql_client()
+    
+    # 使用 QueryBuilder 构建复杂查询
+    builder = QueryBuilder("prompt_dimensions")
+    builder.select(["id", "name", "description", "created_at"])
+    builder.where_eq("status", 1)
+    builder.where_like("name", "%维度%")
+    builder.where_in("id", [1, 2, 3, 4, 5])
+    builder.where_gt("created_at", "2024-01-01")
+    builder.order_by("created_at", "DESC")
+    builder.limit(20)
+    builder.offset(0)
+    
+    sql, params = builder.build_select()
+    results = db.fetch_all(sql, params)
+    
+    logger.info(f"查询到 {len(results)} 条记录")
+
+
+# ========== 示例 5: 重构现有服务 ==========
+
+class PromptServiceRefactored:
+    """
+    重构后的提示词服务示例
+    
+    展示如何使用 Repository 重构原有服务。
+    """
+    
+    def __init__(self):
+        """初始化服务"""
+        self.dimension_repo = PromptDimensionRepository()
+        # 可以继续使用原有的向量数据库客户端
+        from src.utils.vector_db import get_vector_db_client
+        self._vector_client = None
+    
+    def get_dimension_by_name(self, name: str) -> Optional[Dict[str, Any]]:
+        """
+        根据名称获取维度(使用 Repository)
+        
+        Args:
+            name: 维度名称
+            
+        Returns:
+            维度信息字典,不存在返回 None
+        """
+        return self.dimension_repo.find_by_name(name)
+    
+    def get_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
+        """
+        根据ID获取维度(使用 Repository)
+        
+        Args:
+            dimension_id: 维度ID
+            
+        Returns:
+            维度信息字典,不存在返回 None
+        """
+        return self.dimension_repo.find_by_id(dimension_id)
+    
+    def get_all_dimensions(self) -> List[Dict[str, Any]]:
+        """
+        获取所有维度(使用 Repository)
+        
+        Returns:
+            维度列表
+        """
+        return self.dimension_repo.find_all(order_by="created_at DESC")
+    
+    def add_dimension(self, name: str, description: str = None, dataset_id: str = None) -> Dict[str, Any]:
+        """
+        添加维度(使用 Repository)
+        
+        Args:
+            name: 维度名称
+            description: 维度描述
+            dataset_id: 数据集ID
+            
+        Returns:
+            新建的维度信息
+        """
+        # 插入新维度
+        self.dimension_repo.insert({
+            "name": name,
+            "description": description,
+            "dataset_id": dataset_id
+        })
+        
+        # 返回新建的维度
+        return self.dimension_repo.find_by_name(name)
+    
+    def update_dimension(self, dimension_id: int, name: str = None, description: str = None) -> int:
+        """
+        更新维度信息(使用 Repository)
+        
+        Args:
+            dimension_id: 维度ID
+            name: 维度名称
+            description: 维度描述
+            
+        Returns:
+            受影响的行数
+        """
+        update_data = {}
+        if name is not None:
+            update_data["name"] = name
+        if description is not None:
+            update_data["description"] = description
+        
+        if not update_data:
+            return 0
+        
+        return self.dimension_repo.update_by_id(dimension_id, update_data)
+    
+    def delete_dimension(self, dimension_id: int) -> int:
+        """
+        删除维度(使用 Repository)
+        
+        Args:
+            dimension_id: 维度ID
+            
+        Returns:
+            受影响的行数
+        """
+        return self.dimension_repo.delete_by_id(dimension_id)
+
+
+# ========== 示例 6: 批量操作 ==========
+
+def example_batch_operations():
+    """批量操作示例"""
+    repo = PromptDimensionRepository()
+    
+    # 批量插入
+    data_list = [
+        {"name": "维度1", "description": "描述1"},
+        {"name": "维度2", "description": "描述2"},
+        {"name": "维度3", "description": "描述3"},
+    ]
+    affected_rows = repo.insert_many(data_list)
+    logger.info(f"批量插入了 {affected_rows} 条记录")
+    
+    # 批量查询
+    ids = [1, 2, 3, 4, 5]
+    dimensions = repo.find_by_ids(ids)
+    logger.info(f"批量查询到 {len(dimensions)} 条记录")
+
+
+# ========== 示例 7: 自定义 SQL ==========
+
+def example_custom_sql():
+    """自定义 SQL 示例"""
+    repo = PromptDimensionRepository()
+    
+    # 执行自定义查询
+    sql = """
+        SELECT pd.*, COUNT(pv.id) as version_count
+        FROM prompt_dimensions pd
+        LEFT JOIN prompt_versions pv ON pd.id = pv.dimension_id
+        GROUP BY pd.id
+        HAVING version_count > 0
+        ORDER BY version_count DESC
+    """
+    results = repo.execute_query(sql)
+    
+    logger.info(f"查询到 {len(results)} 条记录")

+ 105 - 0
src/api/db/repositories/prompt_dimension_repository.py

@@ -0,0 +1,105 @@
+"""
+提示词维度 Repository
+
+提供提示词维度相关的数据库操作。
+"""
+
+from typing import List, Dict, Any, Optional
+from src.utils.mysql import MySQLRepository
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class PromptDimensionRepository(MySQLRepository):
+    """
+    提示词维度 Repository
+    
+    封装 prompt_dimensions 表的数据库操作。
+    """
+    
+    def __init__(self, db_client=None):
+        """
+        初始化提示词维度 Repository
+        
+        Args:
+            db_client: MySQL 连接客户端,如果为 None 则使用全局客户端
+        """
+        super().__init__("prompt_dimensions", db_client)
+    
+    def find_by_name(self, name: str) -> Optional[Dict[str, Any]]:
+        """
+        根据名称查找维度
+        
+        Args:
+            name: 维度名称
+            
+        Returns:
+            维度信息字典,不存在返回 None
+        """
+        return self.find_one({"name": name})
+    
+    def find_active_dimensions(self) -> List[Dict[str, Any]]:
+        """
+        查找所有激活的维度
+        
+        Returns:
+            维度列表
+        """
+        # 注意:如果表中没有 is_active 字段,需要根据实际表结构调整
+        return self.find_all(order_by="created_at DESC")
+    
+    def get_dimensions_with_active_prompts(self) -> List[Dict[str, Any]]:
+        """
+        关联查询维度和激活的提示词版本
+        
+        Returns:
+            包含维度信息和提示词内容的列表
+        """
+        builder = self._create_query_builder()
+        builder.select([
+            "pd.id", 
+            "pd.name", 
+            "pd.description",
+            "pd.dataset_id",
+            "pv.content as prompt_content", 
+            "pv.version_number",
+            "pv.is_active"
+        ])
+        builder.join(
+            "prompt_versions pv",
+            "pd.id = pv.dimension_id AND pv.is_active = 1",
+            "LEFT"
+        )
+        builder.order_by("pd.created_at", "DESC")
+        sql, params = builder.build_select()
+        return self.execute_query(sql, params)
+    
+    def get_dimension_with_prompt_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
+        """
+        根据维度ID获取维度信息和激活的提示词
+        
+        Args:
+            dimension_id: 维度ID
+            
+        Returns:
+            包含维度信息和提示词内容的字典,不存在返回 None
+        """
+        builder = self._create_query_builder()
+        builder.select([
+            "pd.id", 
+            "pd.name", 
+            "pd.description",
+            "pd.dataset_id",
+            "pv.content as prompt_content", 
+            "pv.version_number"
+        ])
+        builder.join(
+            "prompt_versions pv",
+            "pd.id = pv.dimension_id AND pv.is_active = 1",
+            "LEFT"
+        )
+        builder.where_eq("pd.id", dimension_id)
+        builder.limit(1)
+        sql, params = builder.build_select()
+        return self.execute_query(sql, params)[0] if self.execute_query(sql, params) else None

+ 46 - 0
src/api/db/repositories/ragflow_user_repository.py

@@ -0,0 +1,46 @@
+"""
+RAGFlow 用户 Repository
+
+提供 RAGFlow 用户相关的数据库操作。
+"""
+
+from typing import Optional, Dict, Any
+from src.utils.mysql import MySQLRepository
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class RagflowUserRepository(MySQLRepository):
+    """
+    RAGFlow 用户 Repository
+    
+    封装 ragflow_user 表的数据库操作。
+    """
+    
+    def __init__(self, db_client=None):
+        """
+        初始化 RAGFlow 用户 Repository
+        
+        Args:
+            db_client: MySQL 连接客户端,如果为 None 则使用全局客户端
+        """
+        super().__init__("ragflow_user", db_client)
+    
+    def get_ragflow_id_and_api_key(self, id: int) -> Optional[Dict[str, Any]]:
+        """
+        根据主键 id 获取 ragflow_id(user_id)和 api_key
+        
+        Args:
+            id: 主键ID
+            
+        Returns:
+            包含 ragflow_id 和 api_key 的字典,不存在返回 None
+        """
+        result = self.find_by_id(id)
+        if result:
+            return {
+                "ragflow_id": result.get("ragflow_id"),
+                "api_key": result.get("api_key")
+            }
+        return None

+ 2 - 2
src/api/db/services/infinity_search_service.py

@@ -138,9 +138,9 @@ class InfinitySearchService:
                 "matching_text": request.matching_text,
                 "vector_field": self.vector_field,
                 "query_vector": query_vector,
-                "topn": request.topn,
+                "topn": request.topk,
                 "knn_params": {
-                    "ef": str(request.topn * 10),
+                    "ef": str(request.topk * 10),
                     "threshold": str(request.score_threshold)
                 }
             }

+ 4 - 4
src/api/db/services/vector_search_service.py

@@ -59,7 +59,7 @@ class VectorSearchService:
         self.vector_field = vector_field or "dense_vector_1024"
         self.match_field = match_field or "content"
         self.match_type = match_type or "cosine"
-        self.table_name = table_name or "ragbook_1_4981b64cf8d611f095d95a49fdb98e7b"
+        self.table_name = table_name or None
 
     @property
     def client(self) -> VectorDBClient:
@@ -114,7 +114,7 @@ class VectorSearchService:
         """
         try:
             search_query = self._convert_search_request_to_search_query(request)
-            result = self._client.hybrid_search(self.table_name, self.output_fields, search_query)
+            result = self._client.hybrid_search(table_name=request.database_ids, output_fields=self.output_fields, query=search_query)
             return result.to_result()
         except Exception as e:
             logger.error(f"混合搜索失败: {str(e)}")
@@ -138,9 +138,9 @@ class VectorSearchService:
                 "matching_text": request.matching_text,
                 "vector_field": self.vector_field,
                 "query_vector": query_vector,
-                "topn": request.topn,
+                "topn": request.topk,
                 "knn_params": {
-                    "ef": str(request.topn * 10),
+                    "ef": str(request.topk * 10),
                     "threshold": str(request.score_threshold) if request.score_threshold else "0"
                 }
             }

+ 2 - 2
src/api/sdk/search_infinity.py

@@ -81,10 +81,10 @@ async def question_search(request: SearchRequest):
     - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
     """
     try:
-        output_fields = ["content"]
+        output_fields = ["content_with_weight"]
         search_service = VectorSearchService(
             client=get_vector_db_client(database="ragflow_db"), 
-            table_name="ragflow_92162247e93e11f084830242ac1d0002_1c4f7a82f66c11f09c750242c0a8d002", 
+            table_name="ragflow_f3abf26bf80c11f0953d0242ac180002", 
             vector_field="q_1024_vec",
             output_fields=output_fields
         )

+ 2 - 0
src/datasets/parser/nodes/__init__.py

@@ -21,6 +21,7 @@ from src.datasets.parser.nodes.export_csv_node import ExportCSVNode
 from src.datasets.parser.nodes.prompt_retrieval_node import PromptRetrievalNode
 from src.datasets.parser.nodes.table_name_generation_node import TableNameGenerationNode
 from src.datasets.parser.nodes.dimension_result_node import DimensionResultNode
+from src.datasets.parser.nodes.picture_stitching_node import PictureStitchingNode
 
 __all__ = [
     "PDFSplitNode",
@@ -38,4 +39,5 @@ __all__ = [
     "PromptRetrievalNode",
     "TableNameGenerationNode",
     "DimensionResultNode",
+    "PictureStitchingNode",
 ]

+ 133 - 55
src/datasets/parser/nodes/image_parse_node.py

@@ -22,10 +22,13 @@ class ImageParseNode(BaseNode):
     """
     图像解析节点
     
-    使用VL模型(如Qwen-VL)解析图像内容,支持并行处理多张图片。
+    使用VL模型(如Qwen-VL)解析图像内容,支持两种模式:
+    1. 分页模式:解析多张分页图片(split_pages)
+    2. 书本模式:解析一张完整的长图(book_image)
     
     需要的状态字段:
-        - split_pages: 待解析的页面列表(包含image字段)
+        - split_pages: 待解析的页面列表(包含image字段,分页模式)
+        - book_image: 完整书本图片(书本模式)
         
     更新的状态字段:
         - parsed_results: 解析结果列表
@@ -35,37 +38,56 @@ class ImageParseNode(BaseNode):
     def __init__(
         self,
         model_name: Optional[str] = None,
-        max_workers: int = 5
+        max_workers: int = 5,
+        use_book_image: bool = False
     ):
         """
         初始化图像解析节点
         
         Args:
             model_name: VL模型名称
-            max_workers: 并行处理的最大工作线程数(已废弃,使用全局线程池)
-            prompt_template: 自定义提示词模板
+            max_workers: 并行处理的最大工作线程数(已废弃,使用全局线程池,仅分页模式使用
+            use_book_image: 是否使用book_image模式(True=解析完整长图,False=解析分页图片)
         """
         self.model_name = model_name or model_settings.vl_model_name
         self.max_workers = max_workers  # 保留兼容性但不再使用
+        self.use_book_image = use_book_image
     
     @property
     def name(self) -> str:
         return "image_parse"
     
-    def _default_prompt_template(self) -> str:
-        """默认的图像解析提示词模板"""
-        return """
-            角色定位:你是一位顶尖的儿童绘本分析师与视觉工程专家,擅长将插画视觉信息转化为高精度的结构化元数据。
-            任务描述:请深度解析提供的绘本页面,提取基本要素和特征。
-            当前提取页码为:{page_number}
-            
-            输出格式:JSON
-            {{
-                "page_number": {page_number},
-                "content": "页面内容描述",
-                "elements": []
-            }}
+    def _default_prompt_template(self, is_book_image: bool = False) -> str:
         """
+        默认的图像解析提示词模板
+        
+        Args:
+            is_book_image: 是否为完整书本图片模式
+        """
+        if is_book_image:
+            return """
+                角色定位:你是一位顶尖的儿童绘本分析师与视觉工程专家,擅长将插画视觉信息转化为高精度的结构化元数据。
+                任务描述:请深度解析提供的完整绘本图片(包含所有页面),提取基本要素和特征。
+                
+                输出格式:JSON
+                {{
+                    "content": "完整绘本内容描述",
+                    "elements": []
+                }}
+            """
+        else:
+            return """
+                角色定位:你是一位顶尖的儿童绘本分析师与视觉工程专家,擅长将插画视觉信息转化为高精度的结构化元数据。
+                任务描述:请深度解析提供的绘本页面,提取基本要素和特征。
+                当前提取页码为:{page_number}
+                
+                输出格式:JSON
+                {{
+                    "page_number": {page_number},
+                    "content": "页面内容描述",
+                    "elements": []
+                }}
+            """
     
     def _parse_single_page(self, page: Dict[str, Any], prompt_template: str) -> Dict[str, Any]:
         """
@@ -106,49 +128,105 @@ class ImageParseNode(BaseNode):
                 "error": str(e)
             }
     
-    def execute(self, state: BaseState) -> Dict[str, Any]:
+    def _parse_book_image(self, book_image, prompt_template: str) -> Dict[str, Any]:
         """
-        执行图像解析
+        解析完整书本图片
         
         Args:
-            state: 包含split_pages或image_pages的状态
+            book_image: PIL图像对象,完整的书本长图
+            prompt_template: 提示词模板
             
         Returns:
-            包含parsed_results的更新字典
+            解析结果字典
         """
-        # 支持多种状态字段名称
-        pages = getattr(state, 'split_pages', None) or getattr(state, 'image_pages', [])
-        prompt_template = state.dimension_prompt or self._default_prompt_template()
-        if not pages:
-            logger.warning("没有待解析的页面")
-            return {"parsed_results": [], "processed_pages": 0}
-        
-        logger.info(f"开始并行解析 {len(pages)} 页")
-        
-        parsed_results = []
+        logger.info("开始解析完整书本图片")
         
-        # 使用全局线程池
-        pool = ThreadPoolManager.get_pool("parser")
-        future_to_page = {
-            pool.submit(self._parse_single_page, page, prompt_template): page
-            for page in pages
-        }
-        
-        for future in concurrent.futures.as_completed(future_to_page):
-            try:
-                result = future.result()
-                parsed_results.append(result)
-            except Exception as e:
-                page = future_to_page[future]
-                logger.error(f"解析第 {page.get('page_number', '?')} 页时出错: {str(e)}")
-        
-        # 按页码排序结果
-        parsed_results.sort(key=lambda x: x.get("page_number", 0))
+        try:
+            parser = QWenVLParser(self.model_name)
+            # 对于完整书本图片,page_number设为0或None
+            result = parser.parse_image(book_image, 0, prompt_template)
+            
+            logger.info("完整书本图片解析完成")
+            return result
+        except Exception as e:
+            logger.error(f"解析完整书本图片时出错: {str(e)}")
+            return {
+                "content": "",
+                "error": str(e)
+            }
+    
+    def execute(self, state: BaseState) -> Dict[str, Any]:
+        """
+        执行图像解析
         
-        logger.info(f"所有页面解析完成,共解析 {len(parsed_results)} 页")
+        支持两种模式:
+        1. 分页模式(use_book_image=False):解析split_pages中的多张分页图片
+        2. 书本模式(use_book_image=True):解析book_image中的完整长图
         
-        return {
-            "parsed_results": parsed_results,
-            "processed_pages": len(parsed_results),
-            "is_complete": True
-        }
+        Args:
+            state: 包含split_pages或book_image的状态
+            
+        Returns:
+            包含parsed_results的更新字典
+        """
+        # 检查使用哪种模式
+        if self.use_book_image:
+            # 书本模式:解析完整长图
+            book_image = getattr(state, 'book_image', None)
+            if book_image is None:
+                logger.warning("book_image为空,无法解析")
+                return {"parsed_results": [], "processed_pages": 0}
+            
+            prompt_template = state.dimension_prompt or self._default_prompt_template(is_book_image=True)
+            
+            logger.info("使用书本模式解析完整长图")
+            result = self._parse_book_image(book_image, prompt_template)
+            
+            # 将结果包装成列表格式以保持兼容性
+            parsed_results = [result] if result else []
+            
+            logger.info("完整书本图片解析完成")
+            
+            return {
+                "parsed_results": parsed_results,
+                "processed_pages": 1 if parsed_results else 0,
+                "is_complete": True
+            }
+        else:
+            # 分页模式:解析多张分页图片
+            pages = getattr(state, 'split_pages', None) or getattr(state, 'image_pages', [])
+            prompt_template = state.dimension_prompt or self._default_prompt_template(is_book_image=False)
+            
+            if not pages:
+                logger.warning("没有待解析的页面")
+                return {"parsed_results": [], "processed_pages": 0}
+            
+            logger.info(f"使用分页模式,开始并行解析 {len(pages)} 页")
+            
+            parsed_results = []
+            
+            # 使用全局线程池
+            pool = ThreadPoolManager.get_pool("parser")
+            future_to_page = {
+                pool.submit(self._parse_single_page, page, prompt_template): page
+                for page in pages
+            }
+            
+            for future in concurrent.futures.as_completed(future_to_page):
+                try:
+                    result = future.result()
+                    parsed_results.append(result)
+                except Exception as e:
+                    page = future_to_page[future]
+                    logger.error(f"解析第 {page.get('page_number', '?')} 页时出错: {str(e)}")
+            
+            # 按页码排序结果
+            parsed_results.sort(key=lambda x: x.get("page_number", 0))
+            
+            logger.info(f"所有页面解析完成,共解析 {len(parsed_results)} 页")
+            
+            return {
+                "parsed_results": parsed_results,
+                "processed_pages": len(parsed_results),
+                "is_complete": True
+            }

+ 105 - 0
src/datasets/parser/nodes/picture_stitching_node.py

@@ -0,0 +1,105 @@
+"""
+图片拼接节点
+
+将拆分的PDF页面图片按页码顺序拼接成一张长图。
+"""
+
+from typing import Dict, Any
+from PIL import Image
+from src.datasets.parser.core.base import BaseNode, BaseState
+from src.datasets.parser.core.registry import register_node
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+@register_node()
+class PictureStitchingNode(BaseNode):
+    """
+    图片拼接节点
+    
+    将拆分后的PDF页面图片按页码顺序垂直拼接成一张长图。
+    
+    需要的状态字段:
+        - split_pages: 拆分后的页面列表,每个元素包含:
+            - page_number: 页码
+            - image: PIL图像对象
+            
+    更新的状态字段:
+        - book_image: 拼接后的完整书本图片
+    """
+    
+    @property
+    def name(self) -> str:
+        return "picture_stitching"
+    
+    def execute(self, state: BaseState) -> Dict[str, Any]:
+        """
+        执行图片拼接
+        
+        Args:
+            state: 包含split_pages的状态
+            
+        Returns:
+            包含book_image的更新字典
+        """
+        split_pages = getattr(state, 'split_pages', None)
+        
+        if not split_pages:
+            raise ValueError("State must contain 'split_pages' field with image data")
+        
+        if not split_pages:
+            raise ValueError("split_pages is empty, no images to stitch")
+        
+        logger.info(f"开始拼接图片,共 {len(split_pages)} 页")
+        
+        # 按页码排序
+        sorted_pages = sorted(split_pages, key=lambda x: x.get('page_number', 0))
+        
+        # 提取所有图片
+        images = []
+        for page in sorted_pages:
+            image = page.get('image')
+            if image is None:
+                logger.warning(f"页码 {page.get('page_number')} 的图片为空,跳过")
+                continue
+            if not isinstance(image, Image.Image):
+                logger.warning(f"页码 {page.get('page_number')} 的图片类型不正确: {type(image)},跳过")
+                continue
+            images.append(image)
+        
+        if not images:
+            raise ValueError("没有有效的图片可以拼接")
+        
+        logger.info(f"有效图片数量: {len(images)}")
+        
+        # 计算拼接后图片的尺寸
+        # 宽度取所有图片的最大宽度
+        max_width = max(img.width for img in images)
+        # 高度为所有图片高度之和
+        total_height = sum(img.height for img in images)
+        
+        logger.info(f"拼接后图片尺寸: {max_width}x{total_height}")
+        
+        # 创建新的空白图片
+        stitched_image = Image.new('RGB', (max_width, total_height), color='white')
+        
+        # 垂直拼接所有图片
+        current_y = 0
+        for idx, img in enumerate(images):
+            # 如果图片宽度小于最大宽度,将其居中放置
+            x_offset = (max_width - img.width) // 2
+            
+            # 将图片粘贴到目标位置
+            stitched_image.paste(img, (x_offset, current_y))
+            
+            # 更新当前y坐标
+            current_y += img.height
+            
+            logger.debug(f"已拼接第 {idx + 1}/{len(images)} 页,当前高度: {current_y}")
+        
+        logger.info(f"图片拼接完成,最终尺寸: {stitched_image.size}")
+        
+        return {
+            "book_image": stitched_image
+        }

+ 111 - 0
src/datasets/parser/nodes/picture_stitching_node_example.md

@@ -0,0 +1,111 @@
+# PictureStitchingNode 使用示例
+
+## 功能说明
+
+`PictureStitchingNode` 节点用于将 PDF 拆分后的多个页面图片按页码顺序垂直拼接成一张长图。
+
+## 输入要求
+
+节点需要 state 中包含以下字段:
+
+- `split_pages`: List[Dict] - 拆分后的页面列表,每个元素包含:
+  - `page_number`: int - 页码
+  - `image`: PIL.Image.Image - 图片对象
+
+## 输出
+
+节点会更新 state,添加以下字段:
+
+- `book_image`: PIL.Image.Image - 拼接后的完整书本图片
+
+## 使用示例
+
+### 1. 在工作流中使用
+
+```python
+from src.datasets.parser.workflows.base import WorkflowBuilder
+from src.datasets.parser.nodes.pdf_split_node import PDFSplitNode
+from src.datasets.parser.nodes.picture_stitching_node import PictureStitchingNode
+
+# 创建工作流
+workflow = WorkflowBuilder()
+
+# 添加节点
+workflow.add_node(PDFSplitNode())
+workflow.add_node(PictureStitchingNode())
+
+# 连接节点
+workflow.add_edge("pdf_split", "picture_stitching")
+workflow.set_entry_point("pdf_split")
+workflow.set_finish_point("picture_stitching")
+
+# 编译工作流
+app = workflow.compile()
+
+# 执行工作流
+result = app.invoke({
+    "pdf_path": "path/to/your.pdf",
+    "dataset_name": "my_dataset",
+    "page_dataset_id": "dataset_123"
+})
+
+# 获取拼接后的图片
+book_image = result["book_image"]
+book_image.save("stitched_book.png")
+```
+
+### 2. 单独使用节点
+
+```python
+from src.datasets.parser.nodes.picture_stitching_node import PictureStitchingNode
+from src.datasets.parser.states.parser_states import PDFParsingState
+
+# 准备状态(假设已经有拆分后的页面)
+state = PDFParsingState(
+    dataset_name="my_dataset",
+    page_dataset_id="dataset_123",
+    split_pages=[
+        {"page_number": 1, "image": image1},
+        {"page_number": 2, "image": image2},
+        {"page_number": 3, "image": image3},
+    ]
+)
+
+# 执行拼接
+node = PictureStitchingNode()
+result = node.execute(state)
+
+# 获取拼接后的图片
+book_image = result["book_image"]
+```
+
+## 特性
+
+1. **自动排序**: 节点会自动按 `page_number` 对页面进行排序,确保拼接顺序正确
+2. **居中对齐**: 如果某些页面宽度不同,较窄的页面会被居中放置
+3. **容错处理**: 自动跳过缺失或无效的图片,只拼接有效的页面
+4. **详细日志**: 提供详细的处理日志,便于调试和监控
+
+## 注意事项
+
+1. 拼接后的图片可能会很大,请注意内存使用
+2. 如果 PDF 页面很多(如超过 100 页),建议考虑分批处理
+3. 图片宽度以所有页面中的最大宽度为准
+4. 图片高度为所有页面高度之和
+
+## 性能优化建议
+
+对于大型 PDF(如超过 50 页),可以考虑:
+
+1. 在拼接前降低图片分辨率
+2. 使用分段拼接,每次处理固定数量的页面
+3. 异步处理和保存结果
+
+## 错误处理
+
+节点会在以下情况抛出异常:
+
+- `split_pages` 字段不存在或为空
+- 所有页面的图片都无效或缺失
+
+在这些情况下,请检查 PDF 拆分节点是否正常执行。

+ 205 - 29
src/datasets/parser/nodes/qa_generate_node.py

@@ -68,23 +68,41 @@ class QAGenerateNode(BaseNode):
             QA对列表
         """
         prompt = f"""
-            请根据以下文本内容,生成{count}个高质量的问答对。
-            
-            要求:
-            1. 问题应该多样化,涵盖文本的不同方面
-            2. 答案应该准确、完整,直接来源于文本
-            3. 问题应该自然,像真实用户会问的问题
-            4. 避免过于简单或过于复杂的问题
-            5. 校验必须以完整的JSON数组格式输出
+        # Role
+        你是一位深耕 0-14 岁儿童教育与少儿出版行业 的资深知识萃取专家。你具备儿童心理学、教育学(如皮亚杰认知发展理论、维果茨基最近发展区)以及全球童书出版趋势的深厚洞察力。
 
-            文本内容:
-            {chunk}
-            
-            请以JSON数组格式输出,每个元素包含"question"和"answer"字段:
-            [
-                {{"question": "问题1", "answer": "答案1"}},
-                {{"question": "问题2", "answer": "答案2"}}
-            ]
+        # Task
+        请阅读文本内容:{chunk},生成 {count} 个高质量问答对。
+
+        # Target Audience
+        0-14 岁儿童(提问语气要亲切、好奇、具有代入感)。
+
+        # Requirements
+        提问视角(For Kids):
+        - 禁止说教。使用“如果你是...”、“你猜...”、“为什么...会这样”等激发好奇心的问句。
+        - 问题要能关联孩子的生活经验或想象力。
+        
+        回答要求(Double-Layer & Slim):
+        - 基础事实 + 深度启发:先用一句话讲清事实,再用一句话点破底层逻辑或引导实践。
+        - 字数铁律:每个答案严禁超过 50 字。
+
+        扩展维度:
+        - 好奇心钩子:为什么这个知识很酷?
+        - 生活实验室:你现在可以试着做什么?
+        - 情绪/逻辑种子:这背后的道理是什么?
+
+        # Output Standards
+        格式:必须以完整的 JSON 数组格式输出。
+        - 严禁:任何多余的开场白或解释文字。
+
+        # Output Format (JSON Only)
+        JSON格式如下:
+        [
+            {{
+                "question": "(面向孩子的好奇心提问)",
+                "answer": "(事实要点+深度启发,50字以内)"
+            }}
+        ]
         """
         
         logger.debug(f"开始为第 {chunk_index + 1} 块生成QA对")
@@ -111,20 +129,178 @@ class QAGenerateNode(BaseNode):
         Returns:
             QA对列表
         """
-        try:
-            # 尝试直接解析JSON
-            return json.loads(response)
-        except json.JSONDecodeError:
-            # 尝试提取JSON部分
-            import re
-            json_match = re.search(r'\[[\s\S]*\]', response)
-            if json_match:
-                try:
-                    return json.loads(json_match.group())
-                except json.JSONDecodeError:
-                    pass
-            logger.warning(f"无法解析QA响应为JSON: {response}")
+        import re
+        
+        if not response:
             return []
+        
+        # 清理响应文本
+        cleaned_response = response.strip()
+        
+        # 移除 BOM 标记
+        if cleaned_response.startswith('\ufeff'):
+            cleaned_response = cleaned_response[1:]
+        
+        # 1. 尝试直接解析JSON
+        try:
+            return json.loads(cleaned_response)
+        except json.JSONDecodeError as e:
+            logger.debug(f"直接解析失败: {str(e)}")
+            pass
+        
+        # 1.5. 尝试使用 raw_decode 解析(可以跳过前面的非JSON文本)
+        try:
+            decoder = json.JSONDecoder()
+            result, idx = decoder.raw_decode(cleaned_response)
+            if isinstance(result, list):
+                return result
+        except (json.JSONDecodeError, ValueError) as e:
+            logger.debug(f"raw_decode 解析失败: {str(e)}")
+            pass
+        
+        # 2. 尝试去除 markdown 代码块标记
+        # 匹配 ```json ... ``` 或 ``` ... ```
+        code_block_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
+        code_block_match = re.search(code_block_pattern, cleaned_response, re.DOTALL)
+        if code_block_match:
+            try:
+                json_content = code_block_match.group(1).strip()
+                return json.loads(json_content)
+            except json.JSONDecodeError:
+                pass
+        
+        # 3. 尝试提取第一个完整的 JSON 数组
+        # 使用括号匹配算法,正确处理嵌套的 [] 和 {}
+        bracket_count = 0
+        brace_count = 0
+        start_idx = -1
+        in_string = False
+        escape_next = False
+        
+        for i, char in enumerate(cleaned_response):
+            if escape_next:
+                escape_next = False
+                continue
+            
+            if char == '\\':
+                escape_next = True
+                continue
+            
+            if char == '"' and not escape_next:
+                in_string = not in_string
+                continue
+            
+            if in_string:
+                continue
+            
+            if char == '[':
+                if start_idx == -1:
+                    start_idx = i
+                bracket_count += 1
+            elif char == ']':
+                bracket_count -= 1
+                if bracket_count == 0 and brace_count == 0 and start_idx != -1:
+                    try:
+                        json_content = cleaned_response[start_idx:i+1]
+                        return json.loads(json_content)
+                    except json.JSONDecodeError:
+                        # 继续尝试下一个匹配
+                        start_idx = -1
+                        bracket_count = 0
+                        brace_count = 0
+            elif char == '{':
+                if start_idx != -1:
+                    brace_count += 1
+            elif char == '}':
+                if start_idx != -1:
+                    brace_count -= 1
+        
+        # 4. 尝试使用正则提取 JSON 数组(更宽松的方式)
+        json_array_pattern = r'\[\s*(?:\{[^}]*\}(?:\s*,\s*\{[^}]*\})*)?\s*\]'
+        json_match = re.search(json_array_pattern, cleaned_response, re.DOTALL)
+        if json_match:
+            try:
+                return json.loads(json_match.group())
+            except json.JSONDecodeError:
+                pass
+        
+        # 5. 尝试逐行查找 JSON 数组
+        lines = cleaned_response.split('\n')
+        json_lines = []
+        in_json = False
+        bracket_count = 0
+        
+        for line in lines:
+            stripped_line = line.strip()
+            if not stripped_line:
+                continue
+            
+            # 检查是否包含 JSON 数组的开始
+            if '[' in stripped_line and not in_json:
+                in_json = True
+                json_lines = [stripped_line]
+                bracket_count = stripped_line.count('[') - stripped_line.count(']')
+            elif in_json:
+                json_lines.append(stripped_line)
+                bracket_count += stripped_line.count('[') - stripped_line.count(']')
+                
+                if bracket_count == 0:
+                    try:
+                        json_content = '\n'.join(json_lines)
+                        return json.loads(json_content)
+                    except json.JSONDecodeError:
+                        in_json = False
+                        json_lines = []
+                        bracket_count = 0
+        
+        # 如果收集到了 JSON 行但还没闭合,尝试解析
+        if json_lines:
+            try:
+                json_content = '\n'.join(json_lines)
+                return json.loads(json_content)
+            except json.JSONDecodeError:
+                pass
+        
+        # 6. 最后尝试:查找所有可能的 JSON 对象并组合成数组
+        try:
+            # 查找所有 { ... } 模式的对象
+            json_objects = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned_response, re.DOTALL)
+            if json_objects:
+                parsed_objects = []
+                for obj_str in json_objects:
+                    try:
+                        parsed_obj = json.loads(obj_str)
+                        if isinstance(parsed_obj, dict) and 'question' in parsed_obj and 'answer' in parsed_obj:
+                            parsed_objects.append(parsed_obj)
+                    except json.JSONDecodeError:
+                        continue
+                if parsed_objects:
+                    logger.info(f"通过对象提取方式解析到 {len(parsed_objects)} 个QA对")
+                    return parsed_objects
+        except Exception as e:
+            logger.debug(f"对象提取方式失败: {str(e)}")
+        
+        # 所有方法都失败
+        # 记录更详细的错误信息用于调试
+        error_info = {
+            "response_length": len(cleaned_response),
+            "first_100_chars": repr(cleaned_response[:100]),
+            "last_100_chars": repr(cleaned_response[-100:]) if len(cleaned_response) > 100 else "",
+            "has_bracket": '[' in cleaned_response,
+            "has_brace": '{' in cleaned_response,
+        }
+        logger.warning(f"无法解析QA响应为JSON: {error_info}")
+        
+        # 尝试最后一次:如果响应看起来像 JSON 数组,尝试修复常见问题
+        if cleaned_response.startswith('[') and cleaned_response.endswith(']'):
+            try:
+                # 尝试修复常见的 JSON 问题:替换中文引号
+                fixed_response = cleaned_response.replace('"', '"').replace('"', '"').replace(''', "'").replace(''', "'")
+                return json.loads(fixed_response)
+            except json.JSONDecodeError:
+                pass
+        
+        return []
     
     def execute(self, state: BaseState) -> Dict[str, Any]:
         """

+ 14 - 3
src/datasets/parser/nodes/ragflow_nodes.py

@@ -144,6 +144,10 @@ class RAGFlowDocumentUploadNode(BaseNode):
         """
         dataset_id = getattr(state, 'dataset_id', '')
         # file_path = getattr(state, 'pdf_path', '') or getattr(state, 'csv_path', '')
+        if getattr(state, 'csv_path', ''):
+            file_path = getattr(state, 'csv_path', '')
+        else:
+            file_path = getattr(state, 'pdf_path', '')
         file_path = getattr(state, 'csv_path', '') or getattr(state, 'pdf_path', '')
         
         if not dataset_id:
@@ -177,9 +181,16 @@ class RAGFlowDocumentParseNode(BaseNode):
         - dataset_id: 数据集ID
         - document_id: 文档ID
     """
-    
-    def __init__(self):
-        self.ragflow_service = RAGFlowService()
+
+    def __init__(self, api_key: Optional[str] = None):
+        """
+        初始化文档上传节点
+        
+        Args:
+            
+        """
+        self.api_key = api_key or ragflow_settings.ragflow_api_key
+        self.ragflow_service = RAGFlowService(api_key=self.api_key)
     
     @property
     def name(self) -> str:

+ 12 - 6
src/datasets/parser/states/parser_states.py

@@ -5,6 +5,7 @@
 """
 
 import io
+from PIL import Image
 from typing import List, Dict, Any, Optional
 from pydantic import Field, ConfigDict
 from src.datasets.parser.core.base import BaseState
@@ -109,6 +110,13 @@ class VectorizationMixin(BaseState):
     vectorized_count: int = Field(default=0, description="已向量化数量")
 
 
+class ParsedContent(BaseState):
+    """
+    解析内容状态
+    """
+    content: str = Field(default="", description="解析内容")
+    dimension_id: int = Field(default=0, description="维度ID")
+
 class DynamicDimensionState(BaseState):
     """
     动态维度解析状态
@@ -120,20 +128,18 @@ class DynamicDimensionState(BaseState):
     dimension_ids: List[int] = Field(default_factory=list, description="维度ID列表")
     dataset_id: str = Field(default="", description="数据集ID")
     dimension_prompt: str = Field(default="", description="维度提示词")
-    # dataset_name: str = Field(default="", description="数据集名称")
     document_id: str = Field(default="", description="文档ID")
     ragflow_id: str = Field(default="", description="RagFlow用户ID")
     rag_flow_api_key: str = Field(default="", description="RAGFlow API密钥")
-    # # RAGFlow 相关
-    # ragflow_api_url: str = Field(default="", description="RAGFlow API URL")
-    # rag_flow_api_key: str = Field(default="", description="RAGFlow API密钥")
     
     # 中间状态
     table_name: str = Field(default="", description="向量表名(由 TableNameGenerationNode 生成)")
     split_pages: List[Dict[str, Any]] = Field(default_factory=list, description="拆分后的页面列表")
     parsed_results: List[Dict[str, Any]] = Field(default_factory=list, description="解析结果列表")
-    
+    parsed_contents: List[ParsedContent] = Field(default_factory=list, description="解析内容列表")
+    book_image: Image.Image = Field(default=None, description="书本图片")
+
     # 输出 - 每个维度的结果
     dimension_results: Dict[int, Dict[str, Any]] = Field(default_factory=dict, description="每个维度的解析结果")
     total_vectorized_pages: int = Field(default=0, description="总向量化页面数")
-    is_complete: bool = Field(default=False, description="是否完成")
+    is_complete: bool = Field(default=False, description="是否完成")

+ 5 - 3
src/datasets/parser/workflow_nodes/__init__.py

@@ -1,11 +1,13 @@
 """
 工作流节点模块
 
-提供工作流特定的节点组件,如维度技能节点等。
+提供工作流特定的节点组件,如维度分页拆分节点、维度书本拆分节点等。
 """
 
-from src.datasets.parser.workflow_nodes.dimension_skill_node import DimensionSkillNode
+from src.datasets.parser.workflow_nodes.dimension_page_split_node import DimensionPageSplitNode
+from src.datasets.parser.workflow_nodes.dimension_book_split_node import DimensionBookSplitNode
 
 __all__ = [
-    "DimensionSkillNode",
+    "DimensionPageSplitNode",
+    "DimensionBookSplitNode",
 ]

+ 152 - 0
src/datasets/parser/workflow_nodes/dimension_book_split_node.py

@@ -0,0 +1,152 @@
+"""
+维度书本拆分节点
+"""
+
+from typing import Dict, Any
+from src.datasets.parser.core.base import BaseNode
+from src.datasets.parser.core.workflow_builder import WorkflowBuilder
+from src.datasets.parser.nodes import (
+    ImageParseNode, 
+    VectorizeNode, 
+    PromptRetrievalNode, 
+    TableNameGenerationNode, 
+    DimensionResultNode,
+    RAGFlowDocumentUploadNode,
+    RAGFlowChunkNode,
+    PictureStitchingNode
+)
+from src.datasets.parser.states.parser_states import DynamicDimensionState
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class DimensionBookSplitNode(BaseNode):
+    """
+    维度书本拆分节点
+    
+    单个维度的处理节点,作为子工作流的构建器和执行器,包含:
+    1. 获取维度提示词
+    2. 生成向量表名
+    3. 将分页图片拼接成完整书本图片
+    4. 使用提示词解析完整书本图片
+    5. 向量化入库
+    6. 记录维度结果
+    
+    与维度分页拆分节点的区别:
+    - 先执行图片拼接,将分页图片拼接成一张长图
+    - 使用完整书本图片进行解析,而不是分页图片
+    
+    每个维度都是独立的 LangGraph 节点,可被 Langfuse 追踪。
+    """
+    
+    def __init__(
+        self,
+        dimension_id: int,
+        model_name: str = "Qwen/Qwen3-VL-8B-Instruct",
+        max_workers: int = 5
+    ):
+        """
+        初始化维度书本拆分节点
+        
+        Args:
+            dimension_id: 维度ID
+            model_name: VL模型名称
+            max_workers: 并行处理的最大工作线程数(对于book_image模式,此参数不使用)
+        """
+        self.dimension_id = dimension_id
+        self.model_name = model_name
+        self.max_workers = max_workers
+    
+    @property
+    def name(self) -> str:
+        """节点名称,格式: book_split_dim_{id}"""
+        return f"book_split_dim_{self.dimension_id}"
+    
+    def _build_sub_workflow(self, state):
+        """
+        构建子工作流
+        
+        Args:
+            state: 状态
+            
+        Returns:
+            编译后的 LangGraph 工作流
+        """
+        logger.info(f"[维度书本拆分-{self.dimension_id}] 开始构建子工作流")
+        rag_flow_api_key = getattr(state, 'rag_flow_api_key', '')
+
+        # 创建工作流构建器
+        builder = WorkflowBuilder(DynamicDimensionState)
+        
+        # 创建节点
+        prompt_node = PromptRetrievalNode(self.dimension_id)
+        document_upload_node = RAGFlowDocumentUploadNode(api_key=rag_flow_api_key)
+        table_name_node = TableNameGenerationNode(self.dimension_id)
+        stitching_node = PictureStitchingNode()
+        parse_node = ImageParseNode(
+            model_name=self.model_name,
+            max_workers=self.max_workers,
+            use_book_image=True  # 标记使用book_image模式
+        )
+        chunk_node = RAGFlowChunkNode(api_key=rag_flow_api_key)
+        vectorize_node = VectorizeNode()
+        result_node = DimensionResultNode(self.dimension_id)
+        
+        # 添加节点
+        builder.add_nodes(
+            prompt_node,
+            document_upload_node,
+            table_name_node,
+            stitching_node,
+            parse_node,
+            chunk_node,
+            vectorize_node,
+            result_node
+        )
+        
+        # 设置边
+        builder.set_entry(prompt_node.name)
+        builder.add_edge(prompt_node.name, document_upload_node.name)
+        builder.add_edge(document_upload_node.name, table_name_node.name)
+        builder.add_edge(table_name_node.name, stitching_node.name)
+        builder.add_edge(stitching_node.name, parse_node.name)
+        builder.add_edge(parse_node.name, chunk_node.name)
+        builder.add_edge(chunk_node.name, vectorize_node.name)
+        builder.add_edge(vectorize_node.name, result_node.name)
+        builder.set_finish(result_node.name)
+        
+        # 构建并返回工作流
+        workflow = builder.build()
+        logger.info(f"[维度书本拆分-{self.dimension_id}] 子工作流构建完成")
+        return workflow
+    
+    def execute(self, state: DynamicDimensionState) -> Dict[str, Any]:
+        """
+        执行维度书本拆分
+        
+        构建并执行子工作流,包含以下步骤:
+        1. 获取提示词
+        2. 生成向量表名
+        3. 将分页图片拼接成完整书本图片
+        4. 解析完整书本图片
+        5. 向量化入库
+        6. 记录维度结果
+        """
+        logger.info(f"[维度书本拆分-{self.dimension_id}] 开始执行维度书本拆分")
+        
+        # 构建子工作流
+        workflow = self._build_sub_workflow(state)
+        
+        # 执行子工作流
+        result = workflow.invoke(state)
+        
+        # 处理结果
+        if isinstance(result, dict):
+            final_result = result
+        else:
+            final_result = result.dict() if hasattr(result, 'dict') else dict(result)
+        
+        logger.info(f"[维度书本拆分-{self.dimension_id}] 维度书本拆分执行完成")
+        
+        return final_result

+ 11 - 11
src/datasets/parser/workflow_nodes/dimension_skill_node.py → src/datasets/parser/workflow_nodes/dimension_page_split_node.py

@@ -1,5 +1,5 @@
 """
-维度技能节点
+维度分页拆分节点
 """
 
 from typing import Dict, Any
@@ -20,9 +20,9 @@ from src.common.logging_config import get_logger
 logger = get_logger(__name__)
 
 
-class DimensionSkillNode(BaseNode):
+class DimensionPageSplitNode(BaseNode):
     """
-    维度技能节点
+    维度分页拆分节点
     
     单个维度的处理节点,作为子工作流的构建器和执行器,包含:
     1. 获取维度提示词
@@ -41,7 +41,7 @@ class DimensionSkillNode(BaseNode):
         max_workers: int = 5
     ):
         """
-        初始化维度技能节点
+        初始化维度分页拆分节点
         
         Args:
             dimension_id: 维度ID
@@ -54,8 +54,8 @@ class DimensionSkillNode(BaseNode):
     
     @property
     def name(self) -> str:
-        """节点名称,格式: skill_dim_{id}"""
-        return f"skill_dim_{self.dimension_id}"
+        """节点名称,格式: page_split_dim_{id}"""
+        return f"page_split_dim_{self.dimension_id}"
     
     def _build_sub_workflow(self, state):
         """
@@ -67,7 +67,7 @@ class DimensionSkillNode(BaseNode):
         Returns:
             编译后的 LangGraph 工作流
         """
-        logger.info(f"[Skill-{self.dimension_id}] 开始构建子工作流")
+        logger.info(f"[维度分页拆分-{self.dimension_id}] 开始构建子工作流")
         rag_flow_api_key = getattr(state, 'rag_flow_api_key', '')
 
 
@@ -109,12 +109,12 @@ class DimensionSkillNode(BaseNode):
         
         # 构建并返回工作流
         workflow = builder.build()
-        logger.info(f"[Skill-{self.dimension_id}] 子工作流构建完成")
+        logger.info(f"[维度分页拆分-{self.dimension_id}] 子工作流构建完成")
         return workflow
     
     def execute(self, state: DynamicDimensionState) -> Dict[str, Any]:
         """
-        执行维度技能
+        执行维度分页拆分
         
         构建并执行子工作流,包含以下步骤:
         1. 获取提示词
@@ -123,7 +123,7 @@ class DimensionSkillNode(BaseNode):
         4. 向量化入库
         5. 记录维度结果
         """
-        logger.info(f"[Skill-{self.dimension_id}] 开始执行维度技能")
+        logger.info(f"[维度分页拆分-{self.dimension_id}] 开始执行维度分页拆分")
         
         # 构建子工作流
         workflow = self._build_sub_workflow(state)
@@ -137,6 +137,6 @@ class DimensionSkillNode(BaseNode):
         else:
             final_result = result.dict() if hasattr(result, 'dict') else dict(result)
         
-        logger.info(f"[Skill-{self.dimension_id}] 维度技能执行完成")
+        logger.info(f"[维度分页拆分-{self.dimension_id}] 维度分页拆分执行完成")
         
         return final_result

+ 6 - 3
src/datasets/parser/workflows/dynamic_dimension_workflow.py

@@ -18,7 +18,7 @@ from src.datasets.parser.nodes import (
     PDFSplitNode,
     CompleteNode
 )
-from src.datasets.parser.workflow_nodes import DimensionSkillNode
+from src.datasets.parser.workflow_nodes import DimensionPageSplitNode
 from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
 from src.common.logging_config import get_logger
 
@@ -72,6 +72,8 @@ class DynamicDimensionWorkflow:
         split_node = PDFSplitNode()
         complete_node = CompleteNode(message_template="动态多维度解析完成")
         
+        # 查询维度知识库对应的decomposition method
+
         # 使用 WorkflowBuilder 构建
         builder = WorkflowBuilder(DynamicDimensionState)
         
@@ -88,8 +90,9 @@ class DynamicDimensionWorkflow:
         prev_node = "pdf_split"
         
         for dim_id in dimension_ids:
-            # 为每个维度创建技能节点
-            skill_node = DimensionSkillNode(
+            
+            # 为每个维度创建分页拆分节点
+            skill_node = DimensionPageSplitNode(
                 dimension_id=dim_id,
                 model_name=self.model_name,
                 max_workers=self.max_workers

+ 7 - 4
src/datasets/parser/workflows/pdf_workflow.py

@@ -21,6 +21,7 @@ from src.datasets.parser.nodes import (
     CompleteNode,
 )
 from src.datasets.parser.nodes.ragflow_nodes import DatasetExistsCondition
+from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
 from src.common.logging_config import get_logger
 
 logger = get_logger(__name__)
@@ -58,16 +59,18 @@ class PDFParsingWorkflowV2:
     
     def _build_workflow(self):
         """构建工作流"""
+        ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
+        ragflow_api_key = ragflow_user.get("api_key")
         # 创建节点实例
         dataset_node = RAGFlowDatasetNode(create_if_not_exists=True)
         dataset_condition = DatasetExistsCondition()
-        upload_node = RAGFlowDocumentUploadNode()
-        page_upload_node = RAGFlowDocumentUploadNode()
-        parse_doc_node = RAGFlowDocumentParseNode()
+        upload_node = RAGFlowDocumentUploadNode(api_key=ragflow_api_key)
+        page_upload_node = RAGFlowDocumentUploadNode(api_key=ragflow_api_key)
+        parse_doc_node = RAGFlowDocumentParseNode(api_key=ragflow_api_key)
         split_node = PDFSplitNode()
         image_parse_node = ImageParseNode(model_name=self.model_name)
         vectorize_node = VectorizeNode()
-        chunk_node = RAGFlowChunkNode(api_key="ragflow-jpVLczgZ7_WkXbve59p1TS_wm3BvNXrcyhmoBFCAjR0")
+        chunk_node = RAGFlowChunkNode(api_key=ragflow_api_key)
         complete_node = CompleteNode(message_template="PDF解析工作流完成")
         
         # 使用WorkflowBuilder构建

+ 7 - 3
src/datasets/parser/workflows/qa_workflow.py

@@ -10,6 +10,7 @@ from langfuse.langchain import CallbackHandler
 
 from src.datasets.parser.core.workflow_builder import WorkflowBuilder
 from src.datasets.parser.states.parser_states import QAParsingState
+from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
 from src.datasets.parser.nodes import (
     TextSplitNode,
     QAGenerateNode,
@@ -57,13 +58,16 @@ class QAParsingWorkflowV2:
     
     def _build_workflow(self):
         """构建工作流"""
+        ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
+        ragflow_api_key = ragflow_user.get("api_key")
+
         # 创建节点实例
         ocr_node = PDFOCRNode(model_name=self.model_name)
         split_node = TextSplitNode()
         qa_node = QAGenerateNode(model_name=self.model_name)
         export_node = ExportCSVNode()
-        upload_node = RAGFlowDocumentUploadNode(target_field="document_id")
-        parse_node = RAGFlowDocumentParseNode()
+        upload_node = RAGFlowDocumentUploadNode(api_key=ragflow_api_key)
+        parse_node = RAGFlowDocumentParseNode(api_key=ragflow_api_key)
         complete_node = CompleteNode(message_template="QA解析工作流完成")
         
         # 使用WorkflowBuilder构建
@@ -116,7 +120,7 @@ class QAParsingWorkflowV2:
             包含最终状态的字典
         """
         logger.info(f"开始运行QA解析工作流: {pdf_path}")
-        
+
         initial_state = QAParsingState(
             pdf_path=pdf_path,
             dataset_id=dataset_id,

+ 2 - 2
src/model/multimodal_embedding.py

@@ -41,7 +41,7 @@ class Embedding:
         try:
             # 使用langchain_openai初始化OpenAI Embeddings模型
             embeddings = OpenAIEmbeddings(
-                model=self.model_name,
+                model=self.embedding_model_name,
                 base_url=self.base_url,
                 api_key=self.api_key
             )
@@ -64,7 +64,7 @@ class Embedding:
         try:
             # 使用langchain_openai初始化OpenAI Embeddings模型
             embeddings = OpenAIEmbeddings(
-                model=self.model_name,
+                model=self.embedding_model_name,
                 base_url=self.base_url,
                 api_key=self.api_key
             )

+ 188 - 0
src/utils/mysql/ARCHITECTURE.md

@@ -0,0 +1,188 @@
+# MySQL 数据库查询层架构设计
+
+## 架构概览
+
+```
+┌─────────────────────────────────────────────────────────┐
+│                     Service Layer                       │
+│  (业务服务层,如 PromptService, RagflowUserService)     │
+└──────────────────────┬──────────────────────────────────┘
+                       │
+┌──────────────────────▼──────────────────────────────────┐
+│                  Repository Layer                        │
+│  ┌──────────────────────────────────────────────────┐  │
+│  │  BaseRepository (抽象接口)                        │  │
+│  │  - find_by_id()                                   │  │
+│  │  - find_one() / find_all()                        │  │
+│  │  - insert() / update() / delete()                │  │
+│  │  - count() / paginate()                           │  │
+│  └──────────────────┬────────────────────────────────┘  │
+│                      │                                   │
+│  ┌───────────────────▼──────────────────────────────┐  │
+│  │  MySQLRepository (MySQL 实现)                    │  │
+│  │  - 实现所有 BaseRepository 方法                  │  │
+│  │  - 使用 QueryBuilder 构建 SQL                     │  │
+│  └───────────────────┬──────────────────────────────┘  │
+│                      │                                   │
+│  ┌───────────────────▼──────────────────────────────┐  │
+│  │  具体业务 Repository                              │  │
+│  │  - PromptDimensionRepository                     │  │
+│  │  - RagflowUserRepository                         │  │
+│  │  - 封装业务特定的查询方法                         │  │
+│  └──────────────────────────────────────────────────┘  │
+└──────────────────────┬──────────────────────────────────┘
+                       │
+┌──────────────────────▼──────────────────────────────────┐
+│                  Query Builder                          │
+│  ┌──────────────────────────────────────────────────┐  │
+│  │  QueryBuilder                                    │  │
+│  │  - where() / where_eq() / where_in()             │  │
+│  │  - order_by() / limit() / offset()               │  │
+│  │  - join() / group_by() / having()                │  │
+│  │  - build_select() / build_update() / build_delete()││
+│  └──────────────────┬───────────────────────────────┘  │
+└──────────────────────┬──────────────────────────────────┘
+                       │
+┌──────────────────────▼──────────────────────────────────┐
+│              Connection Layer                           │
+│  ┌──────────────────────────────────────────────────┐  │
+│  │  MySQLConnection                                 │  │
+│  │  - execute() / fetch_one() / fetch_all()         │  │
+│  │  - bulk_insert()                                 │  │
+│  │  - 事务支持                                       │  │
+│  └──────────────────┬───────────────────────────────┘  │
+│                      │                                   │
+│  ┌───────────────────▼──────────────────────────────┐  │
+│  │  MySQLPool                                       │  │
+│  │  - 连接池管理                                     │  │
+│  │  - DBUtils PooledDB                              │  │
+│  └──────────────────────────────────────────────────┘  │
+└─────────────────────────────────────────────────────────┘
+```
+
+## 核心组件说明
+
+### 1. BaseRepository (抽象层)
+- **位置**: `src/utils/mysql/base_repository.py`
+- **作用**: 定义统一的数据库操作接口
+- **优势**: 
+  - 支持多种数据库实现(MySQL, PostgreSQL 等)
+  - 便于单元测试(可 Mock)
+  - 统一的 API 设计
+
+### 2. QueryBuilder (查询构建器)
+- **位置**: `src/utils/mysql/query_builder.py`
+- **作用**: 提供链式 API 构建 SQL 语句
+- **特性**:
+  - 类型安全的查询构建
+  - 防止 SQL 注入(参数化查询)
+  - 支持复杂查询(JOIN, GROUP BY 等)
+
+### 3. MySQLRepository (实现层)
+- **位置**: `src/utils/mysql/mysql_repository.py`
+- **作用**: 实现 BaseRepository 接口,提供 MySQL 具体实现
+- **功能**:
+  - 基础 CRUD 操作
+  - 分页查询
+  - 批量操作
+  - 自定义 SQL 执行
+
+### 4. 业务 Repository (业务层)
+- **位置**: `src/api/db/repositories/`
+- **作用**: 封装业务特定的数据库操作
+- **示例**:
+  - `PromptDimensionRepository`: 提示词维度相关操作
+  - `RagflowUserRepository`: RAGFlow 用户相关操作
+
+## 设计原则
+
+### 1. 单一职责原则
+- 每个 Repository 只负责一个表的操作
+- QueryBuilder 只负责 SQL 构建
+- MySQLConnection 只负责数据库连接和执行
+
+### 2. 开闭原则
+- BaseRepository 对扩展开放,对修改关闭
+- 可以轻松添加新的 Repository 实现
+
+### 3. 依赖倒置原则
+- Service 层依赖 Repository 抽象,而非具体实现
+- 便于替换数据库实现
+
+### 4. 接口隔离原则
+- BaseRepository 提供最小必要接口
+- 具体 Repository 可以扩展额外方法
+
+## 使用场景
+
+### 场景 1: 简单 CRUD
+```python
+repo = get_repository("prompt_dimensions")
+dimension = repo.find_by_id(1)
+```
+
+### 场景 2: 复杂查询
+```python
+repo = PromptDimensionRepository()
+dimensions = repo.get_dimensions_with_active_prompts()
+```
+
+### 场景 3: 分页查询
+```python
+result = repo.paginate(conditions={"status": 1}, page=1, page_size=10)
+```
+
+### 场景 4: 批量操作
+```python
+repo.insert_many([{...}, {...}, {...}])
+```
+
+## 迁移路径
+
+### 阶段 1: 保持兼容
+- 保留原有的 `MySQLConnection` 接口
+- 新代码使用 Repository 模式
+- 旧代码逐步迁移
+
+### 阶段 2: 创建业务 Repository
+- 为每个表创建专门的 Repository
+- 封装业务特定的查询逻辑
+
+### 阶段 3: 重构 Service 层
+- Service 层使用 Repository 替代直接 SQL
+- 提高代码可维护性
+
+## 性能考虑
+
+1. **连接池**: 使用 DBUtils 连接池,避免频繁创建连接
+2. **批量操作**: 支持批量插入,提高性能
+3. **查询优化**: QueryBuilder 生成的 SQL 可以进一步优化
+4. **索引建议**: 为常用查询字段建立索引
+
+## 扩展性
+
+### 支持其他数据库
+只需实现 BaseRepository 接口:
+
+```python
+class PostgreSQLRepository(BaseRepository):
+    """PostgreSQL Repository 实现"""
+    # 实现所有抽象方法
+```
+
+### 添加新功能
+- 在 MySQLRepository 中添加扩展方法
+- 在具体业务 Repository 中封装业务逻辑
+
+## 测试支持
+
+### Mock Repository
+```python
+from unittest.mock import Mock
+
+mock_repo = Mock(spec=BaseRepository)
+mock_repo.find_by_id.return_value = {"id": 1, "name": "test"}
+```
+
+### 集成测试
+使用测试数据库,创建真实的 Repository 实例进行测试。

+ 160 - 0
src/utils/mysql/README.md

@@ -0,0 +1,160 @@
+# MySQL 数据库查询层使用指南
+
+## 架构设计
+
+本模块提供了工程化的 MySQL 数据库查询层,采用 Repository 模式和查询构建器模式:
+
+```
+BaseRepository (抽象接口)
+    ↓
+MySQLRepository (MySQL 实现)
+    ↓
+具体业务 Repository (如 PromptDimensionRepository)
+```
+
+## 核心组件
+
+### 1. BaseRepository
+抽象基类,定义统一的 CRUD 操作接口。
+
+### 2. QueryBuilder
+链式查询构建器,支持:
+- WHERE 条件构建
+- ORDER BY 排序
+- LIMIT/OFFSET 分页
+- JOIN 关联查询
+- GROUP BY 分组
+
+### 3. MySQLRepository
+MySQL 具体实现,提供:
+- 基础 CRUD 操作
+- 分页查询
+- 批量操作
+- 自定义 SQL 执行
+
+## 使用方式
+
+### 方式一:直接使用 MySQLRepository
+
+```python
+from src.utils.mysql import get_repository
+
+# 获取 Repository 实例
+repo = get_repository("prompt_dimensions")
+
+# 根据 ID 查询
+dimension = repo.find_by_id(1)
+
+# 条件查询
+dimensions = repo.find_all({"status": 1}, order_by="created_at DESC")
+
+# 分页查询
+result = repo.paginate(conditions={"status": 1}, page=1, page_size=10)
+# 返回: {"total": 100, "items": [...], "page": 1, "page_size": 10, "total_pages": 10}
+
+# 插入
+new_id = repo.insert({"name": "新维度", "description": "描述"})
+
+# 更新
+repo.update_by_id(1, {"description": "新描述"})
+
+# 删除
+repo.delete_by_id(1)
+```
+
+### 方式二:创建自定义 Repository
+
+```python
+from src.utils.mysql import MySQLRepository
+
+class PromptDimensionRepository(MySQLRepository):
+    """提示词维度 Repository"""
+    
+    def __init__(self, db_client=None):
+        super().__init__("prompt_dimensions", db_client)
+    
+    def find_by_name(self, name: str):
+        """根据名称查找维度"""
+        return self.find_one({"name": name})
+    
+    def find_active_dimensions(self):
+        """查找所有激活的维度"""
+        return self.find_all({"is_active": 1}, order_by="created_at DESC")
+    
+    def get_dimensions_with_prompts(self):
+        """关联查询维度和提示词"""
+        builder = self._create_query_builder()
+        builder.select([
+            "pd.id", "pd.name", "pd.description",
+            "pv.content as prompt_content", "pv.version_number"
+        ])
+        builder.join(
+            "prompt_versions pv",
+            "pd.id = pv.dimension_id AND pv.is_active = 1",
+            "LEFT"
+        )
+        sql, params = builder.build_select()
+        return self.execute_query(sql, params)
+
+# 使用
+repo = PromptDimensionRepository()
+dimension = repo.find_by_name("知识实体与百科拆解")
+```
+
+### 方式三:使用 QueryBuilder 构建复杂查询
+
+```python
+from src.utils.mysql import QueryBuilder, get_global_mysql_client
+
+db = get_global_mysql_client()
+
+# 构建复杂查询
+builder = QueryBuilder("prompt_dimensions")
+builder.select(["id", "name", "description"])
+builder.where_eq("status", 1)
+builder.where_like("name", "%维度%")
+builder.where_in("id", [1, 2, 3, 4])
+builder.order_by("created_at", "DESC")
+builder.limit(10)
+builder.offset(0)
+
+sql, params = builder.build_select()
+results = db.fetch_all(sql, params)
+```
+
+## 迁移指南
+
+### 旧代码(直接使用 MySQLConnection)
+
+```python
+from src.utils.mysql import get_global_mysql_client
+
+db = get_global_mysql_client()
+sql = "SELECT * FROM prompt_dimensions WHERE id = %s"
+result = db.fetch_one(sql, [1])
+```
+
+### 新代码(使用 Repository)
+
+```python
+from src.utils.mysql import get_repository
+
+repo = get_repository("prompt_dimensions")
+result = repo.find_by_id(1)
+```
+
+## 优势
+
+1. **类型安全**:统一的接口定义,减少错误
+2. **代码复用**:通用 CRUD 操作无需重复编写
+3. **易于测试**:可以轻松 Mock BaseRepository
+4. **可扩展**:支持自定义 Repository 扩展
+5. **链式查询**:QueryBuilder 提供流畅的查询体验
+6. **向后兼容**:保留原有的 MySQLConnection 接口
+
+## 最佳实践
+
+1. **为每个表创建专门的 Repository**:封装业务逻辑
+2. **使用 QueryBuilder 处理复杂查询**:提高可读性
+3. **在 Service 层使用 Repository**:保持代码分层清晰
+4. **利用分页查询**:避免一次性加载大量数据

+ 23 - 1
src/utils/mysql/__init__.py

@@ -1,4 +1,7 @@
 from .mysql_conn import MySQLConnection
+from .mysql_repository import MySQLRepository
+from .base_repository import BaseRepository
+from .query_builder import QueryBuilder
 
 
 # 全局 MySQL 客户端实例
@@ -49,4 +52,23 @@ def close_global_mysql_client() -> None:
     global _global_mysql_client
     if _global_mysql_client is not None:
         _global_mysql_client.close()
-        _global_mysql_client = None
+        _global_mysql_client = None
+
+
+def get_repository(table_name: str, db_client: MySQLConnection = None) -> MySQLRepository:
+    """
+    获取指定表的 Repository 实例
+    
+    Args:
+        table_name: 表名
+        db_client: MySQL 连接客户端,如果为 None 则使用全局客户端
+        
+    Returns:
+        MySQLRepository 实例
+        
+    Example:
+        >>> repo = get_repository("prompt_dimensions")
+        >>> dimension = repo.find_by_id(1)
+        >>> dimensions = repo.find_all({"status": 1})
+    """
+    return MySQLRepository(table_name, db_client)

+ 198 - 0
src/utils/mysql/base_repository.py

@@ -0,0 +1,198 @@
+"""
+数据库 Repository 抽象基类
+
+定义统一的数据库操作接口,支持多种数据库实现。
+"""
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Union
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class BaseRepository(ABC):
+    """
+    数据库 Repository 抽象基类
+    
+    定义统一的 CRUD 操作接口,子类需要实现具体的数据库操作。
+    """
+    
+    def __init__(self, table_name: str):
+        """
+        初始化 Repository
+        
+        Args:
+            table_name: 表名
+        """
+        self.table_name = table_name
+    
+    @abstractmethod
+    def find_by_id(self, id: Union[int, str]) -> Optional[Dict[str, Any]]:
+        """
+        根据 ID 查询单条记录
+        
+        Args:
+            id: 主键ID
+            
+        Returns:
+            查询结果字典,不存在返回 None
+        """
+        pass
+    
+    @abstractmethod
+    def find_one(self, conditions: Dict[str, Any] = None, order_by: str = None) -> Optional[Dict[str, Any]]:
+        """
+        查询单条记录
+        
+        Args:
+            conditions: 查询条件字典,如 {"name": "value", "status": 1}
+            order_by: 排序字段,如 "created_at DESC"
+            
+        Returns:
+            查询结果字典,不存在返回 None
+        """
+        pass
+    
+    @abstractmethod
+    def find_all(
+        self, 
+        conditions: Dict[str, Any] = None,
+        order_by: str = None,
+        limit: int = None,
+        offset: int = None
+    ) -> List[Dict[str, Any]]:
+        """
+        查询多条记录
+        
+        Args:
+            conditions: 查询条件字典
+            order_by: 排序字段
+            limit: 限制返回数量
+            offset: 偏移量
+            
+        Returns:
+            查询结果列表
+        """
+        pass
+    
+    @abstractmethod
+    def count(self, conditions: Dict[str, Any] = None) -> int:
+        """
+        统计记录数量
+        
+        Args:
+            conditions: 查询条件字典
+            
+        Returns:
+            记录数量
+        """
+        pass
+    
+    @abstractmethod
+    def insert(self, data: Dict[str, Any]) -> int:
+        """
+        插入单条记录
+        
+        Args:
+            data: 要插入的数据字典
+            
+        Returns:
+            受影响的行数
+        """
+        pass
+    
+    @abstractmethod
+    def insert_many(self, data_list: List[Dict[str, Any]]) -> int:
+        """
+        批量插入记录
+        
+        Args:
+            data_list: 要插入的数据列表
+            
+        Returns:
+            受影响的行数
+        """
+        pass
+    
+    @abstractmethod
+    def update(self, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
+        """
+        更新记录
+        
+        Args:
+            conditions: 更新条件字典
+            data: 要更新的数据字典
+            
+        Returns:
+            受影响的行数
+        """
+        pass
+    
+    @abstractmethod
+    def update_by_id(self, id: Union[int, str], data: Dict[str, Any]) -> int:
+        """
+        根据 ID 更新记录
+        
+        Args:
+            id: 主键ID
+            data: 要更新的数据字典
+            
+        Returns:
+            受影响的行数
+        """
+        pass
+    
+    @abstractmethod
+    def delete(self, conditions: Dict[str, Any]) -> int:
+        """
+        删除记录
+        
+        Args:
+            conditions: 删除条件字典
+            
+        Returns:
+            受影响的行数
+        """
+        pass
+    
+    @abstractmethod
+    def delete_by_id(self, id: Union[int, str]) -> int:
+        """
+        根据 ID 删除记录
+        
+        Args:
+            id: 主键ID
+            
+        Returns:
+            受影响的行数
+        """
+        pass
+    
+    @abstractmethod
+    def execute_query(self, sql: str, params: Union[List, Dict] = None) -> List[Dict[str, Any]]:
+        """
+        执行自定义查询
+        
+        Args:
+            sql: SQL 查询语句
+            params: SQL 参数
+            
+        Returns:
+            查询结果列表
+        """
+        pass
+    
+    @abstractmethod
+    def execute_update(self, sql: str, params: Union[List, Dict] = None) -> int:
+        """
+        执行自定义更新语句
+        
+        Args:
+            sql: SQL 更新语句
+            params: SQL 参数
+            
+        Returns:
+            受影响的行数
+        """
+        pass

+ 306 - 0
src/utils/mysql/mysql_repository.py

@@ -0,0 +1,306 @@
+"""
+MySQL Repository 实现
+
+基于 MySQLConnection 实现 BaseRepository 接口,提供统一的数据库操作。
+"""
+
+from typing import Any, Dict, List, Optional, Union
+from .base_repository import BaseRepository
+from .query_builder import QueryBuilder
+from .mysql_conn import MySQLConnection
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class MySQLRepository(BaseRepository):
+    """
+    MySQL Repository 实现
+    
+    基于 MySQLConnection 和 QueryBuilder 提供统一的数据库操作接口。
+    """
+    
+    def __init__(self, table_name: str, db_client: Optional[MySQLConnection] = None):
+        """
+        初始化 MySQL Repository
+        
+        Args:
+            table_name: 表名
+            db_client: MySQL 连接客户端,如果为 None 则使用全局客户端
+        """
+        super().__init__(table_name)
+        if db_client is None:
+            from . import get_global_mysql_client
+            self._db = get_global_mysql_client()
+        else:
+            self._db = db_client
+    
+    def _create_query_builder(self) -> QueryBuilder:
+        """创建查询构建器"""
+        return QueryBuilder(self.table_name)
+    
+    def find_by_id(self, id: Union[int, str], id_field: str = "id") -> Optional[Dict[str, Any]]:
+        """
+        根据 ID 查询单条记录
+        
+        Args:
+            id: 主键ID
+            id_field: ID 字段名,默认为 "id"
+            
+        Returns:
+            查询结果字典,不存在返回 None
+        """
+        builder = self._create_query_builder()
+        sql, params = builder.where_eq(id_field, id).build_select()
+        return self._db.fetch_one(sql, params)
+    
+    def find_one(self, conditions: Dict[str, Any] = None, order_by: str = None) -> Optional[Dict[str, Any]]:
+        """
+        查询单条记录
+        
+        Args:
+            conditions: 查询条件字典,如 {"name": "value", "status": 1}
+            order_by: 排序字段,如 "created_at DESC"
+            
+        Returns:
+            查询结果字典,不存在返回 None
+        """
+        builder = self._create_query_builder()
+        if conditions:
+            builder.where_conditions(conditions)
+        if order_by:
+            # 解析 order_by 字符串,如 "created_at DESC"
+            parts = order_by.strip().split()
+            if len(parts) == 2:
+                builder.order_by(parts[0], parts[1])
+            else:
+                builder.order_by(parts[0])
+        builder.limit(1)
+        sql, params = builder.build_select()
+        return self._db.fetch_one(sql, params)
+    
+    def find_all(
+        self, 
+        conditions: Dict[str, Any] = None,
+        order_by: str = None,
+        limit: int = None,
+        offset: int = None
+    ) -> List[Dict[str, Any]]:
+        """
+        查询多条记录
+        
+        Args:
+            conditions: 查询条件字典
+            order_by: 排序字段
+            limit: 限制返回数量
+            offset: 偏移量
+            
+        Returns:
+            查询结果列表
+        """
+        builder = self._create_query_builder()
+        if conditions:
+            builder.where_conditions(conditions)
+        if order_by:
+            parts = order_by.strip().split()
+            if len(parts) == 2:
+                builder.order_by(parts[0], parts[1])
+            else:
+                builder.order_by(parts[0])
+        if limit is not None:
+            builder.limit(limit)
+        if offset is not None:
+            builder.offset(offset)
+        sql, params = builder.build_select()
+        return self._db.fetch_all(sql, params)
+    
+    def count(self, conditions: Dict[str, Any] = None) -> int:
+        """
+        统计记录数量
+        
+        Args:
+            conditions: 查询条件字典
+            
+        Returns:
+            记录数量
+        """
+        builder = self._create_query_builder()
+        if conditions:
+            builder.where_conditions(conditions)
+        sql, params = builder.build_count()
+        result = self._db.fetch_one(sql, params)
+        return result['count'] if result else 0
+    
+    def insert(self, data: Dict[str, Any]) -> int:
+        """
+        插入单条记录
+        
+        Args:
+            data: 要插入的数据字典
+            
+        Returns:
+            受影响的行数
+        """
+        builder = self._create_query_builder()
+        sql, params = builder.build_insert(data)
+        return self._db.execute(sql, params)
+    
+    def insert_many(self, data_list: List[Dict[str, Any]]) -> int:
+        """
+        批量插入记录
+        
+        Args:
+            data_list: 要插入的数据列表
+            
+        Returns:
+            受影响的行数
+        """
+        if not data_list:
+            return 0
+        builder = self._create_query_builder()
+        sql, params = builder.build_bulk_insert(data_list)
+        return self._db.bulk_insert(sql, params)
+    
+    def update(self, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
+        """
+        更新记录
+        
+        Args:
+            conditions: 更新条件字典
+            data: 要更新的数据字典
+            
+        Returns:
+            受影响的行数
+        """
+        builder = self._create_query_builder()
+        sql, params = builder.build_update(data, conditions)
+        return self._db.execute(sql, params)
+    
+    def update_by_id(self, id: Union[int, str], data: Dict[str, Any], id_field: str = "id") -> int:
+        """
+        根据 ID 更新记录
+        
+        Args:
+            id: 主键ID
+            data: 要更新的数据字典
+            id_field: ID 字段名,默认为 "id"
+            
+        Returns:
+            受影响的行数
+        """
+        return self.update({id_field: id}, data)
+    
+    def delete(self, conditions: Dict[str, Any]) -> int:
+        """
+        删除记录
+        
+        Args:
+            conditions: 删除条件字典
+            
+        Returns:
+            受影响的行数
+        """
+        builder = self._create_query_builder()
+        sql, params = builder.build_delete(conditions)
+        return self._db.execute(sql, params)
+    
+    def delete_by_id(self, id: Union[int, str], id_field: str = "id") -> int:
+        """
+        根据 ID 删除记录
+        
+        Args:
+            id: 主键ID
+            id_field: ID 字段名,默认为 "id"
+            
+        Returns:
+            受影响的行数
+        """
+        return self.delete({id_field: id})
+    
+    def execute_query(self, sql: str, params: Union[List, Dict] = None) -> List[Dict[str, Any]]:
+        """
+        执行自定义查询
+        
+        Args:
+            sql: SQL 查询语句
+            params: SQL 参数
+            
+        Returns:
+            查询结果列表
+        """
+        return self._db.fetch_all(sql, params)
+    
+    def execute_update(self, sql: str, params: Union[List, Dict] = None) -> int:
+        """
+        执行自定义更新语句
+        
+        Args:
+            sql: SQL 更新语句
+            params: SQL 参数
+            
+        Returns:
+            受影响的行数
+        """
+        return self._db.execute(sql, params)
+    
+    # ========== 扩展方法 ==========
+    
+    def exists(self, conditions: Dict[str, Any]) -> bool:
+        """
+        检查记录是否存在
+        
+        Args:
+            conditions: 查询条件字典
+            
+        Returns:
+            存在返回 True,否则返回 False
+        """
+        return self.count(conditions) > 0
+    
+    def find_by_ids(self, ids: List[Union[int, str]], id_field: str = "id") -> List[Dict[str, Any]]:
+        """
+        根据 ID 列表批量查询
+        
+        Args:
+            ids: ID 列表
+            id_field: ID 字段名,默认为 "id"
+            
+        Returns:
+            查询结果列表
+        """
+        if not ids:
+            return []
+        builder = self._create_query_builder()
+        sql, params = builder.where_in(id_field, ids).build_select()
+        return self._db.fetch_all(sql, params)
+    
+    def paginate(
+        self,
+        conditions: Dict[str, Any] = None,
+        order_by: str = None,
+        page: int = 1,
+        page_size: int = 10
+    ) -> Dict[str, Any]:
+        """
+        分页查询
+        
+        Args:
+            conditions: 查询条件字典
+            order_by: 排序字段
+            page: 页码(从1开始)
+            page_size: 每页数量
+            
+        Returns:
+            包含 total, items, page, page_size 的字典
+        """
+        total = self.count(conditions)
+        offset = (page - 1) * page_size
+        items = self.find_all(conditions=conditions, order_by=order_by, limit=page_size, offset=offset)
+        
+        return {
+            "total": total,
+            "items": items,
+            "page": page,
+            "page_size": page_size,
+            "total_pages": (total + page_size - 1) // page_size
+        }

+ 442 - 0
src/utils/mysql/query_builder.py

@@ -0,0 +1,442 @@
+"""
+SQL 查询构建器
+
+提供链式 API 构建 SQL 查询语句。
+"""
+
+from typing import Any, Dict, List, Optional, Union
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class QueryBuilder:
+    """
+    SQL 查询构建器
+    
+    提供链式 API 构建 SQL 查询语句,支持:
+    - WHERE 条件构建
+    - ORDER BY 排序
+    - LIMIT/OFFSET 分页
+    - JOIN 关联查询
+    """
+    
+    def __init__(self, table_name: str):
+        """
+        初始化查询构建器
+        
+        Args:
+            table_name: 表名
+        """
+        self.table_name = table_name
+        self._select_fields = ["*"]
+        self._where_conditions = []
+        self._where_params = []
+        self._order_by = None
+        self._limit_value = None
+        self._offset_value = None
+        self._joins = []
+        self._group_by = None
+        self._having = None
+    
+    def select(self, fields: Union[str, List[str]]) -> 'QueryBuilder':
+        """
+        指定查询字段
+        
+        Args:
+            fields: 字段名或字段列表,如 "id, name" 或 ["id", "name"]
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        if isinstance(fields, str):
+            self._select_fields = [f.strip() for f in fields.split(",")]
+        else:
+            self._select_fields = fields
+        return self
+    
+    def where(self, field: str, operator: str, value: Any) -> 'QueryBuilder':
+        """
+        添加 WHERE 条件
+        
+        Args:
+            field: 字段名
+            operator: 操作符,如 "=", ">", "<", "LIKE", "IN", "BETWEEN" 等
+            value: 条件值
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+            
+        Examples:
+            >>> builder.where("id", "=", 1)
+            >>> builder.where("age", ">", 18)
+            >>> builder.where("name", "LIKE", "%test%")
+            >>> builder.where("status", "IN", [1, 2, 3])
+        """
+        if operator.upper() == "IN":
+            if not isinstance(value, (list, tuple)):
+                raise ValueError("IN 操作符的值必须是列表或元组")
+            placeholders = ",".join(["%s"] * len(value))
+            self._where_conditions.append(f"{field} IN ({placeholders})")
+            self._where_params.extend(value)
+        elif operator.upper() == "BETWEEN":
+            if not isinstance(value, (list, tuple)) or len(value) != 2:
+                raise ValueError("BETWEEN 操作符的值必须是包含两个元素的列表或元组")
+            self._where_conditions.append(f"{field} BETWEEN %s AND %s")
+            self._where_params.extend(value)
+        elif operator.upper() == "IS NULL":
+            self._where_conditions.append(f"{field} IS NULL")
+        elif operator.upper() == "IS NOT NULL":
+            self._where_conditions.append(f"{field} IS NOT NULL")
+        else:
+            self._where_conditions.append(f"{field} {operator} %s")
+            self._where_params.append(value)
+        return self
+    
+    def where_eq(self, field: str, value: Any) -> 'QueryBuilder':
+        """等于条件(便捷方法)"""
+        return self.where(field, "=", value)
+    
+    def where_ne(self, field: str, value: Any) -> 'QueryBuilder':
+        """不等于条件(便捷方法)"""
+        return self.where(field, "!=", value)
+    
+    def where_gt(self, field: str, value: Any) -> 'QueryBuilder':
+        """大于条件(便捷方法)"""
+        return self.where(field, ">", value)
+    
+    def where_gte(self, field: str, value: Any) -> 'QueryBuilder':
+        """大于等于条件(便捷方法)"""
+        return self.where(field, ">=", value)
+    
+    def where_lt(self, field: str, value: Any) -> 'QueryBuilder':
+        """小于条件(便捷方法)"""
+        return self.where(field, "<", value)
+    
+    def where_lte(self, field: str, value: Any) -> 'QueryBuilder':
+        """小于等于条件(便捷方法)"""
+        return self.where(field, "<=", value)
+    
+    def where_like(self, field: str, value: str) -> 'QueryBuilder':
+        """LIKE 条件(便捷方法)"""
+        return self.where(field, "LIKE", value)
+    
+    def where_in(self, field: str, values: List[Any]) -> 'QueryBuilder':
+        """IN 条件(便捷方法)"""
+        return self.where(field, "IN", values)
+    
+    def where_between(self, field: str, start: Any, end: Any) -> 'QueryBuilder':
+        """BETWEEN 条件(便捷方法)"""
+        return self.where(field, "BETWEEN", [start, end])
+    
+    def where_conditions(self, conditions: Dict[str, Any]) -> 'QueryBuilder':
+        """
+        批量添加 WHERE 条件
+        
+        Args:
+            conditions: 条件字典,如 {"id": 1, "status": "active"}
+                       默认使用 = 操作符
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        for field, value in conditions.items():
+            self.where_eq(field, value)
+        return self
+    
+    def order_by(self, field: str, direction: str = "ASC") -> 'QueryBuilder':
+        """
+        添加排序
+        
+        Args:
+            field: 排序字段
+            direction: 排序方向,ASC 或 DESC
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        if direction.upper() not in ["ASC", "DESC"]:
+            raise ValueError("排序方向必须是 ASC 或 DESC")
+        self._order_by = f"{field} {direction.upper()}"
+        return self
+    
+    def limit(self, count: int) -> 'QueryBuilder':
+        """
+        限制返回数量
+        
+        Args:
+            count: 返回数量
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        self._limit_value = count
+        return self
+    
+    def offset(self, count: int) -> 'QueryBuilder':
+        """
+        设置偏移量
+        
+        Args:
+            count: 偏移量
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        self._offset_value = count
+        return self
+    
+    def join(self, table: str, on: str, join_type: str = "INNER") -> 'QueryBuilder':
+        """
+        添加 JOIN
+        
+        Args:
+            table: 关联表名
+            on: JOIN 条件,如 "table1.id = table2.foreign_id"
+            join_type: JOIN 类型,INNER, LEFT, RIGHT, FULL
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        join_type = join_type.upper()
+        if join_type not in ["INNER", "LEFT", "RIGHT", "FULL"]:
+            raise ValueError("JOIN 类型必须是 INNER, LEFT, RIGHT 或 FULL")
+        self._joins.append(f"{join_type} JOIN {table} ON {on}")
+        return self
+    
+    def group_by(self, field: str) -> 'QueryBuilder':
+        """
+        添加 GROUP BY
+        
+        Args:
+            field: 分组字段
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        self._group_by = field
+        return self
+    
+    def having(self, condition: str) -> 'QueryBuilder':
+        """
+        添加 HAVING 条件
+        
+        Args:
+            condition: HAVING 条件,如 "COUNT(*) > 10"
+            
+        Returns:
+            QueryBuilder 实例(支持链式调用)
+        """
+        self._having = condition
+        return self
+    
+    def build_select(self) -> tuple[str, List[Any]]:
+        """
+        构建 SELECT 查询语句
+        
+        Returns:
+            (SQL 语句, 参数列表) 元组
+        """
+        # 构建 SELECT 字段
+        select_clause = ", ".join(self._select_fields)
+        
+        # 构建 FROM 子句
+        from_clause = f"FROM {self.table_name}"
+        
+        # 构建 JOIN 子句
+        join_clause = ""
+        if self._joins:
+            join_clause = " " + " ".join(self._joins)
+        
+        # 构建 WHERE 子句
+        where_clause = ""
+        if self._where_conditions:
+            where_clause = " WHERE " + " AND ".join(self._where_conditions)
+        
+        # 构建 GROUP BY 子句
+        group_by_clause = ""
+        if self._group_by:
+            group_by_clause = f" GROUP BY {self._group_by}"
+        
+        # 构建 HAVING 子句
+        having_clause = ""
+        if self._having:
+            having_clause = f" HAVING {self._having}"
+        
+        # 构建 ORDER BY 子句
+        order_by_clause = ""
+        if self._order_by:
+            order_by_clause = f" ORDER BY {self._order_by}"
+        
+        # 构建 LIMIT 子句
+        limit_clause = ""
+        if self._limit_value is not None:
+            limit_clause = f" LIMIT {self._limit_value}"
+            if self._offset_value is not None:
+                limit_clause += f" OFFSET {self._offset_value}"
+        
+        # 组合 SQL
+        sql = f"SELECT {select_clause} {from_clause}{join_clause}{where_clause}{group_by_clause}{having_clause}{order_by_clause}{limit_clause}"
+        
+        return sql.strip(), self._where_params
+    
+    def build_count(self) -> tuple[str, List[Any]]:
+        """
+        构建 COUNT 查询语句
+        
+        Returns:
+            (SQL 语句, 参数列表) 元组
+        """
+        # 构建 FROM 子句
+        from_clause = f"FROM {self.table_name}"
+        
+        # 构建 JOIN 子句
+        join_clause = ""
+        if self._joins:
+            join_clause = " " + " ".join(self._joins)
+        
+        # 构建 WHERE 子句
+        where_clause = ""
+        if self._where_conditions:
+            where_clause = " WHERE " + " AND ".join(self._where_conditions)
+        
+        # 构建 GROUP BY 子句
+        group_by_clause = ""
+        if self._group_by:
+            group_by_clause = f" GROUP BY {self._group_by}"
+        
+        # 构建 HAVING 子句
+        having_clause = ""
+        if self._having:
+            having_clause = f" HAVING {self._having}"
+        
+        sql = f"SELECT COUNT(*) as count {from_clause}{join_clause}{where_clause}{group_by_clause}{having_clause}"
+        
+        return sql.strip(), self._where_params
+    
+    def build_insert(self, data: Dict[str, Any]) -> tuple[str, List[Any]]:
+        """
+        构建 INSERT 语句
+        
+        Args:
+            data: 要插入的数据字典
+            
+        Returns:
+            (SQL 语句, 参数列表) 元组
+        """
+        if not data:
+            raise ValueError("插入数据不能为空")
+        
+        fields = list(data.keys())
+        placeholders = ", ".join(["%s"] * len(fields))
+        fields_str = ", ".join(fields)
+        
+        sql = f"INSERT INTO {self.table_name} ({fields_str}) VALUES ({placeholders})"
+        params = list(data.values())
+        
+        return sql, params
+    
+    def build_bulk_insert(self, data_list: List[Dict[str, Any]]) -> tuple[str, List[Any]]:
+        """
+        构建批量 INSERT 语句
+        
+        Args:
+            data_list: 要插入的数据列表
+            
+        Returns:
+            (SQL 语句, 参数列表) 元组
+        """
+        if not data_list:
+            raise ValueError("插入数据列表不能为空")
+        
+        # 使用第一个字典的键作为字段名
+        fields = list(data_list[0].keys())
+        fields_str = ", ".join(fields)
+        
+        # 构建 VALUES 子句
+        values_list = []
+        params = []
+        for data in data_list:
+            if set(data.keys()) != set(fields):
+                raise ValueError("批量插入的数据字典字段必须一致")
+            placeholders = ", ".join(["%s"] * len(fields))
+            values_list.append(f"({placeholders})")
+            params.extend([data[field] for field in fields])
+        
+        values_str = ", ".join(values_list)
+        sql = f"INSERT INTO {self.table_name} ({fields_str}) VALUES {values_str}"
+        
+        return sql, params
+    
+    def build_update(self, data: Dict[str, Any], conditions: Dict[str, Any] = None) -> tuple[str, List[Any]]:
+        """
+        构建 UPDATE 语句
+        
+        Args:
+            data: 要更新的数据字典
+            conditions: 更新条件字典(如果提供,会添加到 WHERE 子句)
+            
+        Returns:
+            (SQL 语句, 参数列表) 元组
+        """
+        if not data:
+            raise ValueError("更新数据不能为空")
+        
+        # 构建 SET 子句
+        set_clauses = []
+        params = []
+        for field, value in data.items():
+            set_clauses.append(f"{field} = %s")
+            params.append(value)
+        
+        set_clause = ", ".join(set_clauses)
+        
+        # 构建 WHERE 子句
+        where_clause = ""
+        if conditions:
+            where_conditions = []
+            for field, value in conditions.items():
+                where_conditions.append(f"{field} = %s")
+                params.append(value)
+            where_clause = " WHERE " + " AND ".join(where_conditions)
+        elif self._where_conditions:
+            where_clause = " WHERE " + " AND ".join(self._where_conditions)
+            params.extend(self._where_params)
+        
+        if not where_clause:
+            raise ValueError("UPDATE 语句必须包含 WHERE 条件")
+        
+        sql = f"UPDATE {self.table_name} SET {set_clause}{where_clause}"
+        
+        return sql, params
+    
+    def build_delete(self, conditions: Dict[str, Any] = None) -> tuple[str, List[Any]]:
+        """
+        构建 DELETE 语句
+        
+        Args:
+            conditions: 删除条件字典(如果提供,会添加到 WHERE 子句)
+            
+        Returns:
+            (SQL 语句, 参数列表) 元组
+        """
+        params = []
+        
+        # 构建 WHERE 子句
+        where_clause = ""
+        if conditions:
+            where_conditions = []
+            for field, value in conditions.items():
+                where_conditions.append(f"{field} = %s")
+                params.append(value)
+            where_clause = " WHERE " + " AND ".join(where_conditions)
+        elif self._where_conditions:
+            where_clause = " WHERE " + " AND ".join(self._where_conditions)
+            params.extend(self._where_params)
+        
+        if not where_clause:
+            raise ValueError("DELETE 语句必须包含 WHERE 条件")
+        
+        sql = f"DELETE FROM {self.table_name}{where_clause}"
+        
+        return sql, params

+ 54 - 41
src/utils/vector_db/elasticsearch_adapter.py

@@ -6,7 +6,7 @@ Elasticsearch 向量数据库适配器
 
 import re
 import threading
-from typing import Dict, Any, List, Optional, Tuple
+from typing import Dict, Any, List, Optional, Tuple, Union, Sequence
 from .base import VectorDBClient
 from .result_util import UnifiedSearchResult, build_unified_result
 from src.conf.settings import es_settings, vector_db_settings
@@ -14,6 +14,10 @@ from src.common.logging_config import get_logger
 
 logger = get_logger(__name__)
 
+# 模块加载时的标记,用于确认代码已重新加载
+import datetime
+logger.info(f"=== ElasticsearchAdapter 模块已加载 (ES 8.x API) - {datetime.datetime.now()} ===")
+
 
 class ElasticsearchAdapter(VectorDBClient):
     """
@@ -101,15 +105,13 @@ class ElasticsearchAdapter(VectorDBClient):
         matching_text = query.get("matching_text") or ""
         
         # match 的 query 不能为 None,否则 ES 报 VALUE_NULL
-        es_query = {
-            "query": {
-                "match": {match_field: matching_text}
-            },
-            "size": query.get("topn", 10),
-            "_source": output_fields
-        }
-        
-        result = self._client.search(index=index_name, body=es_query)
+        # ES 8.x 使用关键字参数而非 body
+        result = self._client.search(
+            index=index_name,
+            query={"match": {match_field: matching_text}},
+            size=query.get("topn", 10),
+            _source=output_fields
+        )
         return self._convert_result(result, output_fields)
     
     def vector_search(
@@ -127,14 +129,21 @@ class ElasticsearchAdapter(VectorDBClient):
         index_name = table_name
         vector_field = query.get("vector_field", "dense_vector")
         query_vector = query.get("query_vector", [])
+        # 确保 query_vector 是 Python 原生列表,兼容 numpy 数组等类型
+        if query_vector is not None and len(query_vector) > 0:
+            # 强制转换为 Python 原生 float 列表
+            query_vector = [float(x) for x in query_vector]
+        else:
+            query_vector = []
         topn = query.get("topn", 10)
         
         # 获取阈值
         knn_params = query.get("knn_params", {})
         threshold = float(knn_params.get("threshold", 0.0))
         
-        # 构建 ES knn 查询
-        es_query = {
+        # 构建 ES knn 查询(ES 8.x 使用关键字参数而非 body)
+        search_params = {
+            "index": index_name,
             "knn": {
                 "field": vector_field,
                 "query_vector": query_vector,
@@ -146,16 +155,16 @@ class ElasticsearchAdapter(VectorDBClient):
         
         # 如果有阈值,添加 min_score
         if threshold > 0:
-            es_query["min_score"] = threshold
+            search_params["min_score"] = threshold
         
-        result = self._client.search(index=index_name, body=es_query)
+        result = self._client.search(**search_params)
         return self._convert_result(result, output_fields)
     
     def hybrid_search(
         self,
-        table_name: str,
         output_fields: List[str],
         query: Dict[str, Any],
+        table_name: Optional[Union[str, Sequence[str]]] = None,
         database_name: Optional[str] = None
     ) -> Any:
         """
@@ -163,9 +172,17 @@ class ElasticsearchAdapter(VectorDBClient):
         
         使用 ES 的 bool 查询结合 knn 和 match。
         """
+        # index_name =  "ragbook_1_f3c87b89f82711f0b7450242ac180002,ragbook_2_16db7236f82911f093e20242ac180002"
         index_name = table_name
         vector_field = query.get("vector_field", "dense_vector")
         query_vector = query.get("query_vector", [])
+        # 确保 query_vector 是 Python 原生列表,兼容 numpy 数组等类型
+        if query_vector is not None and len(query_vector) > 0:
+            # 强制转换为 Python 原生 float 列表
+            query_vector = [float(x) for x in query_vector]
+            logger.info(f"混合搜索: 已转换 query_vector 类型, 长度={len(query_vector)}, 前3个元素={query_vector[:3]}")
+        else:
+            query_vector = []
         match_field = query.get("match_field", "content")
         matching_text = query.get("matching_text")
         if matching_text is not None and not isinstance(matching_text, str):
@@ -186,25 +203,24 @@ class ElasticsearchAdapter(VectorDBClient):
         else:
             text_clause = {"match_all": {"boost": 1.0}}
         
-        # 构建混合查询
-        es_query = {
-            "query": {
+        # 构建混合查询(ES 8.x 使用关键字参数而非 body)
+        result = self._client.search(
+            index=index_name,
+            query={
                 "bool": {
                     "should": [text_clause]
                 }
             },
-            "knn": {
+            knn={
                 "field": vector_field,
                 "query_vector": query_vector,
                 "k": topn,
                 "num_candidates": topn * 10,
                 "boost": 1.0
             },
-            "size": topn,
-            "_source": output_fields
-        }
-        
-        result = self._client.search(index=index_name, body=es_query)
+            size=topn,
+            _source=output_fields
+        )
         return self._convert_result(result, output_fields)
     
     def insert(
@@ -216,14 +232,14 @@ class ElasticsearchAdapter(VectorDBClient):
         """插入文档"""
         index_name = table_name
         
-        # 批量插入
+        # 批量插入(ES 8.x 使用 operations 而非 body)
         operations = []
         for doc in documents:
             operations.append({"index": {"_index": index_name}})
             operations.append(doc)
         
         if operations:
-            result = self._client.bulk(body=operations, refresh=True)
+            result = self._client.bulk(operations=operations, refresh=True)
             return result
         return None
     
@@ -264,10 +280,11 @@ class ElasticsearchAdapter(VectorDBClient):
             if field in ['id', '_id']:
                 # 使用update API直接更新指定id的文档
                 try:
+                    # ES 8.x 使用 doc 而非 body
                     result = self._client.update(
                         index=index_name,
                         id=value,
-                        body={"doc": data}
+                        doc=data
                     )
                     logger.info(f"update 执行: index={index_name}, id={value}, data={data}")
                     return result
@@ -280,18 +297,14 @@ class ElasticsearchAdapter(VectorDBClient):
             # 复杂条件,使用query_string
             es_query_clause = {"query_string": {"query": cond or "*"}}
         
-        # 构建update_by_query请求
-        es_query = {
-            "query": es_query_clause,
-            "script": {
-                "source": "; ".join([f"ctx._source.{k} = params.{k}" for k in data.keys()]),
-                "params": data
-            }
-        }
-        
+        # 构建update_by_query请求(ES 8.x 使用关键字参数而非 body)
         result = self._client.update_by_query(
             index=index_name,
-            body=es_query,
+            query=es_query_clause,
+            script={
+                "source": "; ".join([f"ctx._source.{k} = params.{k}" for k in data.keys()]),
+                "params": data
+            },
             conflicts="proceed"
         )
         # 打印实际执行的查询
@@ -359,11 +372,11 @@ class ElasticsearchAdapter(VectorDBClient):
         mappings: Dict[str, Any],
         settings: Optional[Dict[str, Any]] = None
     ):
-        """创建索引"""
-        body = {"mappings": mappings}
+        """创建索引(ES 8.x 使用关键字参数而非 body)"""
+        create_params = {"index": index_name, "mappings": mappings}
         if settings:
-            body["settings"] = settings
-        return self._client.indices.create(index=index_name, body=body, ignore=400)
+            create_params["settings"] = settings
+        return self._client.indices.create(**create_params, ignore=400)
     
     def delete_index(self, index_name: str):
         """删除索引"""