Explorar o código

查询效率优化

yingge hai 3 meses
pai
achega
f0b5960c74

+ 1 - 9
src/api/db/models/search_request_models.py

@@ -46,7 +46,7 @@ class SearchRequest(BaseModel):
         le=1.0,
         description="相似度分数阈值"
     )
-    kb_id: Optional[str] = Field(
+    kb_id: List[str] = Field(
         default=None,
         description="知识库id",
         examples=["kb_id1", "kb_id2"]
@@ -72,11 +72,3 @@ class SearchRequest(BaseModel):
         """模型初始化后验证:确保至少提供一个搜索条件"""
         if self.image_url is None and self.matching_text is None:
             raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")
-
-    @field_validator('kb_id')
-    @classmethod
-    def validate_kb_id(cls, v: Optional[str]) -> Optional[str]:
-        """验证知识库id"""
-        if v is not None and v.strip() == "":
-            return None
-        return v

+ 0 - 1
src/api/db/services/vector_search_service.py

@@ -132,7 +132,6 @@ class VectorSearchService:
                     query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image)
                 else:
                     query_vector = get_embedding_model().get_text_embedding(request.matching_text)
-            
             search_query = {
                 "match_field": self.match_field,
                 "matching_text": request.matching_text,

+ 63 - 29
src/api/sdk/search_infinity.py

@@ -8,6 +8,41 @@ from src.utils.async_utils import run_in_threadpool
 from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
 from src.api.db.models import SearchRequest
 from src.conf.settings import ragflow_settings
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+# 全局搜索服务实例
+_search_service = None
+
+
+def get_search_service():
+    """获取搜索服务实例(单例模式)"""
+    global _search_service
+    if _search_service is None:
+        _search_service = VectorSearchService(
+            client=get_vector_db_client(database="ragflow_db"),
+            vector_field="q_1024_vec",
+            match_field="content_sm_ltks",
+            output_fields=["content_with_weight"]
+        )
+        logger.info("Search service initialized")
+    return _search_service
+
+
+def _get_question_table_name():
+    """获取问答对表名"""
+    try:
+        ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
+        ragflow_id = ragflow_user.get("ragflow_id")
+        if not ragflow_id:
+            raise Exception("未找到ragflow_id")
+        table_name = f"{ragflow_settings.ragflow_dataset_prefix}_{ragflow_id}"
+        logger.info(f"Question table name: {table_name}")
+        return table_name
+    except Exception as e:
+        logger.error(f"Error getting question table name: {str(e)}")
+        raise
 
 # 创建FastAPI应用
 
@@ -82,48 +117,47 @@ async def question_search(request: SearchRequest):
     - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
     """
     try:
-        output_fields = ["content_with_weight"]
-        search_service = VectorSearchService(
-            client=get_vector_db_client(database="ragflow_db"), 
-            vector_field="q_1024_vec",
-            output_fields=output_fields
-        )
-        ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
-        ragflow_id = ragflow_user.get("ragflow_id")
-        if not ragflow_id:
-            return Result.error(code=500, message="未找到ragflow_id")
-        question_table_name = f"{ragflow_settings.ragflow_dataset_prefix}_{ragflow_id}"
-        request.kb_id = "2b0ac35df80e11f096160242ac180002"
-        request.database_ids = [question_table_name]
+        logger.info(f"Received question search request: {request.matching_text[:100]}...")
+        
+        # 获取搜索服务实例
+        search_service = get_search_service()
+        
+        # 执行搜索
         result = await run_in_threadpool(search_service.hybrid_search, request)
+        
+        logger.info(f"Question search completed successfully, found {len(result) if result else 0} results")
         return Result.success(data=result, message="问答对检索成功")
     except Exception as e:
+        logger.error(f"问答对检索失败: {str(e)}")
         return Result.error(code=500, message=f"问答对检索失败: {str(e)}")
 
 @app.post("/resource")
-async def question_search(request: SearchRequest):
+async def resource_search(request: SearchRequest):
     """
-    问答对检索接口
+    资源检索接口
 
     - **output_fields**: 要返回的字段列表
     - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
     - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
     """
     try:
-        output_fields = ["content_with_weight"]
-        search_service = VectorSearchService(
-            client=get_vector_db_client(database="ragflow_db"), 
-            vector_field="q_1024_vec",
-            output_fields=output_fields
-        )
-        ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
-        ragflow_id = ragflow_user.get("ragflow_id")
-        if not ragflow_id:
-            return Result.error(code=500, message="未找到ragflow_id")
-        question_table_name = f"{ragflow_settings.ragflow_dataset_prefix}_{ragflow_id}"
+        logger.info(f"Received resource search request: {request.matching_text[:100]}...")
+        
+        # 获取搜索服务实例
+        search_service = get_search_service()
+        
+        # 获取资源表名
+        resource_table_name = await run_in_threadpool(_get_question_table_name)
+        
+        # 设置请求参数
         request.kb_id = "f59936fbfb1f11f095280242ac180002"
-        request.database_ids = [question_table_name]
+        request.database_ids = [resource_table_name]
+        
+        # 执行搜索
         result = await run_in_threadpool(search_service.hybrid_search, request)
-        return Result.success(data=result, message="问答对检索成功")
+        
+        logger.info(f"Resource search completed successfully, found {len(result) if result else 0} results")
+        return Result.success(data=result, message="资源检索成功")
     except Exception as e:
-        return Result.error(code=500, message=f"问答对检索失败: {str(e)}")
+        logger.error(f"资源检索失败: {str(e)}")
+        return Result.error(code=500, message=f"资源检索失败: {str(e)}")

+ 4 - 3
src/common/logging_config.py

@@ -19,8 +19,7 @@ DEFAULT_LOG_LEVEL = logging.INFO
 
 # 日志文件配置(可选)
 LOG_DIR = Path("logs")
-LOG_FILE = LOG_DIR / "app.log"
-
+LOG_FILE = r"D:\project\work\graph_rag_server\logs\app.log"
 
 def setup_logging(
     level: int = DEFAULT_LOG_LEVEL,
@@ -54,6 +53,8 @@ def setup_logging(
     # 添加文件处理器(如果需要)
     if log_to_file:
         file_path = log_file or LOG_FILE
+        # 转换为Path对象
+        file_path = Path(file_path)
         # 确保日志目录存在
         file_path.parent.mkdir(parents=True, exist_ok=True)
         
@@ -81,4 +82,4 @@ def get_logger(name: str) -> logging.Logger:
 
 
 # 在模块导入时自动配置日志
-setup_logging()
+setup_logging(log_to_file=True)

+ 25 - 21
src/model/multimodal_embedding.py

@@ -26,6 +26,14 @@ class Embedding:
         self.base_url = model_settings.base_url
         self.api_key = api_key or model_settings.api_key
         self.dashscope_api_key =  model_settings.dashscope_api_key
+        # 初始化OpenAI Embeddings客户端(只初始化一次)
+        self._embeddings_client = OpenAIEmbeddings(
+            model=self.embedding_model_name,
+            base_url=self.base_url,
+            api_key=self.api_key
+        )
+        # 初始化MultiModal Embedding客户端(只初始化一次)
+        self._multimodal_client = MultiModalEmbedding
     
     @observe(name="text_embedding", as_type="embedding")
     def get_text_embedding(self, text: str) -> List[float]:
@@ -39,13 +47,8 @@ class Embedding:
             List[float]: 文本的embedding向量
         """
         try:
-            # 使用langchain_openai初始化OpenAI Embeddings模型
-            embeddings = OpenAIEmbeddings(
-                model=self.embedding_model_name,
-                base_url=self.base_url,
-                api_key=self.api_key
-            )
-            embedding = embeddings.embed_query(text)
+            # 使用预先初始化的客户端
+            embedding = self._embeddings_client.embed_query(text)
             return embedding
         except Exception as e:
             raise Exception(f"文本embedding生成失败: {str(e)}")
@@ -62,13 +65,8 @@ class Embedding:
             List[List[float]]: 文本列表的embedding向量列表
         """
         try:
-            # 使用langchain_openai初始化OpenAI Embeddings模型
-            embeddings = OpenAIEmbeddings(
-                model=self.embedding_model_name,
-                base_url=self.base_url,
-                api_key=self.api_key
-            )
-            embeddings = embeddings.embed_documents(texts)
+            # 使用预先初始化的客户端
+            embeddings = self._embeddings_client.embed_documents(texts)
             return embeddings
         except Exception as e:
             raise Exception(f"多个文本embedding生成失败: {str(e)}")
@@ -100,7 +98,7 @@ class Embedding:
                     "image": f"data:image/png;base64,{image_base64}"
                 }
             ]
-            response = MultiModalEmbedding.call(
+            response = self._multimodal_client.call(
                 model=self.multi_embedding_model_name,
                 api_key=self.dashscope_api_key,
                 input=item
@@ -138,7 +136,7 @@ class Embedding:
             if text is not None and text.strip() != "":
                 item.append({'text': text})
         
-            response = MultiModalEmbedding.call(
+            response = self._multimodal_client.call(
                 model=self.multi_embedding_model_name,
                 api_key=self.dashscope_api_key,
                 input=item
@@ -151,7 +149,10 @@ class Embedding:
         except Exception as e:
             raise Exception(f"多模态embedding生成失败: {str(e)}")
 
-# 全局单例
+# 全局单例实例
+_embedding_model_instance = None
+
+
 def get_embedding_model() -> Embedding:
     """
     获取全局单例的Embedding模型
@@ -159,7 +160,10 @@ def get_embedding_model() -> Embedding:
     Returns:
         Embedding: 全局单例的Embedding模型
     """
-    return Embedding(
-        model_name=model_settings.embedding_model_name,
-        api_key=model_settings.api_key
-    )
+    global _embedding_model_instance
+    if _embedding_model_instance is None:
+        _embedding_model_instance = Embedding(
+            model_name=model_settings.embedding_model_name,
+            api_key=model_settings.api_key
+        )
+    return _embedding_model_instance

+ 213 - 0
src/utils/chinese_text_processor.py

@@ -0,0 +1,213 @@
+"""
+中文文本处理工具
+
+提供中文文本标准化、分词、权重计算和同义词扩展功能。
+"""
+
+import re
+import string
+from typing import List, Dict, Tuple, Any
+import jieba
+from src.common.logging_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class ChineseTextProcessor:
+    """
+    中文文本处理器
+    
+    提供以下功能:
+    1. 文本标准化
+    2. 中文分词
+    3. 停用词过滤
+    4. 权重计算
+    5. 同义词扩展
+    """
+    
+    def __init__(self):
+        """
+        初始化中文文本处理器
+        """
+        # 加载停用词
+        self.stopwords = self._load_stopwords()
+    
+    def _load_stopwords(self) -> set:
+        """
+        加载停用词表
+        
+        Returns:
+            set: 停用词集合
+        """
+        # 基本停用词
+        stopwords = {
+            '的', '了', '是', '在', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也',
+            '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这', '个',
+            '得', '地', '天', '子', '日', '月', '年', '时', '分', '秒', '中', '国', '人', '民',
+            '大', '小', '多', '少', '上', '下', '左', '右', '前', '后', '里', '外', '内', '外',
+            '高', '低', '长', '短', '宽', '窄', '厚', '薄', '远', '近', '早', '晚', '今', '明',
+            '昨', '天', '周', '月', '年', '春', '夏', '秋', '冬', '东', '南', '西', '北',
+            # 英文停用词
+            'a', 'an', 'the', 'and', 'or', 'but', 'if', 'because', 'as', 'what', 'which', 'this',
+            'that', 'these', 'those', 'then', 'just', 'so', 'than', 'such', 'both', 'through',
+            'about', 'for', 'is', 'of', 'while', 'during', 'to', 'from', 'in', 'on'
+        }
+        return stopwords
+    
+
+    
+    def normalize_text(self, text: str) -> str:
+        """
+        文本标准化
+        
+        Args:
+            text: 原始文本
+            
+        Returns:
+            str: 标准化后的文本
+        """
+        if not text:
+            return ""
+        
+        # 1. 移除 WWW 相关内容
+        text = re.sub(r'https?://\S+', '', text)
+        text = re.sub(r'www\.\S+', '', text)
+        
+        # 2. 中英文之间添加空格
+        text = re.sub(r'([a-zA-Z])([\u4e00-\u9fa5])', r'\1 \2', text)
+        text = re.sub(r'([\u4e00-\u9fa5])([a-zA-Z])', r'\1 \2', text)
+        
+        # 3. 替换特殊字符为空格
+        text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s]', ' ', text)
+        
+        # 4. 全角转半角
+        def full_to_half(s):
+            result = []
+            for char in s:
+                code = ord(char)
+                if code == 12288:  # 全角空格
+                    result.append(' ')
+                elif 65281 <= code <= 65374:  # 全角字符
+                    result.append(chr(code - 65248))
+                else:
+                    result.append(char)
+            return ''.join(result)
+        
+        text = full_to_half(text)
+        
+        # 5. 转换为小写
+        text = text.lower()
+        
+        # 6. 移除多余空格
+        text = re.sub(r'\s+', ' ', text).strip()
+        
+        return text
+    
+    def split(self, text: str) -> List[str]:
+        """
+        中文分词
+        
+        Args:
+            text: 标准化后的文本
+            
+        Returns:
+            List[str]: 分词结果
+        """
+        if not text:
+            return []
+        
+        # 使用 jieba 分词
+        words = jieba.cut(text)
+        
+        # 过滤停用词和无意义字符
+        filtered_words = []
+        for word in words:
+            # 过滤停用词
+            if word in self.stopwords:
+                continue
+            # 过滤无意义字符
+            if len(word) < 1:
+                continue
+            # 过滤纯数字
+            if word.isdigit():
+                continue
+            # 过滤纯空格
+            if word.isspace():
+                continue
+            filtered_words.append(word)
+        
+        return filtered_words
+    
+    def weights(self, words: List[str]) -> Dict[str, float]:
+        """
+        计算词权重
+        
+        Args:
+            words: 分词结果
+            
+        Returns:
+            Dict[str, float]: 词权重映射
+        """
+        if not words:
+            return {}
+        
+        # 简单的词频权重计算
+        word_freq = {}
+        for word in words:
+            word_freq[word] = word_freq.get(word, 0) + 1
+        
+        # 归一化权重
+        total_freq = sum(word_freq.values())
+        weights = {}
+        for word, freq in word_freq.items():
+            # 词长因子:长词权重更高
+            length_factor = min(len(word) / 4, 1.0)
+            # 频率因子
+            freq_factor = freq / total_freq
+            # 综合权重
+            weight = (freq_factor * 0.7) + (length_factor * 0.3)
+            weights[word] = round(weight, 4)
+        
+        return weights
+    
+    def lookup(self, word: str) -> List[str]:
+        """
+        查找同义词(已禁用)
+        
+        Args:
+            word: 原始词
+            
+        Returns:
+            List[str]: 空列表(同义词功能已禁用)
+        """
+        return []
+    
+    def process_text(self, text: str) -> Dict[str, Any]:
+        """
+        完整文本处理流程
+        
+        Args:
+            text: 原始文本
+            
+        Returns:
+            Dict[str, Any]: 处理结果
+        """
+        # 1. 文本标准化
+        normalized_text = self.normalize_text(text)
+        
+        # 2. 分词
+        words = self.split(normalized_text)
+        
+        # 3. 权重计算
+        word_weights = self.weights(words)
+        
+        return {
+            'original_text': text,
+            'normalized_text': normalized_text,
+            'words': words,
+            'word_weights': word_weights
+        }
+
+
+# 全局实例
+chinese_processor = ChineseTextProcessor()

+ 102 - 20
src/utils/vector_db/elasticsearch_adapter.py

@@ -11,6 +11,7 @@ from .base import VectorDBClient
 from .result_util import UnifiedSearchResult, build_unified_result
 from src.conf.settings import es_settings, vector_db_settings
 from src.common.logging_config import get_logger
+from src.utils.chinese_text_processor import chinese_processor
 
 logger = get_logger(__name__)
 
@@ -214,14 +215,15 @@ class ElasticsearchAdapter(VectorDBClient):
                 "query_vector": query_vector,
                 "k": topn,
                 "num_candidates": topn * 10,
+                "similarity": threshold,
                 "filter": kb_id_filter  # kb_id 精准匹配(支持单个或数组)
             },
             "_source": output_fields
         }
         
-        # 如果有阈值,添加 min_score
-        if threshold > 0:
-            search_params["min_score"] = threshold
+        # # 如果有阈值,添加 min_score
+        # if threshold > 0:
+        #     search_params["min_score"] = threshold
         
         result = self._client.search(**search_params)
         return self._convert_result(result, output_fields)
@@ -252,17 +254,23 @@ class ElasticsearchAdapter(VectorDBClient):
         # 构建 kb_id 过滤条件(支持单个值或数组)
         kb_id_filter = self._build_kb_id_filter(kb_id)
         
-        # index_name =  "ragbook_1_f3c87b89f82711f0b7450242ac180002,ragbook_2_16db7236f82911f093e20242ac180002"
         index_name = table_name
         vector_field = query.get("vector_field", "dense_vector")
         query_vector = query.get("query_vector", [])
         # 确保 query_vector 是 Python 原生列表,兼容 numpy 数组等类型
-        if query_vector is not None and len(query_vector) > 0:
-            # 强制转换为 Python 原生 float 列表
-            query_vector = [float(x) for x in query_vector]
-            logger.info(f"混合搜索: 已转换 query_vector 类型, 长度={len(query_vector)}, 前3个元素={query_vector[:3]}")
-        else:
-            query_vector = []
+        # if query_vector is not None and len(query_vector) > 0:
+        #     # 强制转换为 Python 原生 float 列表
+        #     query_vector = [float(x) for x in query_vector]
+        #     logger.info(f"混合搜索: 已转换 query_vector 类型, 长度={len(query_vector)}, 前3个元素={query_vector[:3]}")
+        # else:
+        #     query_vector = []
+
+        # 将query_vector的值写入D:\project\work\graph_rag_server\logs\app.log
+        logger.info(f"混合搜索: query_vector={query_vector}")
+        # 获取阈值
+        knn_params = query.get("knn_params", {})
+        threshold = float(knn_params.get("threshold", 0.0))
+
         match_field = query.get("match_field", "content")
         matching_text = query.get("matching_text")
         if matching_text is not None and not isinstance(matching_text, str):
@@ -275,24 +283,21 @@ class ElasticsearchAdapter(VectorDBClient):
         
         # match 的 query 不能为 None,否则 ES 报 VALUE_NULL。无文案时仅用向量(match_all)
         if matching_text:
-            text_clause = {
-                "match": {
-                    match_field: {"query": matching_text, "boost": 1.0}
-                }
-            }
+            # 处理中文文本
+            text_clause = self._build_chinese_query(match_field, matching_text)
         else:
-            text_clause = {"match_all": {"boost": 1.0}}
-        
+            text_clause = {"match_all": {"boost": 2.0}}
         # 构建混合查询(ES 8.x 使用关键字参数而非 body)
         # 在 query 的 bool.must 中添加 kb_id 精准匹配,在 knn 的 filter 中也添加
+        # 打印所有查询条件
         result = self._client.search(
             index=index_name,
             query={
                 "bool": {
-                    "must": [
+                    "filter": [
                         kb_id_filter,  # kb_id 精准匹配(支持单个或数组)
                     ],
-                    "should": [text_clause]
+                    "must": [text_clause]
                 }
             },
             knn={
@@ -300,7 +305,8 @@ class ElasticsearchAdapter(VectorDBClient):
                 "query_vector": query_vector,
                 "k": topn,
                 "num_candidates": topn * 10,
-                "boost": 1.0,
+                "similarity": threshold,
+                "boost": 8.0,
                 "filter": kb_id_filter  # kb_id 精准匹配(支持单个或数组)
             },
             size=topn,
@@ -328,6 +334,82 @@ class ElasticsearchAdapter(VectorDBClient):
             return result
         return None
     
+    def _build_chinese_query(self, field: str, text: str) -> Dict[str, Any]:
+        """
+        构建中文查询
+        
+        Args:
+            field: 查询字段
+            text: 原始文本
+            
+        Returns:
+            Dict[str, Any]: ES查询DSL
+        """
+        # 处理中文文本
+        processed_result = chinese_processor.process_text(text)
+        
+        logger.info(f"Chinese text processed: {processed_result}")
+        
+        # 获取处理结果
+        words = processed_result.get('words', [])
+        word_weights = processed_result.get('word_weights', {})
+        
+        # 构建布尔查询
+        bool_query = {
+            "bool": {
+                "should": []
+            }
+        }
+        
+        # 添加原始词查询
+        for word in words:
+            weight = word_weights.get(word, 1.0)
+            # 添加 term 查询(精确匹配)
+            bool_query["bool"]["should"].append({
+                "term": {
+                    field: {
+                        "value": word,
+                        "boost": 2.0 * weight
+                    }
+                }
+            })
+            # 添加 match 查询(模糊匹配)
+            bool_query["bool"]["should"].append({
+                "match": {
+                    field: {
+                        "query": word,
+                        "boost": 1.0 * weight
+                    }
+                }
+            })
+        
+        # 添加短语查询(保留词序)
+        if len(words) > 1:
+            phrase_query = " ".join(words)
+            bool_query["bool"]["should"].append({
+                "match_phrase": {
+                    field: {
+                        "query": phrase_query,
+                        "boost": 3.0
+                    }
+                }
+            })
+        
+        # 添加原始文本匹配
+        bool_query["bool"]["should"].append({
+            "match": {
+                field: {
+                    "query": text,
+                    "boost": 1.5
+                }
+            }
+        })
+        
+        # 设置 minimum_should_match,确保至少匹配一个条件
+        bool_query["bool"]["minimum_should_match"] = 1
+        
+        return bool_query
+    
     @staticmethod
     def _parse_cond_for_term(cond: str) -> Optional[Tuple[str, str]]:
         """