vector_search_service.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. 向量数据库搜索服务
  3. 提供统一的搜索接口,自动根据配置选择底层数据库实现。
  4. """
  5. from typing import Dict, Any, List, Optional
  6. from src.conf.settings import vector_db_settings
  7. from src.utils.vector_db import get_vector_db_client, VectorDBClient
  8. from src.utils.file.image_util import image_util
  9. from src.model.multimodal_embedding import get_embedding_model
  10. from src.utils.infinity.result_util import convert_to_json
  11. from src.api.db.models import SearchRequest
  12. from src.common.logging_config import get_logger
  13. logger = get_logger(__name__)
  14. class VectorSearchService:
  15. """
  16. 向量数据库搜索服务
  17. 自动根据 VECTOR_DB_TYPE 配置选择对应的数据库实现。
  18. 提供与原 InfinitySearchService 兼容的接口。
  19. """
  20. default_output_fields = [
  21. "file_name",
  22. "page_number",
  23. "content",
  24. "image_path",
  25. "dataset_id",
  26. "document_id"
  27. ]
  28. def __init__(
  29. self,
  30. client: VectorDBClient = None,
  31. database: str = None,
  32. vector_field: str = None,
  33. match_field: str = None,
  34. match_type: str = None,
  35. table_name: str = None,
  36. output_fields: List[str] = None
  37. ):
  38. """
  39. 初始化搜索服务
  40. Args:
  41. client: 向量数据库客户端(可选,不传则自动创建)
  42. database: 数据库名称
  43. vector_field: 向量字段名
  44. match_field: 匹配字段名
  45. match_type: 匹配类型
  46. table_name: 表名
  47. output_fields: 输出字段列表
  48. """
  49. self._client = client or get_vector_db_client(database=database)
  50. self.output_fields = output_fields or self.default_output_fields
  51. self.vector_field = vector_field or "dense_vector_1024"
  52. self.match_field = match_field or "content"
  53. self.match_type = match_type or "cosine"
  54. self.table_name = table_name or vector_db_settings.infinity_table_name
  55. @property
  56. def client(self) -> VectorDBClient:
  57. """获取底层客户端"""
  58. return self._client
  59. def search(self, request: SearchRequest) -> List[Dict[str, Any]]:
  60. """
  61. 执行全文搜索
  62. Args:
  63. request: 搜索请求
  64. Returns:
  65. 搜索结果列表
  66. """
  67. try:
  68. search_query = self._convert_search_request_to_search_query(request)
  69. result = self._client.search(self.table_name, self.output_fields, search_query)
  70. result_dict = result.to_result()
  71. return convert_to_json(result_dict)
  72. except Exception as e:
  73. logger.error(f"搜索失败: {str(e)}")
  74. raise Exception(f"搜索失败: {str(e)}")
  75. def vector_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
  76. """
  77. 执行向量搜索
  78. Args:
  79. request: 搜索请求
  80. Returns:
  81. 搜索结果列表
  82. """
  83. try:
  84. search_query = self._convert_search_request_to_search_query(request)
  85. result = self._client.vector_search(self.table_name, self.output_fields, search_query)
  86. result_dict = result.to_result()
  87. return convert_to_json(result_dict)
  88. except Exception as e:
  89. logger.error(f"向量搜索失败: {str(e)}")
  90. raise Exception(f"向量搜索失败: {str(e)}")
  91. def hybrid_search(self, request: SearchRequest) -> List[Dict[str, Any]]:
  92. """
  93. 执行混合搜索
  94. Args:
  95. request: 搜索请求
  96. Returns:
  97. 搜索结果列表
  98. """
  99. try:
  100. search_query = self._convert_search_request_to_search_query(request)
  101. result = self._client.hybrid_search(self.table_name, self.output_fields, search_query)
  102. result_dict = result.to_result()
  103. return convert_to_json(result_dict)
  104. except Exception as e:
  105. logger.error(f"混合搜索失败: {str(e)}")
  106. raise Exception(f"混合搜索失败: {str(e)}")
  107. def _convert_search_request_to_search_query(self, request: SearchRequest) -> dict:
  108. """
  109. 将 SearchRequest 转换为搜索查询参数
  110. """
  111. try:
  112. query_vector = None
  113. if request.score_threshold is not None:
  114. if request.image_url is not None:
  115. image = image_util._url_to_image(request.image_url)
  116. query_vector = get_embedding_model().get_multimodal_embedding(request.matching_text, image)
  117. else:
  118. query_vector = get_embedding_model().get_text_embedding(request.matching_text)
  119. search_query = {
  120. "match_field": self.match_field,
  121. "matching_text": request.matching_text,
  122. "vector_field": self.vector_field,
  123. "query_vector": query_vector,
  124. "topn": request.topn,
  125. "knn_params": {
  126. "ef": str(request.topn * 10),
  127. "threshold": str(request.score_threshold) if request.score_threshold else "0"
  128. }
  129. }
  130. return search_query
  131. except Exception as e:
  132. logger.error(f"参数转换失败: {str(e)}")
  133. raise Exception(f"参数转换失败: {str(e)}")
  134. # 向后兼容别名
  135. SearchService = VectorSearchService