| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- """
- 向量数据库搜索服务
- 提供统一的搜索接口,自动根据配置选择底层数据库实现。
- """
- from typing import Dict, Any, List, Optional
- from src.conf.settings import vector_db_settings
- from src.utils.vector_db import get_vector_db_client, VectorDBClient
- from src.utils.file.image_util import image_util
- from src.model.multimodal_embedding import get_embedding_model
- from src.utils.infinity.result_util import convert_to_json
- from src.api.db.models import SearchRequest
- from src.common.logging_config import get_logger
- logger = get_logger(__name__)
- class VectorSearchService:
- """
- 向量数据库搜索服务
-
- 自动根据 VECTOR_DB_TYPE 配置选择对应的数据库实现。
- 提供与原 InfinitySearchService 兼容的接口。
- """
- default_output_fields = [
- "file_name",
- "page_number",
- "content",
- "image_path",
- "dataset_id",
- "document_id"
- ]
- def __init__(
- self,
- client: VectorDBClient = None,
- database: str = None,
- vector_field: str = None,
- match_field: str = None,
- match_type: str = None,
- table_name: str = None,
- output_fields: List[str] = None
- ):
- """
- 初始化搜索服务
-
- Args:
- client: 向量数据库客户端(可选,不传则自动创建)
- database: 数据库名称
- vector_field: 向量字段名
- match_field: 匹配字段名
- match_type: 匹配类型
- table_name: 表名
- output_fields: 输出字段列表
- """
- self._client = client or get_vector_db_client(database=database)
- self.output_fields = output_fields or self.default_output_fields
- self.vector_field = vector_field or "dense_vector_1024"
- self.match_field = match_field or "content"
- self.match_type = match_type or "cosine"
- self.table_name = table_name or vector_db_settings.infinity_table_name
- @property
- def client(self) -> VectorDBClient:
- """获取底层客户端"""
- return self._client
- def search(self, request: SearchRequest) -> List[Dict[str, Any]]:
- """
- 执行全文搜索
-
- Args:
- request: 搜索请求
-
- Returns:
- 搜索结果列表
- """
- try:
- search_query = self._convert_search_request_to_search_query(request)
- result = self._client.search(self.table_name, self.output_fields, search_query)
- result_dict = result.to_result()
- return convert_to_json(result_dict)
- except Exception as e:
- logger.error(f"搜索失败: {str(e)}")
- raise Exception(f"搜索失败: {str(e)}")
- def vector_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
- """
- 执行向量搜索
-
- Args:
- request: 搜索请求
-
- Returns:
- 搜索结果列表
- """
- try:
- search_query = self._convert_search_request_to_search_query(request)
- result = self._client.vector_search(self.table_name, self.output_fields, search_query)
- result_dict = result.to_result()
- return convert_to_json(result_dict)
- except Exception as e:
- logger.error(f"向量搜索失败: {str(e)}")
- raise Exception(f"向量搜索失败: {str(e)}")
- def hybrid_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
- """
- 执行混合搜索
-
- Args:
- request: 搜索请求
-
- Returns:
- 搜索结果列表
- """
- try:
- search_query = self._convert_search_request_to_search_query(request)
- result = self._client.hybrid_search(self.table_name, self.output_fields, search_query)
- result_dict = result.to_result()
- return convert_to_json(result_dict)
- except Exception as e:
- logger.error(f"混合搜索失败: {str(e)}")
- raise Exception(f"混合搜索失败: {str(e)}")
- def _convert_search_request_to_search_query(self, request: SearchRequest) -> dict:
- """
- 将 SearchRequest 转换为搜索查询参数
- """
- try:
- query_vector = None
- if request.score_threshold is not None:
- if request.image_url is not None:
- image = image_util._url_to_image(request.image_url)
- query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image)
- else:
- query_vector = get_embedding_model().get_text_embedding(request.matching_text)
-
- search_query = {
- "match_field": self.match_field,
- "matching_text": request.matching_text,
- "vector_field": self.vector_field,
- "query_vector": query_vector,
- "topn": request.topn,
- "knn_params": {
- "ef": str(request.topn * 10),
- "threshold": str(request.score_threshold) if request.score_threshold else "0"
- }
- }
- return search_query
- except Exception as e:
- logger.error(f"参数转换失败: {str(e)}")
- raise Exception(f"参数转换失败: {str(e)}")
- # 向后兼容别名
- SearchService = VectorSearchService
|