search.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. """
  2. Elasticsearch 搜索管理
  3. """
  4. from typing import List, Dict, Any, Optional
  5. from services.utils.es.base import ESConnection
  6. class SearchManager:
  7. """
  8. Elasticsearch 搜索管理器
  9. 负责:
  10. - 全文检索
  11. - 向量相似度检索(k-NN)
  12. - 混合检索(文本+向量)
  13. - 高亮显示
  14. """
  15. def __init__(self, es_connection: Optional[ESConnection] = None):
  16. """
  17. 初始化搜索管理器
  18. Args:
  19. es_connection: ES 连接实例,可选
  20. """
  21. self.es_conn = es_connection or ESConnection()
  22. self.es = self.es_conn.get_client()
  23. def search(self, index_name: str, query: Dict[str, Any], size: int = 10, from_: int = 0,
  24. fields: List[str] = None, highlight: Dict[str, Any] = None) -> Dict[str, Any]:
  25. """
  26. 搜索文档
  27. Args:
  28. index_name: 索引名称
  29. query: 查询条件
  30. size: 返回结果数量
  31. from_: 起始位置
  32. fields: 要返回的字段列表,可选
  33. highlight: 高亮配置,可选
  34. Returns:
  35. Dict: 搜索结果
  36. """
  37. try:
  38. body = {
  39. "query": query,
  40. "size": size,
  41. "from": from_
  42. }
  43. if fields:
  44. body["_source"] = fields
  45. if highlight:
  46. body["highlight"] = highlight
  47. result = self.es.search(index=index_name, body=body)
  48. return result
  49. except Exception as e:
  50. print(f"搜索失败: {e}")
  51. return {"hits": {"total": 0, "hits": []}}
  52. def hybrid_search(self, index_name: str, text_query: str, vector_field: str, vector: List[float],
  53. size: int = 10, from_: int = 0, fields: List[str] = None,
  54. text_weight: float = 0.5, vector_weight: float = 0.5) -> Dict[str, Any]:
  55. """
  56. 混合检索:向量相似度检索 + 全文检索
  57. Args:
  58. index_name: 索引名称
  59. text_query: 文本查询语句
  60. vector_field: 向量字段名
  61. vector: 检索向量
  62. size: 返回结果数量
  63. from_: 起始位置
  64. fields: 要返回的字段列表,可选
  65. text_weight: 文本检索权重
  66. vector_weight: 向量检索权重
  67. Returns:
  68. Dict: 搜索结果
  69. """
  70. try:
  71. # 构建混合检索查询
  72. query = {
  73. "bool": {
  74. "should": [
  75. {
  76. "query_string": {
  77. "query": text_query,
  78. "default_operator": "OR",
  79. "boost": text_weight
  80. }
  81. },
  82. {
  83. "script_score": {
  84. "query": {
  85. "match_all": {}
  86. },
  87. "script": {
  88. "source": "cosineSimilarity(params.query_vector, doc[params.vector_field]) + 1.0",
  89. "params": {
  90. "query_vector": vector,
  91. "vector_field": vector_field
  92. }
  93. },
  94. "boost": vector_weight
  95. }
  96. }
  97. ]
  98. }
  99. }
  100. body = {
  101. "query": query,
  102. "size": size,
  103. "from": from_
  104. }
  105. if fields:
  106. body["_source"] = fields
  107. result = self.es.search(index=index_name, body=body)
  108. return result
  109. except Exception as e:
  110. print(f"混合检索失败: {e}")
  111. return {"hits": {"total": 0, "hits": []}}
  112. def knn_search(self, index_name: str, vector_field: str, vector: List[float],
  113. k: int = 10, filter_query: Dict[str, Any] = None) -> Dict[str, Any]:
  114. """
  115. 向量相似度检索(k-NN)
  116. Args:
  117. index_name: 索引名称
  118. vector_field: 向量字段名
  119. vector: 检索向量
  120. k: 返回结果数量
  121. filter_query: 过滤条件,可选
  122. Returns:
  123. Dict: 搜索结果
  124. """
  125. try:
  126. knn = {
  127. "field": vector_field,
  128. "query_vector": vector,
  129. "k": k,
  130. "num_candidates": k * 10
  131. }
  132. if filter_query:
  133. knn["filter"] = filter_query
  134. body = {
  135. "knn": knn
  136. }
  137. result = self.es.search(index=index_name, body=body)
  138. return result
  139. except Exception as e:
  140. print(f"向量检索失败: {e}")
  141. return {"hits": {"total": 0, "hits": []}}
  142. def match_search(self, index_name: str, field: str, value: str, size: int = 10,
  143. fields: List[str] = None) -> Dict[str, Any]:
  144. """
  145. 简单匹配搜索
  146. Args:
  147. index_name: 索引名称
  148. field: 字段名
  149. value: 匹配值
  150. size: 返回结果数量
  151. fields: 要返回的字段列表,可选
  152. Returns:
  153. Dict: 搜索结果
  154. """
  155. query = {
  156. "match": {
  157. field: value
  158. }
  159. }
  160. return self.search(index_name, query, size=size, fields=fields)
  161. def match_all(self, index_name: str, size: int = 10, fields: List[str] = None) -> Dict[str, Any]:
  162. """
  163. 匹配所有文档
  164. Args:
  165. index_name: 索引名称
  166. size: 返回结果数量
  167. fields: 要返回的字段列表,可选
  168. Returns:
  169. Dict: 搜索结果
  170. """
  171. query = {
  172. "match_all": {}
  173. }
  174. return self.search(index_name, query, size=size, fields=fields)