search.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. """
  2. Infinity向量数据库搜索管理器
  3. """
  4. from typing import List, Dict, Any, Optional
  5. from services.utils.infinity.base import InfinityConnection
  6. class InfinitySearchManager:
  7. """
  8. Infinity向量数据库搜索管理器
  9. 负责处理各种搜索操作
  10. """
  11. def __init__(self, infinity_connection: Optional[InfinityConnection] = None):
  12. """
  13. 初始化搜索管理器
  14. Args:
  15. infinity_connection: Infinity连接实例,可选
  16. """
  17. self.infinity_conn = infinity_connection or InfinityConnection()
  18. def search(self, index_name: str, query: Dict[str, Any], size: int = 10) -> Dict[str, Any]:
  19. """
  20. 全文检索
  21. Args:
  22. index_name: 索引名称
  23. query: 查询条件
  24. size: 返回结果数量
  25. Returns:
  26. Dict: 搜索结果
  27. """
  28. try:
  29. path = f"/api/collections/{index_name}/search"
  30. response = self.infinity_conn._make_request("POST", path, {
  31. "query": query,
  32. "limit": size
  33. })
  34. if "error" not in response:
  35. return {
  36. "hits": {
  37. "total": response.get("total", 0),
  38. "hits": [{
  39. "_source": doc
  40. } for doc in response.get("documents", [])]
  41. }
  42. }
  43. return {"hits": {"total": 0, "hits": []}}
  44. except Exception as e:
  45. print(f"Infinity搜索失败: {e}")
  46. return {"hits": {"total": 0, "hits": []}}
  47. def vector_search(self, index_name: str, vector_field: str, vector: List[float],
  48. size: int = 10, filter: Dict[str, Any] = None) -> Dict[str, Any]:
  49. """
  50. 向量检索
  51. Args:
  52. index_name: 索引名称
  53. vector_field: 向量字段名
  54. vector: 检索向量
  55. size: 返回结果数量
  56. filter: 过滤条件,可选
  57. Returns:
  58. Dict: 搜索结果
  59. """
  60. try:
  61. path = f"/api/collections/{index_name}/search"
  62. search_query = {
  63. "vector": {
  64. "field": vector_field,
  65. "query": vector,
  66. "limit": size
  67. }
  68. }
  69. if filter:
  70. search_query["filter"] = filter
  71. response = self.infinity_conn._make_request("POST", path, search_query)
  72. if "error" not in response:
  73. return {
  74. "hits": {
  75. "total": response.get("total", 0),
  76. "hits": [{
  77. "_source": doc
  78. } for doc in response.get("documents", [])]
  79. }
  80. }
  81. return {"hits": {"total": 0, "hits": []}}
  82. except Exception as e:
  83. print(f"Infinity向量检索失败: {e}")
  84. return {"hits": {"total": 0, "hits": []}}
  85. def hybrid_search(self, index_name: str, text_query: str, vector_field: str, vector: List[float],
  86. size: int = 10, text_weight: float = 0.5, vector_weight: float = 0.5) -> Dict[str, Any]:
  87. """
  88. 混合检索:文本检索 + 向量检索
  89. Args:
  90. index_name: 索引名称
  91. text_query: 文本查询
  92. vector_field: 向量字段名
  93. vector: 检索向量
  94. size: 返回结果数量
  95. text_weight: 文本检索权重
  96. vector_weight: 向量检索权重
  97. Returns:
  98. Dict: 搜索结果
  99. """
  100. try:
  101. path = f"/api/collections/{index_name}/search"
  102. search_query = {
  103. "hybrid": {
  104. "text": {
  105. "query": text_query,
  106. "fields": ["text"],
  107. "weight": text_weight
  108. },
  109. "vector": {
  110. "field": vector_field,
  111. "query": vector,
  112. "weight": vector_weight
  113. },
  114. "limit": size
  115. }
  116. }
  117. response = self.infinity_conn._make_request("POST", path, search_query)
  118. if "error" not in response:
  119. return {
  120. "hits": {
  121. "total": response.get("total", 0),
  122. "hits": [{
  123. "_source": doc
  124. } for doc in response.get("documents", [])]
  125. }
  126. }
  127. return {"hits": {"total": 0, "hits": []}}
  128. except Exception as e:
  129. print(f"Infinity混合检索失败: {e}")
  130. return {"hits": {"total": 0, "hits": []}}
  131. def match_search(self, index_name: str, field: str, value: str, size: int = 10) -> Dict[str, Any]:
  132. """
  133. 简单匹配搜索
  134. Args:
  135. index_name: 索引名称
  136. field: 字段名
  137. value: 匹配值
  138. size: 返回结果数量
  139. Returns:
  140. Dict: 搜索结果
  141. """
  142. query = {
  143. "match": {
  144. field: value
  145. }
  146. }
  147. return self.search(index_name, query, size=size)
  148. def match_all(self, index_name: str, size: int = 10) -> Dict[str, Any]:
  149. """
  150. 匹配所有文档
  151. Args:
  152. index_name: 索引名称
  153. size: 返回结果数量
  154. Returns:
  155. Dict: 搜索结果
  156. """
  157. query = {
  158. "match_all": {}
  159. }
  160. return self.search(index_name, query, size=size)