|
|
@@ -29,6 +29,12 @@ def get_ragflow_search_service():
|
|
|
logger.info("Search service initialized")
|
|
|
return _search_service
|
|
|
|
|
|
+def get_multi_search_service():
|
|
|
+ global _multi_search_service
|
|
|
+ _multi_search_service = VectorSearchService(client=get_vector_db_client())
|
|
|
+ logger.info("Multi search service initialized")
|
|
|
+ return _multi_search_service
|
|
|
+
|
|
|
|
|
|
def _get_question_table_name():
|
|
|
"""获取问答对表名"""
|
|
|
@@ -64,7 +70,7 @@ async def search(request: SearchRequest):
|
|
|
- **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
|
|
|
"""
|
|
|
try:
|
|
|
- search_service = VectorSearchService(client=get_vector_db_client())
|
|
|
+ search_service = get_multi_search_service()
|
|
|
result = await run_in_threadpool(search_service.search, request)
|
|
|
return Result.success(data=result, message="搜索成功")
|
|
|
except Exception as e:
|
|
|
@@ -82,7 +88,7 @@ async def vector_search(request: SearchRequest):
|
|
|
- **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
|
|
|
"""
|
|
|
try:
|
|
|
- search_service = VectorSearchService(client=get_vector_db_client())
|
|
|
+ search_service = get_multi_search_service()
|
|
|
result = await run_in_threadpool(search_service.vector_search, request)
|
|
|
return Result.success(data=result, message="向量搜索成功")
|
|
|
except Exception as e:
|
|
|
@@ -100,7 +106,7 @@ async def hybrid_search(request: SearchRequest):
|
|
|
- **database_name**: 数据库名称(可选,默认使用客户端配置的数据库)
|
|
|
"""
|
|
|
try:
|
|
|
- search_service = VectorSearchService(client=get_vector_db_client())
|
|
|
+ search_service = get_multi_search_service()
|
|
|
result = await run_in_threadpool(search_service.hybrid_search, request)
|
|
|
return Result.success(data=result, message="混合搜索成功")
|
|
|
except Exception as e:
|
|
|
@@ -125,7 +131,7 @@ async def question_search(request: SearchRequest):
|
|
|
database_ids_list = [db.strip() for db in request.database_ids_str.split(',')]
|
|
|
request.database_ids = database_ids_list
|
|
|
|
|
|
- if not request.kb_ids:
|
|
|
+ if not request.kb_id:
|
|
|
kb_ids_list = [kb.strip() for kb in request.kb_ids_str.split(',')]
|
|
|
request.kb_id = kb_ids_list
|
|
|
|