|
@@ -11,6 +11,7 @@ from .base import VectorDBClient
|
|
|
from .result_util import UnifiedSearchResult, build_unified_result
|
|
from .result_util import UnifiedSearchResult, build_unified_result
|
|
|
from src.conf.settings import es_settings, vector_db_settings
|
|
from src.conf.settings import es_settings, vector_db_settings
|
|
|
from src.common.logging_config import get_logger
|
|
from src.common.logging_config import get_logger
|
|
|
|
|
+from src.utils.chinese_text_processor import chinese_processor
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
@@ -214,14 +215,15 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
"query_vector": query_vector,
|
|
"query_vector": query_vector,
|
|
|
"k": topn,
|
|
"k": topn,
|
|
|
"num_candidates": topn * 10,
|
|
"num_candidates": topn * 10,
|
|
|
|
|
+ "similarity": threshold,
|
|
|
"filter": kb_id_filter # kb_id 精准匹配(支持单个或数组)
|
|
"filter": kb_id_filter # kb_id 精准匹配(支持单个或数组)
|
|
|
},
|
|
},
|
|
|
"_source": output_fields
|
|
"_source": output_fields
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- # 如果有阈值,添加 min_score
|
|
|
|
|
- if threshold > 0:
|
|
|
|
|
- search_params["min_score"] = threshold
|
|
|
|
|
|
|
+ # # 如果有阈值,添加 min_score
|
|
|
|
|
+ # if threshold > 0:
|
|
|
|
|
+ # search_params["min_score"] = threshold
|
|
|
|
|
|
|
|
result = self._client.search(**search_params)
|
|
result = self._client.search(**search_params)
|
|
|
return self._convert_result(result, output_fields)
|
|
return self._convert_result(result, output_fields)
|
|
@@ -252,17 +254,23 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
# 构建 kb_id 过滤条件(支持单个值或数组)
|
|
# 构建 kb_id 过滤条件(支持单个值或数组)
|
|
|
kb_id_filter = self._build_kb_id_filter(kb_id)
|
|
kb_id_filter = self._build_kb_id_filter(kb_id)
|
|
|
|
|
|
|
|
- # index_name = "ragbook_1_f3c87b89f82711f0b7450242ac180002,ragbook_2_16db7236f82911f093e20242ac180002"
|
|
|
|
|
index_name = table_name
|
|
index_name = table_name
|
|
|
vector_field = query.get("vector_field", "dense_vector")
|
|
vector_field = query.get("vector_field", "dense_vector")
|
|
|
query_vector = query.get("query_vector", [])
|
|
query_vector = query.get("query_vector", [])
|
|
|
# 确保 query_vector 是 Python 原生列表,兼容 numpy 数组等类型
|
|
# 确保 query_vector 是 Python 原生列表,兼容 numpy 数组等类型
|
|
|
- if query_vector is not None and len(query_vector) > 0:
|
|
|
|
|
- # 强制转换为 Python 原生 float 列表
|
|
|
|
|
- query_vector = [float(x) for x in query_vector]
|
|
|
|
|
- logger.info(f"混合搜索: 已转换 query_vector 类型, 长度={len(query_vector)}, 前3个元素={query_vector[:3]}")
|
|
|
|
|
- else:
|
|
|
|
|
- query_vector = []
|
|
|
|
|
|
|
+ # if query_vector is not None and len(query_vector) > 0:
|
|
|
|
|
+ # # 强制转换为 Python 原生 float 列表
|
|
|
|
|
+ # query_vector = [float(x) for x in query_vector]
|
|
|
|
|
+ # logger.info(f"混合搜索: 已转换 query_vector 类型, 长度={len(query_vector)}, 前3个元素={query_vector[:3]}")
|
|
|
|
|
+ # else:
|
|
|
|
|
+ # query_vector = []
|
|
|
|
|
+
|
|
|
|
|
+ # 将query_vector的值写入D:\project\work\graph_rag_server\logs\app.log
|
|
|
|
|
+ logger.info(f"混合搜索: query_vector={query_vector}")
|
|
|
|
|
+ # 获取阈值
|
|
|
|
|
+ knn_params = query.get("knn_params", {})
|
|
|
|
|
+ threshold = float(knn_params.get("threshold", 0.0))
|
|
|
|
|
+
|
|
|
match_field = query.get("match_field", "content")
|
|
match_field = query.get("match_field", "content")
|
|
|
matching_text = query.get("matching_text")
|
|
matching_text = query.get("matching_text")
|
|
|
if matching_text is not None and not isinstance(matching_text, str):
|
|
if matching_text is not None and not isinstance(matching_text, str):
|
|
@@ -275,24 +283,21 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
|
|
|
|
|
# match 的 query 不能为 None,否则 ES 报 VALUE_NULL。无文案时仅用向量(match_all)
|
|
# match 的 query 不能为 None,否则 ES 报 VALUE_NULL。无文案时仅用向量(match_all)
|
|
|
if matching_text:
|
|
if matching_text:
|
|
|
- text_clause = {
|
|
|
|
|
- "match": {
|
|
|
|
|
- match_field: {"query": matching_text, "boost": 1.0}
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ # 处理中文文本
|
|
|
|
|
+ text_clause = self._build_chinese_query(match_field, matching_text)
|
|
|
else:
|
|
else:
|
|
|
- text_clause = {"match_all": {"boost": 1.0}}
|
|
|
|
|
-
|
|
|
|
|
|
|
+ text_clause = {"match_all": {"boost": 2.0}}
|
|
|
# 构建混合查询(ES 8.x 使用关键字参数而非 body)
|
|
# 构建混合查询(ES 8.x 使用关键字参数而非 body)
|
|
|
# 在 query 的 bool.must 中添加 kb_id 精准匹配,在 knn 的 filter 中也添加
|
|
# 在 query 的 bool.must 中添加 kb_id 精准匹配,在 knn 的 filter 中也添加
|
|
|
|
|
+ # 打印所有查询条件
|
|
|
result = self._client.search(
|
|
result = self._client.search(
|
|
|
index=index_name,
|
|
index=index_name,
|
|
|
query={
|
|
query={
|
|
|
"bool": {
|
|
"bool": {
|
|
|
- "must": [
|
|
|
|
|
|
|
+ "filter": [
|
|
|
kb_id_filter, # kb_id 精准匹配(支持单个或数组)
|
|
kb_id_filter, # kb_id 精准匹配(支持单个或数组)
|
|
|
],
|
|
],
|
|
|
- "should": [text_clause]
|
|
|
|
|
|
|
+ "must": [text_clause]
|
|
|
}
|
|
}
|
|
|
},
|
|
},
|
|
|
knn={
|
|
knn={
|
|
@@ -300,7 +305,8 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
"query_vector": query_vector,
|
|
"query_vector": query_vector,
|
|
|
"k": topn,
|
|
"k": topn,
|
|
|
"num_candidates": topn * 10,
|
|
"num_candidates": topn * 10,
|
|
|
- "boost": 1.0,
|
|
|
|
|
|
|
+ "similarity": threshold,
|
|
|
|
|
+ "boost": 8.0,
|
|
|
"filter": kb_id_filter # kb_id 精准匹配(支持单个或数组)
|
|
"filter": kb_id_filter # kb_id 精准匹配(支持单个或数组)
|
|
|
},
|
|
},
|
|
|
size=topn,
|
|
size=topn,
|
|
@@ -328,6 +334,82 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
return result
|
|
return result
|
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
+ def _build_chinese_query(self, field: str, text: str) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 构建中文查询
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ field: 查询字段
|
|
|
|
|
+ text: 原始文本
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Dict[str, Any]: ES查询DSL
|
|
|
|
|
+ """
|
|
|
|
|
+ # 处理中文文本
|
|
|
|
|
+ processed_result = chinese_processor.process_text(text)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"Chinese text processed: {processed_result}")
|
|
|
|
|
+
|
|
|
|
|
+ # 获取处理结果
|
|
|
|
|
+ words = processed_result.get('words', [])
|
|
|
|
|
+ word_weights = processed_result.get('word_weights', {})
|
|
|
|
|
+
|
|
|
|
|
+ # 构建布尔查询
|
|
|
|
|
+ bool_query = {
|
|
|
|
|
+ "bool": {
|
|
|
|
|
+ "should": []
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 添加原始词查询
|
|
|
|
|
+ for word in words:
|
|
|
|
|
+ weight = word_weights.get(word, 1.0)
|
|
|
|
|
+ # 添加 term 查询(精确匹配)
|
|
|
|
|
+ bool_query["bool"]["should"].append({
|
|
|
|
|
+ "term": {
|
|
|
|
|
+ field: {
|
|
|
|
|
+ "value": word,
|
|
|
|
|
+ "boost": 2.0 * weight
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+ # 添加 match 查询(模糊匹配)
|
|
|
|
|
+ bool_query["bool"]["should"].append({
|
|
|
|
|
+ "match": {
|
|
|
|
|
+ field: {
|
|
|
|
|
+ "query": word,
|
|
|
|
|
+ "boost": 1.0 * weight
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 添加短语查询(保留词序)
|
|
|
|
|
+ if len(words) > 1:
|
|
|
|
|
+ phrase_query = " ".join(words)
|
|
|
|
|
+ bool_query["bool"]["should"].append({
|
|
|
|
|
+ "match_phrase": {
|
|
|
|
|
+ field: {
|
|
|
|
|
+ "query": phrase_query,
|
|
|
|
|
+ "boost": 3.0
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 添加原始文本匹配
|
|
|
|
|
+ bool_query["bool"]["should"].append({
|
|
|
|
|
+ "match": {
|
|
|
|
|
+ field: {
|
|
|
|
|
+ "query": text,
|
|
|
|
|
+ "boost": 1.5
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 设置 minimum_should_match,确保至少匹配一个条件
|
|
|
|
|
+ bool_query["bool"]["minimum_should_match"] = 1
|
|
|
|
|
+
|
|
|
|
|
+ return bool_query
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def _parse_cond_for_term(cond: str) -> Optional[Tuple[str, str]]:
|
|
def _parse_cond_for_term(cond: str) -> Optional[Tuple[str, str]]:
|
|
|
"""
|
|
"""
|