client.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Infinity数据库客户端实现
  2. from infinity.common import ConflictType
  3. from typing import Dict, Any, List, Optional
  4. import threading
  5. from contextlib import contextmanager
  6. from conf.settings import vector_db_settings
  7. from .pool import InfinityConnectionPool
  8. class InfinityClient:
  9. """
  10. Infinity数据库客户端,基于连接池
  11. 提供了常用的数据库操作方法,包括:
  12. - 数据库操作
  13. - 表操作
  14. - 文档操作
  15. - 搜索操作
  16. """
  17. def __init__(
  18. self,
  19. host: str = vector_db_settings.infinity_host,
  20. port: str = vector_db_settings.infinity_sdk_port,
  21. database: str = vector_db_settings.infinity_database,
  22. min_connections: int = 2,
  23. max_connections: int = 10
  24. ):
  25. """
  26. 初始化Infinity客户端
  27. Args:
  28. host: Infinity服务地址
  29. port: Infinity服务端口
  30. database: 数据库名称
  31. min_connections: 最小连接数
  32. max_connections: 最大连接数
  33. """
  34. self.pool = InfinityConnectionPool(
  35. host=host,
  36. port=port,
  37. database=database,
  38. min_connections=min_connections,
  39. max_connections=max_connections
  40. )
  41. def _get_database(self, conn: Any, database_name: Optional[str] = None):
  42. """
  43. 获取数据库对象
  44. Args:
  45. conn: Infinity连接对象
  46. database_name: 数据库名称(可选,默认使用客户端配置的数据库)
  47. Returns:
  48. Database对象
  49. """
  50. db_name = database_name or self.pool.database
  51. return conn.get_database(db_name)
  52. def _get_table(self, conn: Any, table_name: str, database_name: Optional[str] = None):
  53. """
  54. 获取表对象
  55. Args:
  56. conn: Infinity连接对象
  57. table_name: 表名
  58. database_name: 数据库名称(可选,默认使用客户端配置的数据库)
  59. Returns:
  60. Table对象
  61. """
  62. db = self._get_database(conn, database_name)
  63. return db.get_table(table_name)
  64. @contextmanager
  65. def get_connection(self, timeout: Optional[int] = None):
  66. """
  67. 获取一个连接上下文
  68. Args:
  69. timeout: 获取连接的超时时间(秒)
  70. """
  71. with self.pool.get_connection(timeout) as conn:
  72. yield conn
  73. def list_databases(self) -> List[str]:
  74. """获取所有数据库"""
  75. with self.pool.list_connection() as conn:
  76. return conn.list_databases()
  77. def get_databases(self) -> List[str]:
  78. """获取所有数据库(别名,向后兼容)"""
  79. return self.list_databases()
  80. def create_database(self, database_name: str, conflict_type: str = ConflictType.Error, comment: Optional[str] = None):
  81. """创建数据库"""
  82. with self.pool.get_connection() as conn:
  83. return conn.create_database(database_name, conflict_type, comment)
  84. def drop_database(self, database_name: str, conflict_type: str = "Error"):
  85. """删除数据库"""
  86. with self.pool.get_connection() as conn:
  87. return conn.drop_database(database_name, conflict_type)
  88. def use_database(self, database_name: str):
  89. """切换客户端默认数据库
  90. 注意:此操作只会影响当前客户端实例创建的新连接,不会影响已存在的连接
  91. """
  92. # 更新客户端的默认数据库
  93. self.pool.database = database_name
  94. def list_tables(self, database_name: Optional[str] = None) -> List[str]:
  95. """获取所有表"""
  96. with self.pool.get_connection() as conn:
  97. # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_tables方法
  98. db = self._get_database(conn, database_name)
  99. return db.list_tables()
  100. def get_tables(self, database_name: Optional[str] = None) -> List[str]:
  101. """获取所有表(别名,向后兼容)"""
  102. return self.list_tables(database_name)
  103. def get_table(self, table_name: str, database_name: Optional[str] = None):
  104. """获取表"""
  105. with self.pool.get_connection() as conn:
  106. # 根据官方API,应该先获取Database对象,然后在Database对象上调用get_table方法
  107. return self._get_table(conn, table_name, database_name)
  108. def create_table(
  109. self,
  110. table_name: str,
  111. columns_definition: List[Dict[str, Any]],
  112. conflict_type: str = ConflictType.Error,
  113. database_name: Optional[str] = None
  114. ):
  115. """创建表"""
  116. with self.pool.get_connection() as conn:
  117. # 根据官方API,应该先获取Database对象,然后在Database对象上调用create_table方法
  118. db = self._get_database(conn, database_name)
  119. return db.create_table(table_name, columns_definition, conflict_type)
  120. def drop_table(self, table_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error):
  121. """删除表"""
  122. with self.pool.get_connection() as conn:
  123. # 根据官方API,应该先获取Database对象,然后在Database对象上调用drop_table方法
  124. db = self._get_database(conn, database_name)
  125. return db.drop_table(table_name, conflict_type)
  126. def create_index(self, table_name: str, index_name: str, index_info: Dict[str, Any], database_name: Optional[str] = None, conflict_type = ConflictType.Error, comment = None):
  127. """创建索引"""
  128. with self.pool.get_connection() as conn:
  129. # 使用辅助方法获取Table对象
  130. table = self._get_table(conn, table_name, database_name)
  131. return table.create_index(index_name, index_info, conflict_type, comment)
  132. def optimize(self, table_name: str, database_name: Optional[str] = None):
  133. """优化表"""
  134. with self.pool.get_connection() as conn:
  135. # 使用辅助方法获取Table对象
  136. table = self._get_table(conn, table_name, database_name)
  137. return table.optimize()
  138. def drop_index(self, table_name: str, index_name: str, database_name: Optional[str] = None, conflict_type = ConflictType.Error):
  139. """删除索引"""
  140. with self.pool.get_connection() as conn:
  141. # 使用辅助方法获取Table对象
  142. table = self._get_table(conn, table_name, database_name)
  143. return table.drop_index(index_name, conflict_type)
  144. def insert(
  145. self,
  146. table_name: str,
  147. documents: List[Dict[str, Any]],
  148. database_name: Optional[str] = None
  149. ):
  150. """插入文档"""
  151. with self.pool.get_connection() as conn:
  152. table = self._get_table(conn, table_name, database_name)
  153. return table.insert(documents)
  154. def update(
  155. self,
  156. table_name: str,
  157. cond: str,
  158. data: Dict[str, Any],
  159. database_name: Optional[str] = None
  160. ):
  161. """更新文档"""
  162. with self.pool.get_connection() as conn:
  163. # 获取Table对象
  164. table = self._get_table(conn, table_name, database_name)
  165. # 更新文档
  166. return table.update(cond, data)
  167. def search(
  168. self,
  169. table_name: str,
  170. output_fields: List[str],
  171. query: Dict[str, Any],
  172. database_name: Optional[str] = None
  173. ):
  174. """搜索文档"""
  175. with self.pool.get_connection() as conn:
  176. # 获取Table对象
  177. table = self._get_table(conn, table_name, database_name)
  178. # 获取结果集
  179. return table.output(output_fields).match_text(query["match_field"], query["matching_text"], query["topn"])
  180. def hybrid_search(
  181. self,
  182. table_name: str,
  183. output_fields: List[str],
  184. query: Dict[str, Any],
  185. database_name: Optional[str] = None
  186. ):
  187. """混合检索"""
  188. with self.pool.get_connection() as conn:
  189. # 获取Table对象
  190. table = self._get_table(conn, table_name, database_name)
  191. # 获取结果集
  192. return table.output(output_fields) \
  193. .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"]) \
  194. .match_text(query["match_field"], query["matching_text"], query["topn"]) \
  195. .fusion("rrf", query["topn"])
  196. def vector_search(
  197. self,
  198. table_name: str,
  199. output_fields: List[str],
  200. query: Dict[str, Any],
  201. database_name: Optional[str] = None
  202. ):
  203. """向量检索"""
  204. with self.pool.get_connection() as conn:
  205. # 获取Table对象
  206. table = self._get_table(conn, table_name, database_name)
  207. # 获取结果集
  208. return table.output(output_fields) \
  209. .match_dense(query["vector_field"], query["query_vector"], "float", "cosine", query["topn"])
  210. def get_status(self) -> Dict[str, Any]:
  211. """获取客户端状态"""
  212. return self.pool.get_status()
  213. def close(self):
  214. """关闭客户端,释放所有连接"""
  215. self.pool.close()
  216. # 全局客户端实例
  217. _global_client: Optional[InfinityClient] = None
  218. _client_lock = threading.Lock()
  219. def get_client(
  220. host: str = vector_db_settings.infinity_host,
  221. port: str = vector_db_settings.infinity_sdk_port,
  222. database: str = vector_db_settings.infinity_database,
  223. min_connections: int = 2,
  224. max_connections: int = 10
  225. ) -> InfinityClient:
  226. """
  227. 获取全局客户端实例(单例模式)
  228. Args:
  229. host: Infinity服务地址
  230. port: Infinity服务端口
  231. database: 数据库名称
  232. min_connections: 最小连接数
  233. max_connections: 最大连接数
  234. Returns:
  235. 全局Infinity客户端实例
  236. """
  237. global _global_client
  238. with _client_lock:
  239. if _global_client is None:
  240. _global_client = InfinityClient(
  241. host=host,
  242. port=port,
  243. database=database,
  244. min_connections=min_connections,
  245. max_connections=max_connections
  246. )
  247. return _global_client
  248. def close_client():
  249. """关闭全局客户端"""
  250. global _global_client
  251. with _client_lock:
  252. if _global_client:
  253. _global_client.close()
  254. _global_client = None