__init__.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """
  2. Infinity向量数据库主类
  3. 基于官方Infinity Python SDK实现
  4. """
  5. from typing import List, Dict, Any, Optional
  6. import json
  7. class InfinityVectorDB:
  8. """
  9. Infinity向量数据库主类
  10. 提供统一的接口,整合索引、文档和搜索功能
  11. """
  12. def __init__(self):
  13. """
  14. 初始化Infinity向量数据库
  15. 使用HTTP API实现,不依赖官方SDK
  16. """
  17. from conf.config import VectorDBConfig
  18. from utils.http_client import HTTPClient
  19. import base64
  20. # 获取配置
  21. self.host = VectorDBConfig.get_infinity_host()
  22. self.port = VectorDBConfig.get_infinity_port()
  23. self.user = VectorDBConfig.get_infinity_user()
  24. self.password = VectorDBConfig.get_infinity_password()
  25. self.database = VectorDBConfig.get_infinity_database()
  26. self.headers = {
  27. "Accept": "application/json",
  28. "Content-Type": "application/json"
  29. }
  30. # 生成Basic Auth令牌
  31. auth_str = f"{self.user}:{self.password}"
  32. auth_token = base64.b64encode(auth_str.encode()).decode()
  33. # 初始化HTTP客户端
  34. self.base_url = f"http://{self.host}:{self.port}"
  35. self.http_client = HTTPClient(
  36. base_url=self.base_url,
  37. api_key=auth_token,
  38. auth_type='basic'
  39. )
  40. def create_index(self, index_name: str, mappings: Dict[str, Any] = None) -> bool:
  41. """创建索引"""
  42. try:
  43. # 使用Infinity官方HTTP API创建表(对应索引)
  44. path = f"/databases/{self.database}/tables/{index_name}"
  45. # 定义表字段
  46. with open("conf/infinity_mapping.json", "r", encoding="utf-8") as f:
  47. fields = json.load(f)
  48. data = {
  49. "create_option": "ignore_if_exists",
  50. "fields": fields
  51. }
  52. response = self.http_client.post(path, json_data=data, headers=self.headers)
  53. return response.get("error_code") == 0
  54. except Exception as e:
  55. print(f"Failed to create index: {str(e)}")
  56. return False
  57. def delete_index(self, index_name: str) -> bool:
  58. """删除索引"""
  59. try:
  60. # 使用Infinity官方HTTP API删除表(对应索引)
  61. path = f"/databases/{self.database}/tables/{index_name}"
  62. data = {
  63. "drop_option": "ignore_if_not_exists"
  64. }
  65. response = self.http_client.delete(path, json_data=data, headers=self.headers)
  66. return response.get("error_code") == 0
  67. except Exception as e:
  68. print(f"Failed to delete index: {str(e)}")
  69. return False
  70. def index_exists(self, index_name: str) -> bool:
  71. """检查索引是否存在"""
  72. try:
  73. # 使用Infinity官方HTTP API获取表列表
  74. path = f"/databases/{self.database}/tables"
  75. response = self.http_client.get(path, headers=self.headers)
  76. if response.get("error_code") == 0:
  77. tables = response.get("tables", [])
  78. return index_name in tables
  79. return False
  80. except Exception as e:
  81. print(f"Failed to check index existence: {str(e)}")
  82. return False
  83. def insert_document(self, index_name: str, document: Dict[str, Any], id: str = None) -> bool:
  84. """插入单个文档"""
  85. try:
  86. # 使用Infinity官方HTTP API插入单行数据
  87. path = f"/databases/{self.database}/tables/{index_name}/docs"
  88. # 如果提供了id,将其添加到文档中
  89. if id:
  90. document["id"] = id
  91. data = [document]
  92. response = self.http_client.post(path, json_data=data, headers=self.headers)
  93. return response.get("error_code") == 0
  94. except Exception as e:
  95. print(f"Failed to insert document: {str(e)}")
  96. return False
  97. def bulk_insert(self, index_name: str, documents: List[Dict[str, Any]]) -> Dict[str, Any]:
  98. """批量插入文档"""
  99. try:
  100. # 使用Infinity官方HTTP API批量插入数据
  101. path = f"/databases/{self.database}/tables/{index_name}/docs"
  102. data = documents
  103. response = self.http_client.post(path, json_data=data, headers=self.headers)
  104. if response.get("error_code") == 0:
  105. return {
  106. "success": True,
  107. "inserted": len(documents)
  108. }
  109. else:
  110. return {
  111. "success": False,
  112. "error": response.get("error_msg", "Unknown error"),
  113. "inserted": 0
  114. }
  115. except Exception as e:
  116. print(f"Failed to bulk insert documents: {str(e)}")
  117. return {
  118. "success": False,
  119. "error": str(e),
  120. "inserted": 0
  121. }
  122. def update_document(self, index_name: str, document_id: str, document: Dict[str, Any]) -> bool:
  123. """更新单个文档"""
  124. try:
  125. # 使用Infinity官方HTTP API更新行
  126. path = f"/databases/{self.database}/tables/{index_name}/rows"
  127. data = {
  128. "update_by": {
  129. "column": "id",
  130. "value": document_id
  131. },
  132. "update_data": document
  133. }
  134. response = self.http_client.put(path, json_data=data, headers=self.headers)
  135. return response.get("error_code") == 0
  136. except Exception as e:
  137. print(f"Failed to update document: {str(e)}")
  138. return False
  139. def delete_document(self, index_name: str, document_id: str) -> bool:
  140. """删除单个文档"""
  141. try:
  142. # 使用Infinity官方HTTP API删除行
  143. path = f"/databases/{self.database}/tables/{index_name}/rows"
  144. data = {
  145. "delete_by": {
  146. "column": "id",
  147. "value": document_id
  148. }
  149. }
  150. response = self.http_client.delete(path, json_data=data, headers=self.headers)
  151. return response.get("error_code") == 0
  152. except Exception as e:
  153. print(f"Failed to delete document: {str(e)}")
  154. return False
  155. def get_document(self, index_name: str, document_id: str) -> Optional[Dict[str, Any]]:
  156. """获取单个文档"""
  157. try:
  158. # 使用Infinity官方HTTP API查询单行数据
  159. path = f"/databases/{self.database}/tables/{index_name}/query"
  160. data = {
  161. "filter": {
  162. "column": "id",
  163. "operator": "=",
  164. "value": document_id
  165. },
  166. "limit": 1
  167. }
  168. response = self.http_client.post(path, json_data=data, headers=self.headers)
  169. if response.get("error_code") == 0:
  170. rows = response.get("rows", [])
  171. if rows:
  172. return rows[0]
  173. return None
  174. except Exception as e:
  175. print(f"Failed to get document: {str(e)}")
  176. return None
  177. def delete_by_query(self, index_name: str, query: Dict[str, Any]) -> Dict[str, Any]:
  178. """按查询条件删除文档"""
  179. try:
  180. # 使用Infinity官方HTTP API按条件删除行
  181. path = f"/databases/{self.database}/tables/{index_name}/rows"
  182. # 构建删除条件
  183. # 这里假设query是一个简单的字典,如{"dataset_id": "xxx"}
  184. filter_conditions = []
  185. for key, value in query.items():
  186. filter_conditions.append({
  187. "column": key,
  188. "operator": "=",
  189. "value": value
  190. })
  191. data = {
  192. "delete_by": {
  193. "and": filter_conditions
  194. }
  195. }
  196. response = self.http_client.delete(path, json_data=data, headers=self.headers)
  197. if response.get("error_code") == 0:
  198. return {"success": True}
  199. else:
  200. return {"success": False, "error": response.get("error_msg", "Unknown error")}
  201. except Exception as e:
  202. print(f"Failed to delete by query: {str(e)}")
  203. return {"success": False, "error": str(e)}
  204. def search(self, index_name: str, query: Dict[str, Any], size: int = 10) -> Dict[str, Any]:
  205. """搜索文档"""
  206. try:
  207. # 使用Infinity官方HTTP API查询数据
  208. path = f"/databases/{self.database}/tables/{index_name}/docs"
  209. data = {
  210. "filter": query,
  211. "limit": size
  212. }
  213. response = self.http_client.post(path, json_data=data, headers=self.headers)
  214. if response.get("error_code") == 0:
  215. rows = response.get("output", [])
  216. return {
  217. "output": rows,
  218. }
  219. else:
  220. return {"hits": [], "total": 0, "error": response.get("error_msg", "Unknown error")}
  221. except Exception as e:
  222. print(f"Failed to search: {str(e)}")
  223. return {"output": [], "error": str(e)}
  224. def vector_search(self, index_name: str, vector_field: str, vector: List[float], size: int = 10, filter: Dict[str, Any] = None) -> Dict[str, Any]:
  225. """向量检索"""
  226. try:
  227. # 使用Infinity官方HTTP API进行向量检索
  228. path = f"/databases/{self.database}/tables/{index_name}/docs"
  229. data = {
  230. "vector_field": vector_field,
  231. "vector": vector,
  232. "limit": size
  233. }
  234. if filter:
  235. data["filter"] = filter
  236. response = self.http_client.post(path, json_data=data, headers=self.headers)
  237. if response.get("error_code") == 0:
  238. rows = response.get("output", [])
  239. return {
  240. "hits": rows,
  241. "total": len(rows)
  242. }
  243. else:
  244. return {"hits": [], "total": 0, "error": response.get("error_msg", "Unknown error")}
  245. except Exception as e:
  246. print(f"Failed to vector search: {str(e)}")
  247. return {"hits": [], "total": 0, "error": str(e)}
  248. def hybrid_search(self, index_name: str, match_method: str, vector_field: str, query_vector: List[float], element_type: str,
  249. metric_type: str = "cosine", topn: int = 3, rank_constant: int = 60,
  250. text_query: str = "", text_field: str = "file_name"
  251. ) -> Dict[str, Any]:
  252. """混合检索"""
  253. try:
  254. # 使用Infinity官方HTTP API进行混合检索
  255. path = f"/databases/{self.database}/tables/{index_name}/docs"
  256. # 构建搜索配置列表
  257. search_config = [
  258. {
  259. "match_method": match_method,
  260. "fields": vector_field,
  261. "query_vector": query_vector,
  262. "element_type": element_type,
  263. "metric_type": metric_type,
  264. "topn": topn,
  265. "params": {
  266. "ef": "10"
  267. }
  268. }
  269. ]
  270. # 只有当text_query和text_field都不为空时,才添加文本搜索配置
  271. # if text_query and text_field:
  272. # search_config.append(
  273. # {
  274. # "match_method": "text",
  275. # "fields": text_field,
  276. # "matching_text": text_query,
  277. # "topn": 1,
  278. # "params":
  279. # {
  280. # "default_fields": text_field,
  281. # "operator": "or"
  282. # }
  283. # }
  284. # )
  285. # 添加融合方法配置
  286. # if vector_field and vector and text_query and text_field:
  287. # search_config.append(
  288. # {
  289. # "fusion_method": "rrf",
  290. # "topn": topn,
  291. # "params":{"rank_constant": rank_constant}
  292. # }
  293. # )
  294. data = {
  295. "output": [
  296. "file_name",
  297. "page_number",
  298. "content",
  299. "image_path",
  300. "dataset_id",
  301. "document_id",
  302. "_similarity"
  303. ],
  304. "search": search_config
  305. }
  306. response = self.http_client.get_json(path, json_data=data, headers=self.headers)
  307. if response["error_code"] == 0:
  308. rows = response["output"]
  309. # 将列表的列表转换为字典列表
  310. output_fields = ["file_name", "page_number", "content", "image_path", "dataset_id", "document_id", "_similarity"]
  311. formatted_rows = []
  312. for row in rows:
  313. # 创建字典,将每个字段名与对应的值匹配
  314. formatted_row = {}
  315. for i, field in enumerate(output_fields):
  316. if i < len(row):
  317. # 处理字段值,确保是字典类型
  318. if isinstance(row[i], dict):
  319. formatted_row.update(row[i])
  320. else:
  321. formatted_row[field] = row[i]
  322. formatted_rows.append(formatted_row)
  323. return {
  324. "output": formatted_rows,
  325. "total": len(formatted_rows)
  326. }
  327. else:
  328. return {"output": [], "total": 0, "error": response["error_msg"]}
  329. except Exception as e:
  330. print(f"Failed to hybrid search: {str(e)}")
  331. return {"output": [], "total": 0, "error": str(e)}