""" 向量数据库搜索服务 提供统一的搜索接口,自动根据配置选择底层数据库实现。 """ from typing import Dict, Any, List, Optional from src.conf.settings import vector_db_settings from src.utils.vector_db import get_vector_db_client, VectorDBClient from src.utils.file.image_util import image_util from src.model.multimodal_embedding import get_embedding_model from src.utils.infinity.result_util import convert_to_json from src.api.db.models import SearchRequest from src.common.logging_config import get_logger logger = get_logger(__name__) class VectorSearchService: """ 向量数据库搜索服务 自动根据 VECTOR_DB_TYPE 配置选择对应的数据库实现。 提供与原 InfinitySearchService 兼容的接口。 """ default_output_fields = [ "file_name", "page_number", "content", "image_path", "dataset_id", "document_id" ] def __init__( self, client: VectorDBClient = None, database: str = None, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None, output_fields: List[str] = None ): """ 初始化搜索服务 Args: client: 向量数据库客户端(可选,不传则自动创建) database: 数据库名称 vector_field: 向量字段名 match_field: 匹配字段名 match_type: 匹配类型 table_name: 表名 output_fields: 输出字段列表 """ self._client = client or get_vector_db_client(database=database) self.output_fields = output_fields or self.default_output_fields self.vector_field = vector_field or "dense_vector_1024" self.match_field = match_field or "content" self.match_type = match_type or "cosine" self.table_name = table_name or vector_db_settings.infinity_table_name @property def client(self) -> VectorDBClient: """获取底层客户端""" return self._client def search(self, request: SearchRequest) -> List[Dict[str, Any]]: """ 执行全文搜索 Args: request: 搜索请求 Returns: 搜索结果列表 """ try: search_query = self._convert_search_request_to_search_query(request) result = self._client.search(self.table_name, self.output_fields, search_query) result_dict = result.to_result() return convert_to_json(result_dict) except Exception as e: logger.error(f"搜索失败: {str(e)}") raise Exception(f"搜索失败: {str(e)}") def vector_search(self, request: SearchRequest) -> List[Dict[str, Any]]: """ 执行向量搜索 Args: request: 搜索请求 Returns: 搜索结果列表 """ try: search_query = self._convert_search_request_to_search_query(request) result = self._client.vector_search(self.table_name, self.output_fields, search_query) result_dict = result.to_result() return convert_to_json(result_dict) except Exception as e: logger.error(f"向量搜索失败: {str(e)}") raise Exception(f"向量搜索失败: {str(e)}") def hybrid_search(self, request: SearchRequest) -> List[Dict[str, Any]]: """ 执行混合搜索 Args: request: 搜索请求 Returns: 搜索结果列表 """ try: search_query = self._convert_search_request_to_search_query(request) result = self._client.hybrid_search(self.table_name, self.output_fields, search_query) result_dict = result.to_result() return convert_to_json(result_dict) except Exception as e: logger.error(f"混合搜索失败: {str(e)}") raise Exception(f"混合搜索失败: {str(e)}") def _convert_search_request_to_search_query(self, request: SearchRequest) -> dict: """ 将 SearchRequest 转换为搜索查询参数 """ try: query_vector = None if request.score_threshold is not None: if request.image_url is not None: image = image_util._url_to_image(request.image_url) query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image) else: query_vector = get_embedding_model().get_text_embedding(request.matching_text) search_query = { "match_field": self.match_field, "matching_text": request.matching_text, "vector_field": self.vector_field, "query_vector": query_vector, "topn": request.topn, "knn_params": { "ef": str(request.topn * 10), "threshold": str(request.score_threshold) if request.score_threshold else "0" } } return search_query except Exception as e: logger.error(f"参数转换失败: {str(e)}") raise Exception(f"参数转换失败: {str(e)}") # 向后兼容别名 SearchService = VectorSearchService