test_search.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import time
  2. from langchain.chat_models import init_chat_model
  3. from model.multimodal_embedding import Embedding
  4. from conf.settings import model_settings, ragflow_settings, tag_search_settings
  5. from utils.infinity import get_client
  6. from model.jina_rerank import JinaRerank
  7. from langchain_core.documents import Document
  8. from utils.infinity.result_util import convert_to_langchain_docs
  9. # 初始化多模态嵌入模型
  10. embedding_model = Embedding(
  11. model_name=model_settings.embedding_model_name,
  12. api_key=model_settings.api_key
  13. )
  14. llm = init_chat_model(
  15. model_provider=model_settings.model_provider,
  16. model=model_settings.model_name,
  17. base_url=model_settings.base_url,
  18. api_key=model_settings.api_key,
  19. temperature=0.3
  20. )
  21. compressor = JinaRerank(
  22. base_url=model_settings.base_url,
  23. jina_api_key=model_settings.api_key,
  24. model=model_settings.rank_model_name, # 建议先确认硅基后台此模型 ID 是否正确
  25. top_n=5
  26. )
  27. output_fields = [
  28. "important_keywords",
  29. "content",
  30. "kb_id",
  31. "doc_id"
  32. ]
  33. search_query = {
  34. "matching_text": """
  35. 莉莉兰中都有哪些小虫虫
  36. """,
  37. "query_vector": [],
  38. "vector_field": "q_1024_vec",
  39. "match_field": "content",
  40. "topn": 5
  41. }
  42. def main():
  43. infinity_client = get_client(database="ragflow_db")
  44. # 问题向量化
  45. # 记录开始时间
  46. start_time = time.time()
  47. embedding = embedding_model.get_text_embedding(search_query["matching_text"])
  48. embedding_time = time.time() - start_time
  49. print(f"向量化耗时: {embedding_time:.4f} 秒")
  50. search_query["query_vector"] = embedding
  51. # TAG_TABLE_NAME="ragflow_92162247e93e11f084830242ac1d0002_52275b36f03611f0a5340242c0a85002"
  52. TAG_TABLE_NAME="ragflow_92162247e93e11f084830242ac1d0002_6d2e0990f28b11f0b5200242c0a85002"
  53. results = infinity_client.vector_search(TAG_TABLE_NAME, output_fields, search_query).to_result()
  54. candidate_docs = convert_to_langchain_docs(results)
  55. # print(candidate_docs)
  56. # # 4. 直接调用重排序逻辑
  57. reranked_docs = compressor.compress_documents(
  58. documents=candidate_docs,
  59. query=search_query["matching_text"],
  60. top_n=3
  61. )
  62. # # print(reranked_docs)
  63. # # 5. 查看结果
  64. for i, doc in enumerate(reranked_docs):
  65. print(f"排名 {i+1}: 分数 {doc.metadata['relevance_score']:.4f}")
  66. print(f"内容: {doc.page_content}")
  67. print(f"标签: {doc.metadata['important_keywords']}\n")
  68. # 记录结束时间并计算执行时间
  69. end_time = time.time()
  70. execution_time = end_time - start_time
  71. print(f"执行时间: {execution_time:.4f} 秒")
  72. # from utils.asymmetric_encryption import AsymmetricEncryption
  73. # def main2():
  74. # # passwd = "zhangqi@lelequ.net"
  75. # # loaded_public_pem = AsymmetricEncryption.load_key_from_file(r"D:\project\work\graph_rag_server\public_key.pem")
  76. # loaded_private_pem = AsymmetricEncryption.load_key_from_file(r"D:\project\work\graph_rag_server\private_key.pem")
  77. # # encrypted = AsymmetricEncryption.encrypt(passwd, loaded_public_pem)
  78. # # print(encrypted)
  79. # decrypted = AsymmetricEncryption.decrypt(ragflow_settings.ragflow_user_name, loaded_private_pem)
  80. # print(decrypted)
  81. # # assert decrypted2 == test_message, "使用加载的密钥解密失败!"
  82. # # print("✓ 使用加载的密钥加密解密测试通过!")
  83. if __name__ == "__main__":
  84. main()