本文档详细解析 es_conn.py 文件的设计思路与具体实现。该文件实现了一个基于 Elasticsearch 的数据库连接器,提供完整的 CRUD 操作支持,并具备混合搜索(文本+向量)、批量处理、重试机制等高级功能。
文件开头的 Apache License 2.0 声明表明这是一个开源项目,遵循 Apache 开源协议。该协议允许自由使用、修改和分发代码,但需要保留原始版权声明和许可证文本。
import re # 正则表达式处理
import json # JSON 序列化
import time # 时间控制
import copy # 深拷贝操作
from elasticsearch_dsl import UpdateByQuery, Q, Search # ES DSL 查询构建器
from elastic_transport import ConnectionTimeout # 连接超时异常
from common.decorator import singleton # 单例装饰器
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr # 查询表达式基类
from common.doc_store.es_conn_base import ESConnectionBase # 基础连接类
from common.float_utils import get_float # 浮点数处理工具
from common.constants import PAGERANK_FLD, TAG_FLD # 常量定义
设计思路:
elasticsearch_dsl 库构建类型安全的查询 DSL(领域特定语言)singleton 装饰器确保全局只有一个连接实例,避免连接池浪费ConnectionTimeout 用于优雅处理超时情况ESConnectionBase 遵循开放/封闭原则,便于扩展ATTEMPT_TIME = 2
设计思路:定义重试次数为 2 次,在可靠性和性能之间取得平衡。网络操作可能因临时故障失败,轻量级重试机制可提高成功率。
@singleton
class ESConnection(ESConnectionBase):
设计思路:
ESConnectionBase 抽象基类,遵循依赖倒置原则@singleton 装饰器确保单例模式,整个应用生命周期内只有一个实例def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
参数设计思路:
select_fields 和 highlight_fields 分离了字段选择和高亮显示的关注点condition 使用字典存储查询条件,提供了灵活的条件组合方式match_expressions 使用表达式模式,支持多种匹配类型的组合knowledgebase_ids 支持多知识库搜索场景if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition"
设计思路:
_id 作为查询条件,防止潜在的注入风险bool_query = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bool_query.filter.append(Q("range", available_int={"lt": 1}))
else:
bool_query.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1}")))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(...)
设计思路:
available_int 字段,实现可用性过滤逻辑terms(列表)或 term(单值)查询方式s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
设计思路:
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
设计思路:
query_string 查询支持复杂的查询语法best_fields 类型在多字段搜索中返回最佳匹配字段的分数elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
设计思路:
topn * 2 的设计是为了在应用其他过滤条件后仍能返回足够的候选结果if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
设计思路:
rank_feature 查询提升特定特征的权重tag. 前缀,符合数据模型的命名规范should 查询不会排除结果,只会调整相关性分数for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
设计思路:
unmapped_type,避免排序时因字段不存在而失败if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
设计思路:
terms 聚合统计字段值的分布size=1000000 设置较大的聚合桶数量,确保结果完整性if limit > 0:
s = s[offset:offset + limit]
设计思路:
for i in range(ATTEMPT_TIME):
try:
res = self.es.search(index=index_names,
body=q,
timeout="600s",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
设计思路:
track_total_hits=True 确保返回准确的命中总数def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
设计思路:
documents 参数接收字典列表,支持批量插入knowledgebase_id 参数将知识库 ID 注入每条文档,实现数据隔离operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebase_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
设计思路:
_id 作为字段名,防止与 ES 内部 ID 冲突id 字段,作为 ES 文档 IDid 字段重命名为 ES 的 _id,并添加 kb_id 字段for _ in range(ATTEMPT_TIME):
try:
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
设计思路:
refresh=False 提高写入性能,不要求立即刷新到磁盘def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
设计思路:
if "id" in condition and isinstance(condition["id"], str):
chunk_id = condition["id"]
for i in range(ATTEMPT_TIME):
for k in doc.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(...)
try:
self.es.update(index=index_name, id=chunk_id, doc=doc)
return True
except Exception as e:
self.logger.exception(...)
break
return False
设计思路:
feas 结尾的字段,使用脚本执行删除操作bool_query = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(...)
设计思路:
exists 查询检查字段是否存在scripts = []
params = {}
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(...)
设计思路:
remove 操作支持单个字段删除和数组元素删除add 操作支持向数组添加元素ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
设计思路:
UpdateByQuery 用于批量更新符合条件的文档refresh=True 确保更新后立即对搜索可见slices=5 启用并行处理,提高大数据量更新的性能conflicts="proceed" 允许在文档被其他操作修改时继续执行def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
设计思路:
assert "_id" not in condition
condition["kb_id"] = knowledgebase_id
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
chunk_ids = [chunk_ids]
if not chunk_ids:
qry = Q("match_all")
else:
qry = Q("ids", values=chunk_ids)
else:
qry = Q("bool")
for k, v in condition.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
设计思路:
id 参数支持单个 ID 或 ID 列表,灵活处理删除范围match_all 删除全部文档ids 查询高效删除指定文档exists 和 must_not 高级条件self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
设计思路:
delete_by_query API 批量删除文档refresh=True 确保删除操作立即生效deleted 字段表示实际删除的文档数量def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self._get_source(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if n == "available_int" and isinstance(v, (int, float)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(v)
# if n.find("tks") > 0:
# m[n] = remove_redundant_spaces(m[n])
if m:
res_fields[d["id"]] = m
return res_fields
设计思路:
_get_source 方法提取搜索结果的源文档(由基类提供)使用 @singleton 装饰器确保全局唯一的连接实例,避免资源浪费。
通过 match_expressions 参数支持多种匹配策略(文本、向量、融合)的组合。
继承自 ESConnectionBase,遵循基类定义的接口规范。
重试机制将处理责任依次传递给自身,实现优雅降级。
使用 Elasticsearch DSL 构建复杂的查询对象。
insert 方法使用 Bulk API 批量插入,减少网络往返开销。
UpdateByQuery 的 slices=5 参数启用多线程并行处理。
refresh=False 避免每次写入后立即刷新,提高写入性能。
600 秒超时适应复杂查询,避免无限期等待。
空值条件自动跳过,减少不必要的查询开销。
检测到 ConnectionTimeout 异常后重连并重试。
收集批量操作中每个文档的错误信息。
使用断言确保关键参数的有效性。
所有异常和关键操作都有日志记录,便于问题排查。
MatchExpr 系列类支持扩展新的匹配类型。
继承 ESConnectionBase 可实现其他数据库的连接器。
超时、重试次数等参数可配置,适应不同环境需求。