|
|
@@ -5,8 +5,9 @@ from src.api.db.services.vector_search_service import VectorSearchService
|
|
|
from src.utils.vector_db import get_vector_db_client
|
|
|
from src.common.result import Result
|
|
|
from src.utils.async_utils import run_in_threadpool
|
|
|
+from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
|
|
|
from src.api.db.models import SearchRequest
|
|
|
-from src.conf.settings import vector_db_settings
|
|
|
+from src.conf.settings import ragflow_settings
|
|
|
|
|
|
# 创建FastAPI应用
|
|
|
|
|
|
@@ -84,10 +85,15 @@ async def question_search(request: SearchRequest):
|
|
|
output_fields = ["content_with_weight"]
|
|
|
search_service = VectorSearchService(
|
|
|
client=get_vector_db_client(database="ragflow_db"),
|
|
|
- table_name="ragflow_f3abf26bf80c11f0953d0242ac180002",
|
|
|
vector_field="q_1024_vec",
|
|
|
output_fields=output_fields
|
|
|
)
|
|
|
+ ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(2)
|
|
|
+ ragflow_id = ragflow_user.get("ragflow_id")
|
|
|
+ if not ragflow_id:
|
|
|
+ return Result.error(code=500, message="未找到ragflow_id")
|
|
|
+ question_table_name = f"{ragflow_settings.ragflow_dataset_prefix}_{ragflow_id}"
|
|
|
+ request.database_ids = [question_table_name]
|
|
|
result = await run_in_threadpool(search_service.hybrid_search, request)
|
|
|
return Result.success(data=result, message="问答对检索成功")
|
|
|
except Exception as e:
|