# Infinity数据库客户端实现 from infinity.common import ConflictType from typing import Dict, Any, List, Optional import threading from contextlib import contextmanager from conf.settings import vector_db_settings from .pool import InfinityConnectionPool class InfinityClient: """ Infinity数据库客户端,基于连接池 提供了常用的数据库操作方法,包括: - 数据库操作 - 表操作 - 文档操作 - 搜索操作 """ def __init__( self, host: str = vector_db_settings.infinity_host, port: str = vector_db_settings.infinity_sdk_port, database: str = vector_db_settings.infinity_database, min_connections: int = 2, max_connections: int = 10 ): """ 初始化Infinity客户端 Args: host: Infinity服务地址 port: Infinity服务端口 database: 数据库名称 min_connections: 最小连接数 max_connections: 最大连接数 """ self.pool = InfinityConnectionPool( host=host, port=port, database=database, min_connections=min_connections, max_connections=max_connections ) def _get_database(self, conn: Any, database_name: Optional[str] = None): """ 获取数据库对象 Args: conn: Infinity连接对象 database_name: 数据库名称(可选,默认使用客户端配置的数据库) Returns: Database对象 """ db_name = database_name or self.pool.database return conn.get_database(db_name) def _get_table(self, conn: Any, table_name: str, database_name: Optional[str] = None): """ 获取表对象 Args: conn: Infinity连接对象 table_name: 表名 database_name: 数据库名称(可选,默认使用客户端配置的数据库) Returns: Table对象 """ db = self._get_database(conn, database_name) return db.get_table(table_name) @contextmanager def get_connection(self, timeout: Optional[int] = None): """ 获取一个连接上下文 Args: timeout: 获取连接的超时时间(秒) """ with self.pool.get_connection(timeout) as conn: yield conn def list_databases(self) -> List[str]: """获取所有数据库""" with self.pool.list_connection() as conn: return conn.list_databases() def get_databases(self) -> List[str]: """获取所有数据库(别名,向后兼容)""" return self.list_databases() def create_database(self, database_name: str, conflict_type: str = ConflictType.Error, comment: Optional[str] = None): """创建数据库""" with self.pool.get_connection() as conn: return conn.create_database(database_name, conflict_type, comment) def drop_database(self, database_name: str, conflict_type: str = "Error"): """删除数据库""" with self.pool.get_connection() as conn: return conn.drop_database(database_name, conflict_type) def use_database(self, database_name: str): """切换客户端默认数据库 注意:此操作只会影响当前客户端实例创建的新连接,不会影响已存在的连接 """ # 更新客户端的默认数据库 self.pool.database = database_name def list_tables(self, database_name: Optional[str] = None) -> List[str]: """获取所有表""" with self.pool.get_connection() as conn: # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_tables方法 db = self._get_database(conn, database_name) return db.list_tables() def get_tables(self, database_name: Optional[str] = None) -> List[str]: """获取所有表(别名,向后兼容)""" return self.list_tables(database_name) def get_table(self, table_name: str, database_name: Optional[str] = None): """获取表""" with self.pool.get_connection() as conn: # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_table方法 return self._get_table(conn, table_name, database_name) def create_table( self, table_name: str, columns_definition: List[Dict[str, Any]], conflict_type: str = ConflictType.Error, database_name: Optional[str] = None ): """创建表""" with self.pool.get_connection() as conn: # 根据官方API,应该先获取Database对象,然后在Database对象上调用create_table方法 db = self._get_database(conn, database_name) return db.create_table(table_name, columns_definition, conflict_type) def drop_table(self, table_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error): """删除表""" with self.pool.get_connection() as conn: # 根据官方API,应该先获取Database对象,然后在Database对象上调用drop_table方法 db = self._get_database(conn, database_name) return db.drop_table(table_name, conflict_type) def create_index(self, table_name: str, index_name: str, index_info: Dict[str, Any], database_name: Optional[str] = None, conflict_type = ConflictType.Error, comment = None): """创建索引""" with self.pool.get_connection() as conn: # 使用辅助方法获取Table对象 table = self._get_table(conn, table_name, database_name) return table.create_index(index_name, index_info, conflict_type, comment) def optimize(self, table_name: str, database_name: Optional[str] = None): """优化表""" with self.pool.get_connection() as conn: # 使用辅助方法获取Table对象 table = self._get_table(conn, table_name, database_name) return table.optimize() def drop_index(self, table_name: str, index_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error): """删除索引""" with self.pool.get_connection() as conn: # 使用辅助方法获取Table对象 table = self._get_table(conn, table_name, database_name) return table.drop_index(index_name, conflict_type) def insert( self, table_name: str, documents: List[Dict[str, Any]], database_name: Optional[str] = None ): """插入文档""" with self.pool.get_connection() as conn: table = self._get_table(conn, table_name, database_name) return table.insert(documents) def update( self, table_name: str, cond: str, data: Dict[str, Any], database_name: Optional[str] = None ): """更新文档""" with self.pool.get_connection() as conn: # 获取Table对象 table = self._get_table(conn, table_name, database_name) # 更新文档 return table.update(cond, data) def search( self, table_name: str, output_fields: List[str], query: Dict[str, Any], database_name: Optional[str] = None ): """搜索文档""" with self.pool.get_connection() as conn: # 获取Table对象 table = self._get_table(conn, table_name, database_name) # 获取结果集 return table.output(output_fields).match_text(query["match_field"], query["matching_text"], query["topn"]) def hybrid_search( self, table_name: str, output_fields: List[str], query: Dict[str, Any], database_name: Optional[str] = None ): """混合检索""" with self.pool.get_connection() as conn: # 获取Table对象 table = self._get_table(conn, table_name, database_name) # 获取结果集 return table.output(output_fields) \ .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"]) \ .match_text(query["match_field"], query["matching_text"], query["topn"]) \ .fusion("rrf", query["topn"]) def vector_search( self, table_name: str, output_fields: List[str], query: Dict[str, Any], database_name: Optional[str] = None ): """向量检索""" with self.pool.get_connection() as conn: # 获取Table对象 table = self._get_table(conn, table_name, database_name) # 获取结果集 return table.output(output_fields) \ .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"]) def get_status(self) -> Dict[str, Any]: """获取客户端状态""" return self.pool.get_status() def close(self): """关闭客户端,释放所有连接""" self.pool.close() # 全局客户端实例 _global_client: Optional[InfinityClient] = None _client_lock = threading.Lock() def get_client( host: str = vector_db_settings.infinity_host, port: str = vector_db_settings.infinity_sdk_port, database: str = vector_db_settings.infinity_database, min_connections: int = 1, max_connections: int = 1 ) -> InfinityClient: """ 获取全局客户端实例(单例模式) Args: host: Infinity服务地址 port: Infinity服务端口 database: 数据库名称 min_connections: 最小连接数 max_connections: 最大连接数 Returns: 全局Infinity客户端实例 """ global _global_client with _client_lock: if _global_client is None: _global_client = InfinityClient( host=host, port=port, database=database, min_connections=min_connections, max_connections=max_connections ) return _global_client def close_client(): """关闭全局客户端""" global _global_client with _client_lock: if _global_client: _global_client.close() _global_client = None