Selaa lähdekoodia

增加kb_id查询

yingge 3 kuukautta sitten
vanhempi
commit
62397e7c90

+ 4 - 1
src/api/dataset/services/dify_knowledge_service.py

@@ -39,7 +39,9 @@ class DifyKnowledgeService:
         try:
             if request.knowledge_id:
                 # table_name = f"{vector_db_settings.infinity_dataset_prefix}{request.knowledge_id}"
-                table_name = self.table_name
+                # table_name = self.table_name
+                # knowledge_id为 ragflow_1_2,将其拆分为table_name=ragflow_1, kb_id=2
+                table_name, kb_id = request.knowledge_id.split('_')
             else:
                 # 抛出异常
                 raise Exception("knowledge_id不能为空")
@@ -79,6 +81,7 @@ class DifyKnowledgeService:
                 "vector_field": self.vector_field,
                 "query_vector": query_vector,
                 "topn": retrieval_setting.top_k,
+                "kb_id": kb_id,
                 "knn_params": {
                     "ef": str(retrieval_setting.top_k * 10),
                     "threshold": str(retrieval_setting.score_threshold)

+ 15 - 2
src/api/db/models/search_request_models.py

@@ -46,7 +46,12 @@ class SearchRequest(BaseModel):
         le=1.0,
         description="相似度分数阈值"
     )
-    
+    kb_id: Optional[str] = Field(
+        default=None,
+        description="知识库id",
+        examples=["kb_id1", "kb_id2"]
+    )
+
     @field_validator('image_url')
     @classmethod
     def validate_image_url(cls, v: Optional[str]) -> Optional[str]:
@@ -66,4 +71,12 @@ class SearchRequest(BaseModel):
     def model_post_init(self, __context) -> None:
         """模型初始化后验证:确保至少提供一个搜索条件"""
         if self.image_url is None and self.matching_text is None:
-            raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")
+            raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")
+
+    @field_validator('kb_id')
+    @classmethod
+    def validate_kb_id(cls, v: Optional[str]) -> Optional[str]:
+        """验证知识库id"""
+        if v is not None and v.strip() == "":
+            return None
+        return v

+ 1 - 0
src/api/db/services/vector_search_service.py

@@ -139,6 +139,7 @@ class VectorSearchService:
                 "vector_field": self.vector_field,
                 "query_vector": query_vector,
                 "topn": request.topk,
+                "kb_id": request.kb_id,
                 "knn_params": {
                     "ef": str(request.topk * 10),
                     "threshold": str(request.score_threshold) if request.score_threshold else "0"

+ 91 - 1
src/api/sdk/dataset_manage.py

@@ -314,12 +314,102 @@ async def get_task_status(task_id: str):
     获取任务状态
     
     - **task_id**: 任务ID
+    
+    返回内容包含:
+    - 任务基本信息(id, name, status, position等)
+    - 时间信息(created_at, started_at, completed_at)
+    - 工作流执行过程详情(workflow_details)
+        - 处理进度信息
+        - 维度处理结果
+        - 统计信息
+        - 其他执行状态
+    - 错误信息(如果任务失败)
     """
     task_queue = get_task_queue()
     status = task_queue.get_task_status(task_id)
     if not status:
         return Result.error(code=404, message=f"任务不存在: {task_id}")
-    return Result.success(data=status, message="获取任务状态成功")
+    
+    # 提取工作流执行过程详情
+    workflow_details = None
+    if status.get("result"):
+        result = status["result"]
+        if isinstance(result, dict):
+            # 提取维度结果摘要(避免返回过大的数据)
+            dimension_results_summary = {}
+            if result.get("dimension_results"):
+                for dim_id, dim_result in result.get("dimension_results", {}).items():
+                    dimension_results_summary[dim_id] = {
+                        "success": dim_result.get("success", False),
+                        "error": dim_result.get("error"),
+                        "processed_pages": dim_result.get("processed_pages", 0),
+                        "vectorized_pages": dim_result.get("vectorized_pages", 0),
+                        "table_name": dim_result.get("table_name", ""),
+                    }
+            
+            # 定义需要排除的键
+            excluded_keys = {
+                "processed_pages", "vectorized_pages", "total_pages", "is_complete",
+                "dimension_results", "dimensions_processed", "dimensions_success", 
+                "total_vectorized_pages", "qa_count", "qa_pairs", "text_chunks",
+                "parsed_images", "total_images", "dataset_id", "document_id",
+                "page_dataset_id", "page_document_id", "success", "split_pages",
+                "parsed_results", "vectorized_results"
+            }
+            
+            # 提取其他状态信息(排除已处理的字段和大列表)
+            other_info = {}
+            for k, v in result.items():
+                if k not in excluded_keys:
+                    # 只包含小型的列表和字典,或非列表/字典类型
+                    if not isinstance(v, (list, dict)):
+                        other_info[k] = v
+                    elif isinstance(v, (list, dict)) and len(str(v)) < 500:
+                        other_info[k] = v
+            
+            workflow_details = {
+                # 通用进度信息
+                "processed_pages": result.get("processed_pages", 0),
+                "vectorized_pages": result.get("vectorized_pages", 0),
+                "total_pages": result.get("total_pages", 0),
+                "is_complete": result.get("is_complete", False),
+                
+                # 维度相关结果(动态多维度解析)
+                "dimension_results": dimension_results_summary,
+                "dimensions_processed": result.get("dimensions_processed", 0),
+                "dimensions_success": result.get("dimensions_success", 0),
+                "total_vectorized_pages": result.get("total_vectorized_pages", 0),
+                
+                # QA相关结果(只返回统计信息,不返回完整列表)
+                "qa_count": result.get("qa_count", 0),
+                "qa_pairs_count": len(result.get("qa_pairs", [])),
+                "text_chunks_count": len(result.get("text_chunks", [])),
+                
+                # 图片解析相关
+                "parsed_images_count": len(result.get("parsed_images", [])),
+                "total_images": result.get("total_images", 0),
+                
+                # 数据集和文档信息
+                "dataset_id": result.get("dataset_id", ""),
+                "document_id": result.get("document_id", ""),
+                "page_dataset_id": result.get("page_dataset_id", ""),
+                "page_document_id": result.get("page_document_id", ""),
+                
+                # 成功状态
+                "success": result.get("success", False),
+            }
+            
+            # 如果有其他信息,添加到详情中
+            if other_info:
+                workflow_details["other_info"] = other_info
+    
+    # 构建响应数据
+    response_data = {
+        **status,
+        "workflow_details": workflow_details
+    }
+    
+    return Result.success(data=response_data, message="获取任务状态成功")
 
 
 @app.get("/v2/queue")

+ 29 - 0
src/api/sdk/search_infinity.py

@@ -93,6 +93,35 @@ async def question_search(request: SearchRequest):
         if not ragflow_id:
             return Result.error(code=500, message="未找到ragflow_id")
         question_table_name = f"{ragflow_settings.ragflow_dataset_prefix}_{ragflow_id}"
+        request.kb_id = "2b0ac35df80e11f096160242ac180002"
+        request.database_ids = [question_table_name]
+        result = await run_in_threadpool(search_service.hybrid_search, request)
+        return Result.success(data=result, message="问答对检索成功")
+    except Exception as e:
+        return Result.error(code=500, message=f"问答对检索失败: {str(e)}")
+
+@app.post("/resource")
+async def question_search(request: SearchRequest):
+    """
+    问答对检索接口
+
+    - **output_fields**: 要返回的字段列表
+    - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
+    - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
+    """
+    try:
+        output_fields = ["content_with_weight"]
+        search_service = VectorSearchService(
+            client=get_vector_db_client(database="ragflow_db"), 
+            vector_field="q_1024_vec",
+            output_fields=output_fields
+        )
+        ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
+        ragflow_id = ragflow_user.get("ragflow_id")
+        if not ragflow_id:
+            return Result.error(code=500, message="未找到ragflow_id")
+        question_table_name = f"{ragflow_settings.ragflow_dataset_prefix}_{ragflow_id}"
+        request.kb_id = "f59936fbfb1f11f095280242ac180002"
         request.database_ids = [question_table_name]
         result = await run_in_threadpool(search_service.hybrid_search, request)
         return Result.success(data=result, message="问答对检索成功")

+ 3 - 3
src/datasets/parser/nodes/image_parse_node.py

@@ -12,7 +12,7 @@ from src.model.qwen_vl import QWenVLParser
 from src.conf.settings import model_settings
 from src.common.logging_config import get_logger
 from src.utils.async_utils import ThreadPoolManager
-from src.utils.markdown_utils import parse_markdown_json
+from src.utils.json_utils import parse_json_response
 
 logger = get_logger(__name__)
 
@@ -109,7 +109,7 @@ class ImageParseNode(BaseNode):
         try:
             parser = QWenVLParser(self.model_name)
             result = parser.parse_image(image, page_number, prompt)
-            
+            parsed_content = parse_json_response(result, expected_type=dict)
             # 处理markdown格式的JSON标签
             # parsed_content = parse_markdown_json(result)
             # if parsed_content:
@@ -119,7 +119,7 @@ class ImageParseNode(BaseNode):
             #     logger.warning(f"解析JSON内容失败,保留原始内容")
             
             logger.debug(f"第 {page_number} 页解析完成")
-            return result
+            return parsed_content
         except Exception as e:
             logger.error(f"解析第 {page_number} 页时出错: {str(e)}")
             return {

+ 24 - 208
src/datasets/parser/nodes/qa_generate_node.py

@@ -13,6 +13,7 @@ from src.model.qwen_vl import QWenVLParser
 from src.conf.settings import model_settings
 from src.common.logging_config import get_logger
 from src.utils.async_utils import ThreadPoolManager
+from src.utils.json_utils import parse_qa_response
 
 logger = get_logger(__name__)
 
@@ -69,39 +70,36 @@ class QAGenerateNode(BaseNode):
         """
         prompt = f"""
         # Role
-        你是一位深耕 0-14 岁儿童教育与少儿出版行业 的资深知识萃取专家。你具备儿童心理学、教育学(如皮亚杰认知发展理论、维果茨基最近发展区)以及全球童书出版趋势的深厚洞察力。
-
+        你是一位拥有“魔法大脑”的儿童教育专家。你深知皮亚杰的认知阶段理论,能把深奥的行业文本转化为 0-14 岁孩子爱听、爱想、爱做的趣味问答。
         # Task
-        请阅读文本内容:{chunk},生成 {count} 个高质量问答对。
-
-        # Target Audience
-        0-14 岁儿童(提问语气要亲切、好奇、具有代入感)。
+        阅读文本 {chunk},从中萃取 {count} 个问答对。请严格遵守以下**“提问配方”**:
+        【40% 内容捕快】:针对文本中的基础事实提问。
+            - 目标:让孩子像玩侦探游戏一样找回文中的关键信息。
+        【30% 脑洞探险家】:基于文本底层逻辑进行高阶提问。
+            - 目标:深挖“为什么”、分析“优劣势”或“实际怎么做”,启发孩子思考事物的本质。
+        【30% 知识桥梁家】:跳出文本,向科普、艺术、教育或社会素养进行跨界延伸。
+            - 目标:将书本知识与现实世界(STEAM、审美、SEL)连接。
 
-        # Requirements
-        提问视角(For Kids):
-        - 禁止说教。使用“如果你是...”、“你猜...”、“为什么...会这样”等激发好奇心的问句。
-        - 问题要能关联孩子的生活经验或想象力。
-        
-        回答要求(Double-Layer & Slim):
-        - 基础事实 + 深度启发:先用一句话讲清事实,再用一句话点破底层逻辑或引导实践。
-        - 字数铁律:每个答案严禁超过 50 字。
+        # Rules for Interaction
+        提问视角(For Kids Only):
+        - 使用“猜猜看”、“如果你是...”、“你有没有发现...”等语气。
+        - 禁令:严禁出现“根据文中描述”、“请分析...”等成人化、考试化的术语。
 
-        扩展维度:
-        - 好奇心钩子:为什么这个知识很酷?
-        - 生活实验室:你现在可以试着做什么?
-        - 情绪/逻辑种子:这背后的道理是什么?
+        # Answer炼金术(Double-Layer & Slim):
+        - 结构:1句事实要点 + 1句思维火花(点破逻辑、情绪种子或生活实验建议)。
+        - 字数:每个回答必须在 50 字以内,语言生动、干脆。
 
         # Output Standards
-        格式:必须以完整的 JSON 数组格式输出
-        - 严禁:任何多余的开场白或解释文字
+        - 格式:仅输出完整的 JSON 数组
+        - 严禁:任何多余的开场白、分类标题或总结
 
         # Output Format (JSON Only)
         JSON格式如下:
         [
-            {{
-                "question": "(面向孩子的好奇心提问)",
-                "answer": "(事实要点+深度启发,50字以内)"
-            }}
+          {{
+            "question": "(面向孩子的好奇心提问)",
+            "answer": "(事实+启发,50字以内)"
+          }}
         ]
         """
         
@@ -111,196 +109,14 @@ class QAGenerateNode(BaseNode):
             parser = QWenVLParser(self.model_name)
             result = parser.chat(prompt)
             
-            # 解析JSON响应
-            qa_pairs = self._parse_qa_response(result)
+            # 解析JSON响应(使用通用工具函数)
+            qa_pairs = parse_qa_response(result)
             logger.debug(f"第 {chunk_index + 1} 块生成 {len(qa_pairs)} 个QA对")
             return qa_pairs
         except Exception as e:
             logger.error(f"第 {chunk_index + 1} 块QA生成失败: {str(e)}")
             return []
     
-    def _parse_qa_response(self, response: str) -> List[Dict[str, str]]:
-        """
-        解析QA响应
-        
-        Args:
-            response: 模型响应文本
-            
-        Returns:
-            QA对列表
-        """
-        import re
-        
-        if not response:
-            return []
-        
-        # 清理响应文本
-        cleaned_response = response.strip()
-        
-        # 移除 BOM 标记
-        if cleaned_response.startswith('\ufeff'):
-            cleaned_response = cleaned_response[1:]
-        
-        # 1. 尝试直接解析JSON
-        try:
-            return json.loads(cleaned_response)
-        except json.JSONDecodeError as e:
-            logger.debug(f"直接解析失败: {str(e)}")
-            pass
-        
-        # 1.5. 尝试使用 raw_decode 解析(可以跳过前面的非JSON文本)
-        try:
-            decoder = json.JSONDecoder()
-            result, idx = decoder.raw_decode(cleaned_response)
-            if isinstance(result, list):
-                return result
-        except (json.JSONDecodeError, ValueError) as e:
-            logger.debug(f"raw_decode 解析失败: {str(e)}")
-            pass
-        
-        # 2. 尝试去除 markdown 代码块标记
-        # 匹配 ```json ... ``` 或 ``` ... ```
-        code_block_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
-        code_block_match = re.search(code_block_pattern, cleaned_response, re.DOTALL)
-        if code_block_match:
-            try:
-                json_content = code_block_match.group(1).strip()
-                return json.loads(json_content)
-            except json.JSONDecodeError:
-                pass
-        
-        # 3. 尝试提取第一个完整的 JSON 数组
-        # 使用括号匹配算法,正确处理嵌套的 [] 和 {}
-        bracket_count = 0
-        brace_count = 0
-        start_idx = -1
-        in_string = False
-        escape_next = False
-        
-        for i, char in enumerate(cleaned_response):
-            if escape_next:
-                escape_next = False
-                continue
-            
-            if char == '\\':
-                escape_next = True
-                continue
-            
-            if char == '"' and not escape_next:
-                in_string = not in_string
-                continue
-            
-            if in_string:
-                continue
-            
-            if char == '[':
-                if start_idx == -1:
-                    start_idx = i
-                bracket_count += 1
-            elif char == ']':
-                bracket_count -= 1
-                if bracket_count == 0 and brace_count == 0 and start_idx != -1:
-                    try:
-                        json_content = cleaned_response[start_idx:i+1]
-                        return json.loads(json_content)
-                    except json.JSONDecodeError:
-                        # 继续尝试下一个匹配
-                        start_idx = -1
-                        bracket_count = 0
-                        brace_count = 0
-            elif char == '{':
-                if start_idx != -1:
-                    brace_count += 1
-            elif char == '}':
-                if start_idx != -1:
-                    brace_count -= 1
-        
-        # 4. 尝试使用正则提取 JSON 数组(更宽松的方式)
-        json_array_pattern = r'\[\s*(?:\{[^}]*\}(?:\s*,\s*\{[^}]*\})*)?\s*\]'
-        json_match = re.search(json_array_pattern, cleaned_response, re.DOTALL)
-        if json_match:
-            try:
-                return json.loads(json_match.group())
-            except json.JSONDecodeError:
-                pass
-        
-        # 5. 尝试逐行查找 JSON 数组
-        lines = cleaned_response.split('\n')
-        json_lines = []
-        in_json = False
-        bracket_count = 0
-        
-        for line in lines:
-            stripped_line = line.strip()
-            if not stripped_line:
-                continue
-            
-            # 检查是否包含 JSON 数组的开始
-            if '[' in stripped_line and not in_json:
-                in_json = True
-                json_lines = [stripped_line]
-                bracket_count = stripped_line.count('[') - stripped_line.count(']')
-            elif in_json:
-                json_lines.append(stripped_line)
-                bracket_count += stripped_line.count('[') - stripped_line.count(']')
-                
-                if bracket_count == 0:
-                    try:
-                        json_content = '\n'.join(json_lines)
-                        return json.loads(json_content)
-                    except json.JSONDecodeError:
-                        in_json = False
-                        json_lines = []
-                        bracket_count = 0
-        
-        # 如果收集到了 JSON 行但还没闭合,尝试解析
-        if json_lines:
-            try:
-                json_content = '\n'.join(json_lines)
-                return json.loads(json_content)
-            except json.JSONDecodeError:
-                pass
-        
-        # 6. 最后尝试:查找所有可能的 JSON 对象并组合成数组
-        try:
-            # 查找所有 { ... } 模式的对象
-            json_objects = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned_response, re.DOTALL)
-            if json_objects:
-                parsed_objects = []
-                for obj_str in json_objects:
-                    try:
-                        parsed_obj = json.loads(obj_str)
-                        if isinstance(parsed_obj, dict) and 'question' in parsed_obj and 'answer' in parsed_obj:
-                            parsed_objects.append(parsed_obj)
-                    except json.JSONDecodeError:
-                        continue
-                if parsed_objects:
-                    logger.info(f"通过对象提取方式解析到 {len(parsed_objects)} 个QA对")
-                    return parsed_objects
-        except Exception as e:
-            logger.debug(f"对象提取方式失败: {str(e)}")
-        
-        # 所有方法都失败
-        # 记录更详细的错误信息用于调试
-        error_info = {
-            "response_length": len(cleaned_response),
-            "first_100_chars": repr(cleaned_response[:100]),
-            "last_100_chars": repr(cleaned_response[-100:]) if len(cleaned_response) > 100 else "",
-            "has_bracket": '[' in cleaned_response,
-            "has_brace": '{' in cleaned_response,
-        }
-        logger.warning(f"无法解析QA响应为JSON: {error_info}")
-        
-        # 尝试最后一次:如果响应看起来像 JSON 数组,尝试修复常见问题
-        if cleaned_response.startswith('[') and cleaned_response.endswith(']'):
-            try:
-                # 尝试修复常见的 JSON 问题:替换中文引号
-                fixed_response = cleaned_response.replace('"', '"').replace('"', '"').replace(''', "'").replace(''', "'")
-                return json.loads(fixed_response)
-            except json.JSONDecodeError:
-                pass
-        
-        return []
     
     def execute(self, state: BaseState) -> Dict[str, Any]:
         """

+ 1 - 1
src/datasets/parser/workflows/qa_workflow.py

@@ -64,7 +64,7 @@ class QAParsingWorkflowV2:
         # 创建节点实例
         ocr_node = PDFOCRNode(model_name=self.model_name)
         split_node = TextSplitNode()
-        qa_node = QAGenerateNode(model_name=self.model_name)
+        qa_node = QAGenerateNode(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507")
         export_node = ExportCSVNode()
         upload_node = RAGFlowDocumentUploadNode(api_key=ragflow_api_key)
         parse_node = RAGFlowDocumentParseNode(api_key=ragflow_api_key)

+ 267 - 0
src/utils/json_utils.py

@@ -0,0 +1,267 @@
+"""
+JSON 解析工具
+
+提供强大的 JSON 响应解析功能,支持多种格式的 JSON 内容提取。
+"""
+
+import json
+import re
+from typing import Any, Dict, List, Optional
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+def parse_json_response(response: str, expected_type: Optional[type] = None) -> Any:
+    """
+    解析 JSON 响应文本
+    
+    支持多种格式的 JSON 内容提取:
+    1. 直接 JSON 解析
+    2. raw_decode 解析(跳过前面的非JSON文本)
+    3. Markdown 代码块包裹的 JSON
+    4. 括号匹配提取 JSON 数组
+    5. 正则表达式提取
+    6. 逐行查找 JSON
+    7. 提取独立 JSON 对象并组合
+    
+    Args:
+        response: 模型响应文本
+        expected_type: 期望的返回类型(list 或 dict),如果指定则验证返回类型
+        
+    Returns:
+        解析后的 JSON 对象(通常是列表或字典),解析失败返回空列表或空字典
+        
+    Examples:
+        >>> # 解析 QA 响应(返回列表)
+        >>> qa_pairs = parse_json_response(response_text)
+        >>> 
+        >>> # 解析配置响应(返回字典)
+        >>> config = parse_json_response(response_text, expected_type=dict)
+    """
+    if not response:
+        return [] if expected_type == list else {} if expected_type == dict else None
+    
+    # 清理响应文本
+    cleaned_response = response.strip()
+    
+    # 移除 BOM 标记
+    if cleaned_response.startswith('\ufeff'):
+        cleaned_response = cleaned_response[1:]
+    
+    # 1. 尝试直接解析JSON
+    try:
+        result = json.loads(cleaned_response)
+        if expected_type is None or isinstance(result, expected_type):
+            return result
+    except json.JSONDecodeError as e:
+        logger.debug(f"直接解析失败: {str(e)}")
+        pass
+    
+    # 1.5. 尝试使用 raw_decode 解析(可以跳过前面的非JSON文本)
+    try:
+        decoder = json.JSONDecoder()
+        result, idx = decoder.raw_decode(cleaned_response)
+        if expected_type is None or isinstance(result, expected_type):
+            return result
+    except (json.JSONDecodeError, ValueError) as e:
+        logger.debug(f"raw_decode 解析失败: {str(e)}")
+        pass
+    
+    # 2. 尝试去除 markdown 代码块标记
+    # 匹配 ```json ... ``` 或 ``` ... ```
+    code_block_pattern = r'```(?:json)?\s*\n?(.*?)\n?```'
+    code_block_match = re.search(code_block_pattern, cleaned_response, re.DOTALL)
+    if code_block_match:
+        try:
+            json_content = code_block_match.group(1).strip()
+            result = json.loads(json_content)
+            if expected_type is None or isinstance(result, expected_type):
+                return result
+        except json.JSONDecodeError:
+            pass
+    
+    # 3. 尝试提取第一个完整的 JSON 数组
+    # 使用括号匹配算法,正确处理嵌套的 [] 和 {}
+    bracket_count = 0
+    brace_count = 0
+    start_idx = -1
+    in_string = False
+    escape_next = False
+    
+    for i, char in enumerate(cleaned_response):
+        if escape_next:
+            escape_next = False
+            continue
+        
+        if char == '\\':
+            escape_next = True
+            continue
+        
+        if char == '"' and not escape_next:
+            in_string = not in_string
+            continue
+        
+        if in_string:
+            continue
+        
+        if char == '[':
+            if start_idx == -1:
+                start_idx = i
+            bracket_count += 1
+        elif char == ']':
+            bracket_count -= 1
+            if bracket_count == 0 and brace_count == 0 and start_idx != -1:
+                try:
+                    json_content = cleaned_response[start_idx:i+1]
+                    result = json.loads(json_content)
+                    if expected_type is None or isinstance(result, expected_type):
+                        return result
+                except json.JSONDecodeError:
+                    # 继续尝试下一个匹配
+                    start_idx = -1
+                    bracket_count = 0
+                    brace_count = 0
+        elif char == '{':
+            if start_idx != -1:
+                brace_count += 1
+        elif char == '}':
+            if start_idx != -1:
+                brace_count -= 1
+    
+    # 4. 尝试使用正则提取 JSON 数组(更宽松的方式)
+    json_array_pattern = r'\[\s*(?:\{[^}]*\}(?:\s*,\s*\{[^}]*\})*)?\s*\]'
+    json_match = re.search(json_array_pattern, cleaned_response, re.DOTALL)
+    if json_match:
+        try:
+            result = json.loads(json_match.group())
+            if expected_type is None or isinstance(result, expected_type):
+                return result
+        except json.JSONDecodeError:
+            pass
+    
+    # 5. 尝试逐行查找 JSON 数组
+    lines = cleaned_response.split('\n')
+    json_lines = []
+    in_json = False
+    bracket_count = 0
+    
+    for line in lines:
+        stripped_line = line.strip()
+        if not stripped_line:
+            continue
+        
+        # 检查是否包含 JSON 数组的开始
+        if '[' in stripped_line and not in_json:
+            in_json = True
+            json_lines = [stripped_line]
+            bracket_count = stripped_line.count('[') - stripped_line.count(']')
+        elif in_json:
+            json_lines.append(stripped_line)
+            bracket_count += stripped_line.count('[') - stripped_line.count(']')
+            
+            if bracket_count == 0:
+                try:
+                    json_content = '\n'.join(json_lines)
+                    result = json.loads(json_content)
+                    if expected_type is None or isinstance(result, expected_type):
+                        return result
+                except json.JSONDecodeError:
+                    in_json = False
+                    json_lines = []
+                    bracket_count = 0
+    
+    # 如果收集到了 JSON 行但还没闭合,尝试解析
+    if json_lines:
+        try:
+            json_content = '\n'.join(json_lines)
+            result = json.loads(json_content)
+            if expected_type is None or isinstance(result, expected_type):
+                return result
+        except json.JSONDecodeError:
+            pass
+    
+    # 6. 最后尝试:查找所有可能的 JSON 对象并组合成数组
+    try:
+        # 查找所有 { ... } 模式的对象
+        json_objects = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned_response, re.DOTALL)
+        if json_objects:
+            parsed_objects = []
+            for obj_str in json_objects:
+                try:
+                    parsed_obj = json.loads(obj_str)
+                    # 如果是 QA 格式(包含 question 和 answer),则添加
+                    if isinstance(parsed_obj, dict) and 'question' in parsed_obj and 'answer' in parsed_obj:
+                        parsed_objects.append(parsed_obj)
+                    elif expected_type is None:
+                        # 如果没有指定期望类型,也添加
+                        parsed_objects.append(parsed_obj)
+                except json.JSONDecodeError:
+                    continue
+            if parsed_objects:
+                logger.info(f"通过对象提取方式解析到 {len(parsed_objects)} 个对象")
+                return parsed_objects
+    except Exception as e:
+        logger.debug(f"对象提取方式失败: {str(e)}")
+    
+    # 所有方法都失败
+    # 记录更详细的错误信息用于调试
+    error_info = {
+        "response_length": len(cleaned_response),
+        "first_100_chars": repr(cleaned_response[:100]),
+        "last_100_chars": repr(cleaned_response[-100:]) if len(cleaned_response) > 100 else "",
+        "has_bracket": '[' in cleaned_response,
+        "has_brace": '{' in cleaned_response,
+    }
+    logger.warning(f"无法解析JSON响应: {error_info}")
+    
+    # 尝试最后一次:如果响应看起来像 JSON 数组,尝试修复常见问题
+    if cleaned_response.startswith('[') and cleaned_response.endswith(']'):
+        try:
+            # 尝试修复常见的 JSON 问题:替换中文引号
+            fixed_response = cleaned_response.replace('"', '"').replace('"', '"').replace(''', "'").replace(''', "'")
+            result = json.loads(fixed_response)
+            if expected_type is None or isinstance(result, expected_type):
+                return result
+        except json.JSONDecodeError:
+            pass
+    
+    # 根据期望类型返回默认值
+    if expected_type == list:
+        return []
+    elif expected_type == dict:
+        return {}
+    return None
+
+
+def parse_qa_response(response: str) -> List[Dict[str, str]]:
+    """
+    解析 QA 响应文本(便捷函数)
+    
+    专门用于解析 QA 问答对响应,返回格式化的 QA 列表。
+    
+    Args:
+        response: 模型响应文本,应包含 JSON 格式的 QA 对列表
+        
+    Returns:
+        QA 对列表,每个元素包含 "question" 和 "answer" 字段
+        
+    Examples:
+        >>> response = '[{"question": "问题1", "answer": "答案1"}]'
+        >>> qa_pairs = parse_qa_response(response)
+        >>> # 返回: [{"question": "问题1", "answer": "答案1"}]
+    """
+    result = parse_json_response(response, expected_type=list)
+    if result is None:
+        return []
+    
+    # 验证并过滤有效的 QA 对
+    valid_qa_pairs = []
+    for item in result:
+        if isinstance(item, dict) and 'question' in item and 'answer' in item:
+            valid_qa_pairs.append({
+                "question": str(item.get("question", "")),
+                "answer": str(item.get("answer", ""))
+            })
+    
+    return valid_qa_pairs

+ 88 - 3
src/utils/vector_db/elasticsearch_adapter.py

@@ -88,6 +88,32 @@ class ElasticsearchAdapter(VectorDBClient):
             return f"{prefix}_{table_name}"
         return table_name
     
+    def _build_kb_id_filter(self, kb_id: Union[str, List[str]]) -> Dict[str, Any]:
+        """
+        构建 kb_id 过滤条件
+        
+        支持单个 kb_id 或 kb_id 数组的精准匹配。
+        
+        Args:
+            kb_id: 单个 kb_id 字符串或 kb_id 列表
+            
+        Returns:
+            ES 查询条件字典(term 或 terms)
+        """
+        if isinstance(kb_id, list):
+            # 数组:使用 terms 查询匹配多个值
+            if len(kb_id) == 0:
+                raise ValueError("kb_id 数组不能为空")
+            if len(kb_id) == 1:
+                # 单个元素,使用 term 查询
+                return {"term": {"kb_id": kb_id[0]}}
+            else:
+                # 多个元素,使用 terms 查询
+                return {"terms": {"kb_id": kb_id}}
+        else:
+            # 单个值:使用 term 查询
+            return {"term": {"kb_id": kb_id}}
+    
     def search(
         self,
         table_name: str,
@@ -99,16 +125,40 @@ class ElasticsearchAdapter(VectorDBClient):
         全文搜索
         
         将 Infinity 风格的查询转换为 ES 查询。
+        
+        Args:
+            table_name: 表名(索引名)
+            output_fields: 输出字段列表
+            query: 查询参数,必须包含 kb_id 字段
+            database_name: 数据库名(可选)
         """
+        # 验证 kb_id 必填
+        kb_id = query.get("kb_id")
+        if not kb_id:
+            raise ValueError("kb_id 是必填参数")
+        
         index_name = table_name
         match_field = query.get("match_field", "content")
         matching_text = query.get("matching_text") or ""
         
+        # 构建 kb_id 过滤条件(支持单个值或数组)
+        kb_id_filter = self._build_kb_id_filter(kb_id)
+        
+        # 构建查询:必须匹配 kb_id,并且匹配文本内容
+        es_query = {
+            "bool": {
+                "must": [
+                    kb_id_filter,  # kb_id 精准匹配(支持单个或数组)
+                    {"match": {match_field: matching_text}}  # 文本匹配
+                ]
+            }
+        }
+        
         # match 的 query 不能为 None,否则 ES 报 VALUE_NULL
         # ES 8.x 使用关键字参数而非 body
         result = self._client.search(
             index=index_name,
-            query={"match": {match_field: matching_text}},
+            query=es_query,
             size=query.get("topn", 10),
             _source=output_fields
         )
@@ -125,7 +175,18 @@ class ElasticsearchAdapter(VectorDBClient):
         向量搜索
         
         使用 ES 的 knn 查询进行向量检索。
+        
+        Args:
+            table_name: 表名(索引名)
+            output_fields: 输出字段列表
+            query: 查询参数,必须包含 kb_id 字段
+            database_name: 数据库名(可选)
         """
+        # 验证 kb_id 必填
+        kb_id = query.get("kb_id")
+        if not kb_id:
+            raise ValueError("kb_id 是必填参数")
+        
         index_name = table_name
         vector_field = query.get("vector_field", "dense_vector")
         query_vector = query.get("query_vector", [])
@@ -141,14 +202,19 @@ class ElasticsearchAdapter(VectorDBClient):
         knn_params = query.get("knn_params", {})
         threshold = float(knn_params.get("threshold", 0.0))
         
+        # 构建 kb_id 过滤条件(支持单个值或数组)
+        kb_id_filter = self._build_kb_id_filter(kb_id)
+        
         # 构建 ES knn 查询(ES 8.x 使用关键字参数而非 body)
+        # 添加 filter 来精准匹配 kb_id
         search_params = {
             "index": index_name,
             "knn": {
                 "field": vector_field,
                 "query_vector": query_vector,
                 "k": topn,
-                "num_candidates": topn * 10
+                "num_candidates": topn * 10,
+                "filter": kb_id_filter  # kb_id 精准匹配(支持单个或数组)
             },
             "_source": output_fields
         }
@@ -171,7 +237,21 @@ class ElasticsearchAdapter(VectorDBClient):
         混合搜索(向量 + 全文)
         
         使用 ES 的 bool 查询结合 knn 和 match。
+        
+        Args:
+            output_fields: 输出字段列表
+            query: 查询参数,必须包含 kb_id 字段
+            table_name: 表名(索引名),可以是字符串或字符串列表
+            database_name: 数据库名(可选)
         """
+        # 验证 kb_id 必填
+        kb_id = query.get("kb_id")
+        if not kb_id:
+            raise ValueError("kb_id 是必填参数")
+        
+        # 构建 kb_id 过滤条件(支持单个值或数组)
+        kb_id_filter = self._build_kb_id_filter(kb_id)
+        
         # index_name =  "ragbook_1_f3c87b89f82711f0b7450242ac180002,ragbook_2_16db7236f82911f093e20242ac180002"
         index_name = table_name
         vector_field = query.get("vector_field", "dense_vector")
@@ -204,10 +284,14 @@ class ElasticsearchAdapter(VectorDBClient):
             text_clause = {"match_all": {"boost": 1.0}}
         
         # 构建混合查询(ES 8.x 使用关键字参数而非 body)
+        # 在 query 的 bool.must 中添加 kb_id 精准匹配,在 knn 的 filter 中也添加
         result = self._client.search(
             index=index_name,
             query={
                 "bool": {
+                    "must": [
+                        kb_id_filter,  # kb_id 精准匹配(支持单个或数组)
+                    ],
                     "should": [text_clause]
                 }
             },
@@ -216,7 +300,8 @@ class ElasticsearchAdapter(VectorDBClient):
                 "query_vector": query_vector,
                 "k": topn,
                 "num_candidates": topn * 10,
-                "boost": 1.0
+                "boost": 1.0,
+                "filter": kb_id_filter  # kb_id 精准匹配(支持单个或数组)
             },
             size=topn,
             _source=output_fields