vector_search_service.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """
  2. 向量数据库搜索服务
  3. 提供统一的搜索接口,自动根据配置选择底层数据库实现。
  4. """
  5. from typing import Dict, Any, List, Optional
  6. from src.conf.settings import vector_db_settings, model_settings
  7. from src.utils.vector_db import get_vector_db_client, VectorDBClient
  8. from src.utils.file.image_util import image_util
  9. from src.model.multimodal_embedding import get_embedding_model
  10. from src.api.db.models import SearchRequest
  11. from src.common.logging_config import get_logger
  12. logger = get_logger(__name__)
  13. class VectorSearchService:
  14. """
  15. 向量数据库搜索服务
  16. 自动根据 VECTOR_DB_TYPE 配置选择对应的数据库实现。
  17. 提供与原 InfinitySearchService 兼容的接口。
  18. """
  19. default_output_fields = [
  20. "file_name",
  21. "page_number",
  22. "content",
  23. "image_path",
  24. "dataset_id",
  25. "document_id"
  26. ]
  27. def __init__(
  28. self,
  29. client: VectorDBClient = None,
  30. database: str = None,
  31. vector_field: str = None,
  32. match_field: str = None,
  33. match_type: str = None,
  34. table_name: str = None,
  35. output_fields: List[str] = None
  36. ):
  37. """
  38. 初始化搜索服务
  39. Args:
  40. client: 向量数据库客户端(可选,不传则自动创建)
  41. database: 数据库名称
  42. vector_field: 向量字段名
  43. match_field: 匹配字段名
  44. match_type: 匹配类型
  45. table_name: 表名
  46. output_fields: 输出字段列表
  47. """
  48. self._client = client or get_vector_db_client(database=database)
  49. self.output_fields = output_fields or self.default_output_fields
  50. self.vector_field = vector_field or "dense_vector_1024"
  51. self.match_field = match_field or "content"
  52. self.match_type = match_type or "cosine"
  53. self.table_name = table_name or None
  54. @property
  55. def client(self) -> VectorDBClient:
  56. """获取底层客户端"""
  57. return self._client
  58. def search(self, request: SearchRequest) -> List[Dict[str, Any]]:
  59. """
  60. 执行全文搜索
  61. Args:
  62. request: 搜索请求
  63. Returns:
  64. 搜索结果列表
  65. """
  66. try:
  67. search_query = self._convert_search_request_to_search_query(request)
  68. result = self._client.search(self.table_name, self.output_fields, search_query)
  69. return result.to_result()
  70. except Exception as e:
  71. logger.error(f"搜索失败: {str(e)}")
  72. raise Exception(f"搜索失败: {str(e)}")
  73. def vector_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
  74. """
  75. 执行向量搜索
  76. Args:
  77. request: 搜索请求
  78. Returns:
  79. 搜索结果列表
  80. """
  81. try:
  82. search_query = self._convert_search_request_to_search_query(request)
  83. result = self._client.vector_search(self.table_name, self.output_fields, search_query)
  84. return result.to_result()
  85. except Exception as e:
  86. logger.error(f"向量搜索失败: {str(e)}")
  87. raise Exception(f"向量搜索失败: {str(e)}")
  88. def hybrid_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
  89. """
  90. 执行混合搜索
  91. Args:
  92. request: 搜索请求
  93. Returns:
  94. 搜索结果列表
  95. """
  96. try:
  97. search_query = self._convert_search_request_to_search_query(request)
  98. result = self._client.hybrid_search(table_name=request.database_ids, output_fields=self.output_fields, query=search_query)
  99. return result.to_result()
  100. except Exception as e:
  101. logger.error(f"混合搜索失败: {str(e)}")
  102. raise Exception(f"混合搜索失败: {str(e)}")
  103. def _convert_search_request_to_search_query(self, request: SearchRequest) -> dict:
  104. """
  105. 将 SearchRequest 转换为搜索查询参数
  106. """
  107. try:
  108. query_vector = None
  109. if request.score_threshold is not None:
  110. if request.image_url is not None:
  111. image = image_util._url_to_image(request.image_url)
  112. query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image)
  113. else:
  114. query_vector = get_embedding_model().get_text_embedding(request.matching_text)
  115. search_query = {
  116. "match_field": self.match_field,
  117. "matching_text": request.matching_text,
  118. "vector_field": self.vector_field,
  119. "query_vector": query_vector,
  120. "topn": request.topk,
  121. "kb_id": request.kb_id,
  122. "knn_params": {
  123. "ef": str(request.topk * 10),
  124. "threshold": str(request.score_threshold) if request.score_threshold else "0"
  125. }
  126. }
  127. return search_query
  128. except Exception as e:
  129. logger.error(f"参数转换失败: {str(e)}")
  130. raise Exception(f"参数转换失败: {str(e)}")
  131. # 向后兼容别名
  132. SearchService = VectorSearchService