infinity_search_service.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from typing import Dict, Any, List
  2. from conf.config import VectorDBConfig
  3. from utils.infinity import InfinityClient
  4. from utils.file.image_util import image_util
  5. from model.multimodal_embedding import get_embedding_model
  6. def convert_to_basic_types(obj: Any) -> Any:
  7. """
  8. 递归将对象转换为基本类型,以便Pydantic能够序列化
  9. 特殊处理:当字典中的子项包含相同长度的数组时,将其转换为数组对象结构
  10. 例如:{"a": [1,2], "b": [3,4]} -> [{"a":1, "b":3}, {"a":2, "b":4}]
  11. Args:
  12. obj: 要转换的对象
  13. Returns:
  14. 转换后的基本类型对象
  15. """
  16. if obj is None:
  17. return None
  18. elif isinstance(obj, (str, int, float, bool)):
  19. return obj
  20. elif isinstance(obj, dict):
  21. # 先递归转换所有值
  22. converted = {k: convert_to_basic_types(v) for k, v in obj.items()}
  23. # 检查是否需要转换为数组对象结构
  24. # 条件:所有值都是列表,且长度一致,且长度大于0
  25. values = list(converted.values())
  26. if all(isinstance(v, list) for v in values):
  27. lengths = [len(v) for v in values]
  28. if len(set(lengths)) == 1 and lengths[0] > 0:
  29. # 转换为数组对象结构
  30. result = []
  31. keys = list(converted.keys())
  32. for i in range(lengths[0]):
  33. item = {}
  34. for key in keys:
  35. # 处理数组中可能存在的None值
  36. if i < len(converted[key]):
  37. item[key] = converted[key][i]
  38. else:
  39. item[key] = None
  40. result.append(item)
  41. return result
  42. return converted
  43. elif isinstance(obj, (list, tuple)):
  44. return [convert_to_basic_types(item) for item in obj]
  45. else:
  46. # 对于其他类型,尝试将其转换为字符串或字典
  47. try:
  48. return dict(obj)
  49. except:
  50. return str(obj)
  51. class InfinitySearchService:
  52. def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None, match_type: str = None, table_name: str = None):
  53. self.infinity_client = infinity_client
  54. # 输出字段
  55. self.output_fields = [
  56. "file_name",
  57. "page_number",
  58. "content",
  59. "image_path",
  60. "dataset_id",
  61. "document_id"
  62. ]
  63. self.vector_field = vector_field or "dense_vector_1024"
  64. self.match_field = match_field or "content"
  65. self.match_type = match_type or "cosine"
  66. self.table_name = table_name or VectorDBConfig.get_infinity_table_name()
  67. def search(self, search_query: Dict[str, Any]) -> Dict[str, Any]:
  68. """
  69. 执行Infinity数据库搜索
  70. Args:
  71. search_query: 搜索查询参数
  72. Returns:
  73. 搜索结果,转换为基本类型以便序列化
  74. """
  75. try:
  76. # 执行搜索
  77. result = self.infinity_client.search(self.table_name, self.output_fields, search_query)
  78. # 将结果转换为基本类型,处理可能的复杂类型
  79. result_dict = result.to_result()
  80. # 递归转换所有复杂类型为基本类型
  81. return convert_to_basic_types(result_dict)
  82. except Exception as e:
  83. raise Exception(f"搜索失败: {str(e)}")
  84. def vector_search(self, search_query: Dict[str, Any]):
  85. """
  86. 执行Infinity数据库向量检索
  87. Args:
  88. search_query: 向量检索查询参数
  89. Returns:
  90. 向量检索结果,转换为基本类型以便序列化
  91. """
  92. try:
  93. # 1.处理image_url为image: Image.Image
  94. image = image_util._url_to_image(search_query["image_url"])
  95. # 2.将图片进行向量化
  96. query_vector = get_embedding_model().get_multimodal_embedding(search_query["matching_text"], image)
  97. search_query["vector_field"] = self.vector_field
  98. search_query["query_vector"] = query_vector
  99. # 执行向量检索
  100. result = self.infinity_client.vector_search(self.table_name, self.output_fields, search_query)
  101. # 将结果转换为基本类型,处理可能的复杂类型
  102. result_dict = result.to_result()
  103. # 递归转换所有复杂类型为基本类型
  104. return convert_to_basic_types(result_dict)
  105. except Exception as e:
  106. raise Exception(f"向量检索失败: {str(e)}")
  107. def hybrid_search(self, search_query: Dict[str, Any]):
  108. """
  109. 执行Infinity数据库混合检索
  110. Args:
  111. search_query: 混合检索查询参数
  112. Returns:
  113. 混合检索结果,转换为基本类型以便序列化
  114. """
  115. try:
  116. # 1.处理image_url为image: Image.Image
  117. image = image_util._url_to_image(search_query["image_url"])
  118. # 2.将图片进行向量化
  119. query_vector = get_embedding_model().get_multimodal_embedding(search_query["matching_text"], image)
  120. search_query["vector_field"] = self.vector_field
  121. search_query["query_vector"] = query_vector
  122. search_query["match_field"] = self.match_field
  123. # 执行混合检索
  124. result = self.infinity_client.hybrid_search(self.table_name, self.output_fields, search_query)
  125. # 将结果转换为基本类型,处理可能的复杂类型
  126. result_dict = result.to_result()
  127. # 递归转换所有复杂类型为基本类型
  128. return convert_to_basic_types(result_dict)
  129. except Exception as e:
  130. raise Exception(f"混合检索失败: {str(e)}")