| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- # Infinity数据库客户端实现
- from infinity.common import ConflictType
- from typing import Dict, Any, List, Optional
- import threading
- from contextlib import contextmanager
- from conf.settings import vector_db_settings
- from .pool import InfinityConnectionPool
- class InfinityClient:
- """
- Infinity数据库客户端,基于连接池
-
- 提供了常用的数据库操作方法,包括:
- - 数据库操作
- - 表操作
- - 文档操作
- - 搜索操作
- """
-
- def __init__(
- self,
- host: str = vector_db_settings.infinity_host,
- port: str = vector_db_settings.infinity_sdk_port,
- database: str = vector_db_settings.infinity_database,
- min_connections: int = 2,
- max_connections: int = 10
- ):
- """
- 初始化Infinity客户端
-
- Args:
- host: Infinity服务地址
- port: Infinity服务端口
- database: 数据库名称
- min_connections: 最小连接数
- max_connections: 最大连接数
- """
- self.pool = InfinityConnectionPool(
- host=host,
- port=port,
- database=database,
- min_connections=min_connections,
- max_connections=max_connections
- )
-
- def _get_database(self, conn: Any, database_name: Optional[str] = None):
- """
- 获取数据库对象
-
- Args:
- conn: Infinity连接对象
- database_name: 数据库名称(可选,默认使用客户端配置的数据库)
-
- Returns:
- Database对象
- """
- db_name = database_name or self.pool.database
- return conn.get_database(db_name)
-
- def _get_table(self, conn: Any, table_name: str, database_name: Optional[str] = None):
- """
- 获取表对象
-
- Args:
- conn: Infinity连接对象
- table_name: 表名
- database_name: 数据库名称(可选,默认使用客户端配置的数据库)
-
- Returns:
- Table对象
- """
- db = self._get_database(conn, database_name)
- return db.get_table(table_name)
-
- @contextmanager
- def get_connection(self, timeout: Optional[int] = None):
- """
- 获取一个连接上下文
-
- Args:
- timeout: 获取连接的超时时间(秒)
- """
- with self.pool.get_connection(timeout) as conn:
- yield conn
-
- def list_databases(self) -> List[str]:
- """获取所有数据库"""
- with self.pool.list_connection() as conn:
- return conn.list_databases()
-
- def get_databases(self) -> List[str]:
- """获取所有数据库(别名,向后兼容)"""
- return self.list_databases()
-
- def create_database(self, database_name: str, conflict_type: str = ConflictType.Error, comment: Optional[str] = None):
- """创建数据库"""
- with self.pool.get_connection() as conn:
- return conn.create_database(database_name, conflict_type, comment)
-
- def drop_database(self, database_name: str, conflict_type: str = "Error"):
- """删除数据库"""
- with self.pool.get_connection() as conn:
- return conn.drop_database(database_name, conflict_type)
-
- def use_database(self, database_name: str):
- """切换客户端默认数据库
-
- 注意:此操作只会影响当前客户端实例创建的新连接,不会影响已存在的连接
- """
- # 更新客户端的默认数据库
- self.pool.database = database_name
-
- def list_tables(self, database_name: Optional[str] = None) -> List[str]:
- """获取所有表"""
- with self.pool.get_connection() as conn:
- # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_tables方法
- db = self._get_database(conn, database_name)
- return db.list_tables()
-
- def get_tables(self, database_name: Optional[str] = None) -> List[str]:
- """获取所有表(别名,向后兼容)"""
- return self.list_tables(database_name)
-
- def get_table(self, table_name: str, database_name: Optional[str] = None):
- """获取表"""
- with self.pool.get_connection() as conn:
- # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_table方法
- return self._get_table(conn, table_name, database_name)
- def create_table(
- self,
- table_name: str,
- columns_definition: List[Dict[str, Any]],
- conflict_type: str = ConflictType.Error,
- database_name: Optional[str] = None
- ):
- """创建表"""
- with self.pool.get_connection() as conn:
- # 根据官方API,应该先获取Database对象,然后在Database对象上调用create_table方法
- db = self._get_database(conn, database_name)
- return db.create_table(table_name, columns_definition, conflict_type)
- def drop_table(self, table_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error):
- """删除表"""
- with self.pool.get_connection() as conn:
- # 根据官方API,应该先获取Database对象,然后在Database对象上调用drop_table方法
- db = self._get_database(conn, database_name)
- return db.drop_table(table_name, conflict_type)
- def create_index(self, table_name: str, index_name: str, index_info: Dict[str, Any], database_name: Optional[str] = None, conflict_type = ConflictType.Error, comment = None):
- """创建索引"""
- with self.pool.get_connection() as conn:
- # 使用辅助方法获取Table对象
- table = self._get_table(conn, table_name, database_name)
- return table.create_index(index_name, index_info, conflict_type, comment)
-
- def optimize(self, table_name: str, database_name: Optional[str] = None):
- """优化表"""
- with self.pool.get_connection() as conn:
- # 使用辅助方法获取Table对象
- table = self._get_table(conn, table_name, database_name)
- return table.optimize()
- def drop_index(self, table_name: str, index_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error):
- """删除索引"""
- with self.pool.get_connection() as conn:
- # 使用辅助方法获取Table对象
- table = self._get_table(conn, table_name, database_name)
- return table.drop_index(index_name, conflict_type)
- def insert(
- self,
- table_name: str,
- documents: List[Dict[str, Any]],
- database_name: Optional[str] = None
- ):
- """插入文档"""
- with self.pool.get_connection() as conn:
- table = self._get_table(conn, table_name, database_name)
- return table.insert(documents)
-
- def update(
- self,
- table_name: str,
- cond: str,
- data: Dict[str, Any],
- database_name: Optional[str] = None
- ):
- """更新文档"""
- with self.pool.get_connection() as conn:
- # 获取Table对象
- table = self._get_table(conn, table_name, database_name)
- # 更新文档
- return table.update(cond, data)
- def search(
- self,
- table_name: str,
- output_fields: List[str],
- query: Dict[str, Any],
- database_name: Optional[str] = None
- ):
- """搜索文档"""
- with self.pool.get_connection() as conn:
- # 获取Table对象
- table = self._get_table(conn, table_name, database_name)
- # 获取结果集
- return table.output(output_fields).match_text(query["match_field"], query["matching_text"], query["topn"])
-
- def hybrid_search(
- self,
- table_name: str,
- output_fields: List[str],
- query: Dict[str, Any],
- database_name: Optional[str] = None
- ):
- """混合检索"""
- with self.pool.get_connection() as conn:
- # 获取Table对象
- table = self._get_table(conn, table_name, database_name)
- # 获取结果集
- return table.output(output_fields) \
- .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"]) \
- .match_text(query["match_field"], query["matching_text"], query["topn"]) \
- .fusion("rrf", query["topn"])
-
- def vector_search(
- self,
- table_name: str,
- output_fields: List[str],
- query: Dict[str, Any],
- database_name: Optional[str] = None
- ):
- """向量检索"""
- with self.pool.get_connection() as conn:
- # 获取Table对象
- table = self._get_table(conn, table_name, database_name)
- # 获取结果集
- return table.output(output_fields) \
- .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"])
-
- def get_status(self) -> Dict[str, Any]:
- """获取客户端状态"""
- return self.pool.get_status()
-
- def close(self):
- """关闭客户端,释放所有连接"""
- self.pool.close()
- # 全局客户端实例
- _global_client: Optional[InfinityClient] = None
- _client_lock = threading.Lock()
- def get_client(
- host: str = vector_db_settings.infinity_host,
- port: str = vector_db_settings.infinity_sdk_port,
- database: str = vector_db_settings.infinity_database,
- min_connections: int = 2,
- max_connections: int = 10
- ) -> InfinityClient:
- """
- 获取全局客户端实例(单例模式)
-
- Args:
- host: Infinity服务地址
- port: Infinity服务端口
- database: 数据库名称
- min_connections: 最小连接数
- max_connections: 最大连接数
-
- Returns:
- 全局Infinity客户端实例
- """
- global _global_client
-
- with _client_lock:
- if _global_client is None:
- _global_client = InfinityClient(
- host=host,
- port=port,
- database=database,
- min_connections=min_connections,
- max_connections=max_connections
- )
-
- return _global_client
- def close_client():
- """关闭全局客户端"""
- global _global_client
-
- with _client_lock:
- if _global_client:
- _global_client.close()
- _global_client = None
|