dify_knowledge_service.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import re
  2. from src.conf.settings import vector_db_settings
  3. from src.utils.infinity import InfinityClient
  4. from src.utils.file.image_util import image_util
  5. from src.model.multimodal_embedding import get_embedding_model
  6. from src.utils.infinity.result_util import convert_to_langchain_docs
  7. from src.api.dataset.models.dify_models import RetrievalRequest
  8. from src.conf.settings import vector_db_settings
  9. class DifyKnowledgeService:
  10. def __init__(self, infinity_client: InfinityClient, vector_field: str = None, match_field: str = None,
  11. match_type: str = None, table_name: str = None):
  12. self.infinity_client = infinity_client
  13. # 输出字段
  14. self.output_fields = [
  15. "file_name",
  16. "page_number",
  17. "content",
  18. "image_path",
  19. "dataset_id",
  20. "document_id",
  21. "_similarity"
  22. ]
  23. self.vector_field = vector_field or "dense_vector_1024"
  24. self.match_field = match_field or "content"
  25. self.match_type = match_type or "cosine"
  26. self.table_name = table_name or vector_db_settings.infinity_table_name
  27. def dify_database_search(self, request: RetrievalRequest):
  28. """
  29. 执行Dify数据库搜索
  30. Args:
  31. retrievalRequest: 搜索查询参数
  32. Returns:
  33. 搜索结果,转换为基本类型以便序列化
  34. """
  35. try:
  36. if request.knowledge_id:
  37. # table_name = f"{vector_db_settings.infinity_dataset_prefix}{request.knowledge_id}"
  38. table_name = self.table_name
  39. else:
  40. # 抛出异常
  41. raise Exception("knowledge_id不能为空")
  42. # 解析格式如: matching_text:点点,match_image:http://xxx 或 matching_text:点点,match_image:http://xxx
  43. # 支持中文和英文的逗号、冒号
  44. input_image = None
  45. input_text = None
  46. query_str = request.query
  47. # 将中文逗号替换为英文逗号,用于分割
  48. query_str_normalized = re.sub(r'[,]', ',', query_str)
  49. # 按逗号分割为多个键值对
  50. pairs = query_str_normalized.split(',')
  51. for pair in pairs:
  52. # 将中文冒号替换为英文冒号,用于分割键值
  53. pair_normalized = re.sub(r'[:]', ':', pair, count=1)
  54. if ':' in pair_normalized:
  55. # 只分割第一个冒号,避免URL中的冒号被分割
  56. key, value = pair_normalized.split(':', 1)
  57. key = key.strip()
  58. value = value.strip()
  59. if key == 'match_image':
  60. input_image = value
  61. elif key == 'matching_text':
  62. input_text = value
  63. # 如果没有解析出任何参数,将整个query作为input_text
  64. if input_image is None and input_text is None:
  65. input_text = request.query
  66. retrieval_setting = request.retrieval_setting
  67. # 1.处理image_url为image: Image.Image
  68. image = image_util._url_to_image(input_image)
  69. # 多模态向量
  70. query_vector = get_embedding_model().get_multimodal_embedding(text=input_text, image=image)
  71. # 构建搜索查询
  72. search_query = {
  73. "vector_field": self.vector_field,
  74. "query_vector": query_vector,
  75. "topn": retrieval_setting.top_k,
  76. "knn_params": {
  77. "ef": str(retrieval_setting.top_k * 10),
  78. "threshold": str(retrieval_setting.score_threshold)
  79. }
  80. }
  81. # 执行搜索
  82. result = self.infinity_client.vector_search(table_name, self.output_fields, search_query)
  83. # 将结果转换为基本类型,处理可能的复杂类型
  84. result_dict = result.to_result()
  85. # 递归转换所有复杂类型为基本类型
  86. return convert_to_langchain_docs(result_dict)
  87. except Exception as e:
  88. raise Exception(f"搜索失败: {str(e)}")