search_infinity.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Infinity搜索API服务
  2. from fastapi import FastAPI, HTTPException
  3. from src.api.db.services.vector_search_service import VectorSearchService
  4. from src.utils.vector_db import get_vector_db_client
  5. from src.common.result import Result
  6. from src.utils.async_utils import run_in_threadpool
  7. from src.api.db.models import SearchRequest
  8. from src.conf.settings import vector_db_settings
  9. # 创建FastAPI应用
  10. app = FastAPI(
  11. title="Vector Search API",
  12. description="向量数据库搜索API服务(支持 Infinity / Elasticsearch)",
  13. version="2.0.0"
  14. )
  15. # 1. 普通搜索接口
  16. @app.post("/text")
  17. async def search(request: SearchRequest):
  18. """
  19. 普通搜索接口
  20. - **table_name**: 表名
  21. - **output_fields**: 要返回的字段列表
  22. - **query**: 查询条件,包含field、query和topn字段
  23. - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
  24. """
  25. try:
  26. search_service = VectorSearchService(client=get_vector_db_client())
  27. result = await run_in_threadpool(search_service.search, request)
  28. return Result.success(data=result, message="搜索成功")
  29. except Exception as e:
  30. return Result.error(code=500, message=f"搜索失败: {str(e)}")
  31. # 2. 向量搜索接口
  32. @app.post("/vector")
  33. async def vector_search(request: SearchRequest):
  34. """
  35. 向量搜索接口
  36. - **table_name**: 表名
  37. - **output_fields**: 要返回的字段列表
  38. - **query**: 查询条件,包含vector_field、query_vector和topn字段
  39. - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
  40. """
  41. try:
  42. search_service = VectorSearchService(client=get_vector_db_client())
  43. result = await run_in_threadpool(search_service.vector_search, request)
  44. return Result.success(data=result, message="向量搜索成功")
  45. except Exception as e:
  46. return Result.error(code=500, message=f"向量搜索失败: {str(e)}")
  47. # 3. 混合搜索接口
  48. @app.post("/hybrid")
  49. async def hybrid_search(request: SearchRequest):
  50. """
  51. 混合搜索接口
  52. - **table_name**: 表名
  53. - **output_fields**: 要返回的字段列表
  54. - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
  55. - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
  56. """
  57. try:
  58. search_service = VectorSearchService(client=get_vector_db_client())
  59. result = await run_in_threadpool(search_service.hybrid_search, request)
  60. return Result.success(data=result, message="混合搜索成功")
  61. except Exception as e:
  62. return Result.error(code=500, message=f"混合搜索失败: {str(e)}")
  63. # 4. 问答对检索
  64. @app.post("/question")
  65. async def question_search(request: SearchRequest):
  66. """
  67. 问答对检索接口
  68. - **output_fields**: 要返回的字段列表
  69. - **query**: 查询条件,包含vector_field、query_vector、field、query、topn和fusion_weight字段
  70. - **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
  71. """
  72. try:
  73. output_fields = ["content_with_weight"]
  74. search_service = VectorSearchService(
  75. client=get_vector_db_client(database="ragflow_db"),
  76. table_name="ragflow_f3abf26bf80c11f0953d0242ac180002",
  77. vector_field="q_1024_vec",
  78. output_fields=output_fields
  79. )
  80. result = await run_in_threadpool(search_service.hybrid_search, request)
  81. return Result.success(data=result, message="问答对检索成功")
  82. except Exception as e:
  83. return Result.error(code=500, message=f"问答对检索失败: {str(e)}")