client.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Infinity数据库客户端实现
  2. from infinity.common import ConflictType
  3. from typing import Dict, Any, List, Optional
  4. import threading
  5. from contextlib import contextmanager
  6. from conf.config import VectorDBConfig
  7. from .pool import InfinityConnectionPool
  8. class InfinityClient:
  9. """
  10. Infinity数据库客户端,基于连接池
  11. 提供了常用的数据库操作方法,包括:
  12. - 数据库操作
  13. - 表操作
  14. - 文档操作
  15. - 搜索操作
  16. """
  17. def __init__(
  18. self,
  19. host: str = VectorDBConfig.get_infinity_host(),
  20. port: str = VectorDBConfig.get_infinity_sdk_port(),
  21. database: str = VectorDBConfig.get_infinity_database(),
  22. min_connections: int = 2,
  23. max_connections: int = 10
  24. ):
  25. """
  26. 初始化Infinity客户端
  27. Args:
  28. host: Infinity服务地址
  29. port: Infinity服务端口
  30. database: 数据库名称
  31. min_connections: 最小连接数
  32. max_connections: 最大连接数
  33. """
  34. self.pool = InfinityConnectionPool(
  35. host=host,
  36. port=port,
  37. database=database,
  38. min_connections=min_connections,
  39. max_connections=max_connections
  40. )
  41. def _get_database(self, conn: Any, database_name: Optional[str] = None):
  42. """
  43. 获取数据库对象
  44. Args:
  45. conn: Infinity连接对象
  46. database_name: 数据库名称(可选,默认使用客户端配置的数据库)
  47. Returns:
  48. Database对象
  49. """
  50. db_name = database_name or self.pool.database
  51. return conn.get_database(db_name)
  52. def _get_table(self, conn: Any, table_name: str, database_name: Optional[str] = None):
  53. """
  54. 获取表对象
  55. Args:
  56. conn: Infinity连接对象
  57. table_name: 表名
  58. database_name: 数据库名称(可选,默认使用客户端配置的数据库)
  59. Returns:
  60. Table对象
  61. """
  62. db = self._get_database(conn, database_name)
  63. return db.get_table(table_name)
  64. @contextmanager
  65. def get_connection(self, timeout: Optional[int] = None):
  66. """
  67. 获取一个连接上下文
  68. Args:
  69. timeout: 获取连接的超时时间(秒)
  70. """
  71. with self.pool.get_connection(timeout) as conn:
  72. yield conn
  73. def list_databases(self) -> List[str]:
  74. """获取所有数据库"""
  75. with self.pool.list_connection() as conn:
  76. return conn.list_databases()
  77. def get_databases(self) -> List[str]:
  78. """获取所有数据库(别名,向后兼容)"""
  79. return self.list_databases()
  80. def create_database(self, database_name: str, conflict_type: str = ConflictType.Error, comment: Optional[str] = None):
  81. """创建数据库"""
  82. with self.pool.get_connection() as conn:
  83. return conn.create_database(database_name, conflict_type, comment)
  84. def drop_database(self, database_name: str, conflict_type: str = "Error"):
  85. """删除数据库"""
  86. with self.pool.get_connection() as conn:
  87. return conn.drop_database(database_name, conflict_type)
  88. def use_database(self, database_name: str):
  89. """切换客户端默认数据库
  90. 注意:此操作只会影响当前客户端实例创建的新连接,不会影响已存在的连接
  91. """
  92. # 更新客户端的默认数据库
  93. self.pool.database = database_name
  94. def list_tables(self, database_name: Optional[str] = None) -> List[str]:
  95. """获取所有表"""
  96. with self.pool.get_connection() as conn:
  97. # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_tables方法
  98. db = self._get_database(conn, database_name)
  99. return db.list_tables()
  100. def get_tables(self, database_name: Optional[str] = None) -> List[str]:
  101. """获取所有表(别名,向后兼容)"""
  102. return self.list_tables(database_name)
  103. def get_table(self, table_name: str, database_name: Optional[str] = None):
  104. """获取表"""
  105. with self.pool.get_connection() as conn:
  106. # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_table方法
  107. return self._get_table(conn, table_name, database_name)
  108. def create_table(
  109. self,
  110. table_name: str,
  111. columns_definition: List[Dict[str, Any]],
  112. conflict_type: str = ConflictType.Error,
  113. database_name: Optional[str] = None
  114. ):
  115. """创建表"""
  116. with self.pool.get_connection() as conn:
  117. # 根据官方API,应该先获取Database对象,然后在Database对象上调用create_table方法
  118. db = self._get_database(conn, database_name)
  119. return db.create_table(table_name, columns_definition, conflict_type)
  120. def drop_table(self, table_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error):
  121. """删除表"""
  122. with self.pool.get_connection() as conn:
  123. # 根据官方API,应该先获取Database对象,然后在Database对象上调用drop_table方法
  124. db = self._get_database(conn, database_name)
  125. return db.drop_table(table_name, conflict_type)
  126. def create_index(self, table_name: str, index_name: str, index_info: Dict[str, Any], database_name: Optional[str] = None, conflict_type = ConflictType.Error, comment = None):
  127. """创建索引"""
  128. with self.pool.get_connection() as conn:
  129. # 使用辅助方法获取Table对象
  130. table = self._get_table(conn, table_name, database_name)
  131. return table.create_index(index_name, index_info, conflict_type, comment)
  132. def optimize(self, table_name: str, database_name: Optional[str] = None):
  133. """优化表"""
  134. with self.pool.get_connection() as conn:
  135. # 使用辅助方法获取Table对象
  136. table = self._get_table(conn, table_name, database_name)
  137. return table.optimize()
  138. def drop_index(self, table_name: str, index_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error):
  139. """删除索引"""
  140. with self.pool.get_connection() as conn:
  141. # 使用辅助方法获取Table对象
  142. table = self._get_table(conn, table_name, database_name)
  143. return table.drop_index(index_name, conflict_type)
  144. def insert(
  145. self,
  146. table_name: str,
  147. documents: List[Dict[str, Any]],
  148. database_name: Optional[str] = None
  149. ):
  150. """插入文档"""
  151. with self.pool.get_connection() as conn:
  152. table = self._get_table(conn, table_name, database_name)
  153. return table.insert(documents)
  154. def search(
  155. self,
  156. table_name: str,
  157. output_fields: List[str],
  158. query: Dict[str, Any],
  159. database_name: Optional[str] = None
  160. ):
  161. """搜索文档"""
  162. with self.pool.get_connection() as conn:
  163. # 获取Table对象
  164. table = self._get_table(conn, table_name, database_name)
  165. # 获取结果集
  166. return table.output(output_fields).match_text(query["match_field"], query["matching_text"], query["topn"])
  167. def hybrid_search(
  168. self,
  169. table_name: str,
  170. output_fields: List[str],
  171. query: Dict[str, Any],
  172. database_name: Optional[str] = None
  173. ):
  174. """混合检索"""
  175. with self.pool.get_connection() as conn:
  176. # 获取Table对象
  177. table = self._get_table(conn, table_name, database_name)
  178. # 获取结果集
  179. return table.output(output_fields) \
  180. .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"]) \
  181. .match_text(query["match_field"], query["matching_text"], query["topn"]) \
  182. .fusion("rrf", query["topn"])
  183. def vector_search(
  184. self,
  185. table_name: str,
  186. output_fields: List[str],
  187. query: Dict[str, Any],
  188. database_name: Optional[str] = None
  189. ):
  190. """向量检索"""
  191. with self.pool.get_connection() as conn:
  192. # 获取Table对象
  193. table = self._get_table(conn, table_name, database_name)
  194. # 获取结果集
  195. return table.output(output_fields) \
  196. .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"])
  197. def get_status(self) -> Dict[str, Any]:
  198. """获取客户端状态"""
  199. return self.pool.get_status()
  200. def close(self):
  201. """关闭客户端,释放所有连接"""
  202. self.pool.close()
  203. # 全局客户端实例
  204. _global_client: Optional[InfinityClient] = None
  205. _client_lock = threading.Lock()
  206. def get_client(
  207. host: str = VectorDBConfig.get_infinity_host(),
  208. port: str = VectorDBConfig.get_infinity_sdk_port(),
  209. database: str = VectorDBConfig.get_infinity_database(),
  210. min_connections: int = 2,
  211. max_connections: int = 10
  212. ) -> InfinityClient:
  213. """
  214. 获取全局客户端实例(单例模式)
  215. Args:
  216. host: Infinity服务地址
  217. port: Infinity服务端口
  218. database: 数据库名称
  219. min_connections: 最小连接数
  220. max_connections: 最大连接数
  221. Returns:
  222. 全局Infinity客户端实例
  223. """
  224. global _global_client
  225. with _client_lock:
  226. if _global_client is None:
  227. _global_client = InfinityClient(
  228. host=host,
  229. port=port,
  230. database=database,
  231. min_connections=min_connections,
  232. max_connections=max_connections
  233. )
  234. return _global_client
  235. def close_client():
  236. """关闭全局客户端"""
  237. global _global_client
  238. with _client_lock:
  239. if _global_client:
  240. _global_client.close()
  241. _global_client = None