|
@@ -4,9 +4,11 @@ Elasticsearch 向量数据库适配器
|
|
|
实现 VectorDBClient 接口,提供与 Infinity 兼容的 ES 操作。
|
|
实现 VectorDBClient 接口,提供与 Infinity 兼容的 ES 操作。
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
|
|
+import re
|
|
|
import threading
|
|
import threading
|
|
|
-from typing import Dict, Any, List, Optional
|
|
|
|
|
|
|
+from typing import Dict, Any, List, Optional, Tuple
|
|
|
from .base import VectorDBClient
|
|
from .base import VectorDBClient
|
|
|
|
|
+from .result_util import UnifiedSearchResult, build_unified_result
|
|
|
from src.conf.settings import es_settings, vector_db_settings
|
|
from src.conf.settings import es_settings, vector_db_settings
|
|
|
from src.common.logging_config import get_logger
|
|
from src.common.logging_config import get_logger
|
|
|
|
|
|
|
@@ -94,14 +96,14 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
|
|
|
|
|
将 Infinity 风格的查询转换为 ES 查询。
|
|
将 Infinity 风格的查询转换为 ES 查询。
|
|
|
"""
|
|
"""
|
|
|
- index_name = self._get_index_name(table_name, database_name)
|
|
|
|
|
|
|
+ index_name = table_name
|
|
|
|
|
+ match_field = query.get("match_field", "content")
|
|
|
|
|
+ matching_text = query.get("matching_text") or ""
|
|
|
|
|
|
|
|
- # 构建 ES 查询
|
|
|
|
|
|
|
+ # match 的 query 不能为 None,否则 ES 报 VALUE_NULL
|
|
|
es_query = {
|
|
es_query = {
|
|
|
"query": {
|
|
"query": {
|
|
|
- "match": {
|
|
|
|
|
- query.get("match_field", "content"): query.get("matching_text", "")
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ "match": {match_field: matching_text}
|
|
|
},
|
|
},
|
|
|
"size": query.get("topn", 10),
|
|
"size": query.get("topn", 10),
|
|
|
"_source": output_fields
|
|
"_source": output_fields
|
|
@@ -122,7 +124,7 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
|
|
|
|
|
使用 ES 的 knn 查询进行向量检索。
|
|
使用 ES 的 knn 查询进行向量检索。
|
|
|
"""
|
|
"""
|
|
|
- index_name = self._get_index_name(table_name, database_name)
|
|
|
|
|
|
|
+ index_name = table_name
|
|
|
vector_field = query.get("vector_field", "dense_vector")
|
|
vector_field = query.get("vector_field", "dense_vector")
|
|
|
query_vector = query.get("query_vector", [])
|
|
query_vector = query.get("query_vector", [])
|
|
|
topn = query.get("topn", 10)
|
|
topn = query.get("topn", 10)
|
|
@@ -161,27 +163,34 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
|
|
|
|
|
使用 ES 的 bool 查询结合 knn 和 match。
|
|
使用 ES 的 bool 查询结合 knn 和 match。
|
|
|
"""
|
|
"""
|
|
|
- index_name = self._get_index_name(table_name, database_name)
|
|
|
|
|
|
|
+ index_name = table_name
|
|
|
vector_field = query.get("vector_field", "dense_vector")
|
|
vector_field = query.get("vector_field", "dense_vector")
|
|
|
query_vector = query.get("query_vector", [])
|
|
query_vector = query.get("query_vector", [])
|
|
|
match_field = query.get("match_field", "content")
|
|
match_field = query.get("match_field", "content")
|
|
|
- matching_text = query.get("matching_text", "")
|
|
|
|
|
|
|
+ matching_text = query.get("matching_text")
|
|
|
|
|
+ if matching_text is not None and not isinstance(matching_text, str):
|
|
|
|
|
+ matching_text = str(matching_text)
|
|
|
|
|
+ elif matching_text is None:
|
|
|
|
|
+ matching_text = ""
|
|
|
|
|
+ else:
|
|
|
|
|
+ matching_text = matching_text.strip()
|
|
|
topn = query.get("topn", 10)
|
|
topn = query.get("topn", 10)
|
|
|
|
|
|
|
|
|
|
+ # match 的 query 不能为 None,否则 ES 报 VALUE_NULL。无文案时仅用向量(match_all)
|
|
|
|
|
+ if matching_text:
|
|
|
|
|
+ text_clause = {
|
|
|
|
|
+ "match": {
|
|
|
|
|
+ match_field: {"query": matching_text, "boost": 1.0}
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ else:
|
|
|
|
|
+ text_clause = {"match_all": {"boost": 1.0}}
|
|
|
|
|
+
|
|
|
# 构建混合查询
|
|
# 构建混合查询
|
|
|
es_query = {
|
|
es_query = {
|
|
|
"query": {
|
|
"query": {
|
|
|
"bool": {
|
|
"bool": {
|
|
|
- "should": [
|
|
|
|
|
- {
|
|
|
|
|
- "match": {
|
|
|
|
|
- match_field: {
|
|
|
|
|
- "query": matching_text,
|
|
|
|
|
- "boost": 1.0
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- ]
|
|
|
|
|
|
|
+ "should": [text_clause]
|
|
|
}
|
|
}
|
|
|
},
|
|
},
|
|
|
"knn": {
|
|
"knn": {
|
|
@@ -205,7 +214,7 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
database_name: Optional[str] = None
|
|
database_name: Optional[str] = None
|
|
|
) -> Any:
|
|
) -> Any:
|
|
|
"""插入文档"""
|
|
"""插入文档"""
|
|
|
- index_name = self._get_index_name(table_name, database_name)
|
|
|
|
|
|
|
+ index_name = table_name
|
|
|
|
|
|
|
|
# 批量插入
|
|
# 批量插入
|
|
|
operations = []
|
|
operations = []
|
|
@@ -218,6 +227,20 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
return result
|
|
return result
|
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _parse_cond_for_term(cond: str) -> Optional[Tuple[str, str]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 解析 cond 为 (field, value),用于构建 term 查询。
|
|
|
|
|
+ 支持 "field = 'value'" 或 "field = \"value\"" 格式。
|
|
|
|
|
+ """
|
|
|
|
|
+ if not cond or not isinstance(cond, str):
|
|
|
|
|
+ return None
|
|
|
|
|
+ cond = cond.strip()
|
|
|
|
|
+ m = re.match(r"^(\w+)\s*=\s*['\"]([^'\"]*)['\"]$", cond)
|
|
|
|
|
+ if m:
|
|
|
|
|
+ return m.group(1), m.group(2)
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
def update(
|
|
def update(
|
|
|
self,
|
|
self,
|
|
|
table_name: str,
|
|
table_name: str,
|
|
@@ -228,25 +251,66 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
"""
|
|
"""
|
|
|
更新文档
|
|
更新文档
|
|
|
|
|
|
|
|
- 注意:ES 的更新方式与 Infinity 不同,这里使用 update_by_query。
|
|
|
|
|
|
|
+ 使用 update_by_query,且 conflicts="proceed" 避免版本冲突导致整次操作失败。
|
|
|
|
|
+ 若 cond 为 "field = 'value'" 形式,则用 term 查询精确匹配;否则回退到 query_string。
|
|
|
|
|
+ 若 cond 为 "id = 'value'" 或 "_id = 'value'" 形式,则使用更高效的 update API。
|
|
|
"""
|
|
"""
|
|
|
- index_name = self._get_index_name(table_name, database_name)
|
|
|
|
|
|
|
+ index_name = table_name
|
|
|
|
|
|
|
|
- # 将条件字符串解析为 ES 查询
|
|
|
|
|
- # 简单实现:假设 cond 是 "field = 'value'" 格式
|
|
|
|
|
|
|
+ parsed = self._parse_cond_for_term(cond)
|
|
|
|
|
+ if parsed:
|
|
|
|
|
+ field, value = parsed
|
|
|
|
|
+ # 检查是否是id条件
|
|
|
|
|
+ if field in ['id', '_id']:
|
|
|
|
|
+ # 使用update API直接更新指定id的文档
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = self._client.update(
|
|
|
|
|
+ index=index_name,
|
|
|
|
|
+ id=value,
|
|
|
|
|
+ body={"doc": data}
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info(f"update 执行: index={index_name}, id={value}, data={data}")
|
|
|
|
|
+ return result
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"使用update API更新失败,回退到update_by_query: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ # 普通字段条件,使用term查询
|
|
|
|
|
+ es_query_clause = {"term": {field: value}}
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 复杂条件,使用query_string
|
|
|
|
|
+ es_query_clause = {"query_string": {"query": cond or "*"}}
|
|
|
|
|
+
|
|
|
|
|
+ # 构建update_by_query请求
|
|
|
es_query = {
|
|
es_query = {
|
|
|
- "query": {
|
|
|
|
|
- "query_string": {
|
|
|
|
|
- "query": cond
|
|
|
|
|
- }
|
|
|
|
|
- },
|
|
|
|
|
|
|
+ "query": es_query_clause,
|
|
|
"script": {
|
|
"script": {
|
|
|
"source": "; ".join([f"ctx._source.{k} = params.{k}" for k in data.keys()]),
|
|
"source": "; ".join([f"ctx._source.{k} = params.{k}" for k in data.keys()]),
|
|
|
"params": data
|
|
"params": data
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- result = self._client.update_by_query(index=index_name, body=es_query)
|
|
|
|
|
|
|
+ result = self._client.update_by_query(
|
|
|
|
|
+ index=index_name,
|
|
|
|
|
+ body=es_query,
|
|
|
|
|
+ conflicts="proceed"
|
|
|
|
|
+ )
|
|
|
|
|
+ # 打印实际执行的查询
|
|
|
|
|
+ logger.info(f"update_by_query 执行: index={index_name}, cond={cond}, data={data}")
|
|
|
|
|
+
|
|
|
|
|
+ vc = result.get("version_conflicts", 0)
|
|
|
|
|
+ failures = result.get("failures", [])
|
|
|
|
|
+ if vc or failures:
|
|
|
|
|
+ logger.warning(
|
|
|
|
|
+ "update_by_query 存在版本冲突或失败: index=%s, version_conflicts=%s, "
|
|
|
|
|
+ "updated=%s, total=%s, failures_count=%s",
|
|
|
|
|
+ index_name, vc, result.get("updated", 0), result.get("total", 0), len(failures)
|
|
|
|
|
+ )
|
|
|
|
|
+ if failures:
|
|
|
|
|
+ for f in failures[:5]:
|
|
|
|
|
+ logger.warning("update_by_query failure: %s", f)
|
|
|
|
|
+ if len(failures) > 5:
|
|
|
|
|
+ logger.warning("... 及其他 %s 个 failures", len(failures) - 5)
|
|
|
|
|
+
|
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
def get_status(self) -> Dict[str, Any]:
|
|
@@ -274,29 +338,18 @@ class ElasticsearchAdapter(VectorDBClient):
|
|
|
|
|
|
|
|
def _convert_result(self, es_result: Dict, output_fields: List[str]) -> Any:
|
|
def _convert_result(self, es_result: Dict, output_fields: List[str]) -> Any:
|
|
|
"""
|
|
"""
|
|
|
- 转换 ES 结果为统一格式
|
|
|
|
|
-
|
|
|
|
|
- 返回类似 Infinity 的结果结构
|
|
|
|
|
|
|
+ 转换 ES 结果为统一格式:仅含 output_fields + score 的 JSON 数组。
|
|
|
"""
|
|
"""
|
|
|
hits = es_result.get("hits", {}).get("hits", [])
|
|
hits = es_result.get("hits", {}).get("hits", [])
|
|
|
-
|
|
|
|
|
- class MockResult:
|
|
|
|
|
- """模拟 Infinity 结果对象"""
|
|
|
|
|
- def __init__(self, data):
|
|
|
|
|
- self._data = data
|
|
|
|
|
-
|
|
|
|
|
- def to_result(self):
|
|
|
|
|
- return self._data
|
|
|
|
|
-
|
|
|
|
|
- # 转换为统一格式
|
|
|
|
|
- results = []
|
|
|
|
|
- for hit in hits:
|
|
|
|
|
- item = hit.get("_source", {})
|
|
|
|
|
- item["_score"] = hit.get("_score", 0)
|
|
|
|
|
- item["_id"] = hit.get("_id", "")
|
|
|
|
|
- results.append(item)
|
|
|
|
|
-
|
|
|
|
|
- return MockResult(results)
|
|
|
|
|
|
|
+ rows = []
|
|
|
|
|
+ for h in hits:
|
|
|
|
|
+ row = dict(h.get("_source", {}))
|
|
|
|
|
+ row["_score"] = h.get("_score", 0)
|
|
|
|
|
+ rows.append(row)
|
|
|
|
|
+ results = build_unified_result(
|
|
|
|
|
+ rows, output_fields, lambda r: r.get("_score", 0)
|
|
|
|
|
+ )
|
|
|
|
|
+ return UnifiedSearchResult(results)
|
|
|
|
|
|
|
|
# ========== ES 特有方法(扩展) ==========
|
|
# ========== ES 特有方法(扩展) ==========
|
|
|
|
|
|