prompt_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. """
  2. 提示词管理服务
  3. 提供维度和提示词版本的 CRUD 操作。
  4. 服务初始化时会自动检查并创建对应的向量数据库表/索引。
  5. """
  6. import json
  7. from typing import List, Dict, Any, Optional
  8. from src.utils.mysql import get_global_mysql_client
  9. from src.utils.vector_db import get_vector_db_client
  10. from src.conf.settings import vector_db_settings, ragflow_settings
  11. from src.common.logging_config import get_logger
  12. logger = get_logger(__name__)
  13. # Infinity 表结构定义(用于 book_{dimension_id} 表)
  14. INFINITY_TABLE_COLUMNS = [
  15. {"name": "id", "type": "varchar", "default": ""},
  16. {"name": "file_name", "type": "varchar", "default": ""},
  17. {"name": "page_number", "type": "int", "default": 0},
  18. {"name": "content", "type": "varchar", "default": ""},
  19. {"name": "image_path", "type": "varchar", "default": ""},
  20. {"name": "document_id", "type": "varchar", "default": ""},
  21. {"name": "chunk_id", "type": "varchar", "default": ""},
  22. {"name": "ability_tags", "type": "varchar", "default": ""},
  23. {"name": "content_tag", "type": "varchar", "default": ""},
  24. {"name": "metadata", "type": "varchar", "default": ""},
  25. {"name": "dense_vector_1024", "type": "vector,1024,float"},
  26. ]
  27. # Elasticsearch 索引映射定义
  28. ES_INDEX_MAPPINGS = {
  29. "properties": {
  30. "id": {"type": "keyword"},
  31. "file_name": {"type": "keyword"},
  32. "page_number": {"type": "integer"},
  33. "content": {"type": "text", "analyzer": "ik_smart"},
  34. "image_path": {"type": "keyword"},
  35. "document_id": {"type": "keyword"},
  36. "chunk_id": {"type": "keyword"},
  37. "ability_tags": {"type": "keyword"},
  38. "content_tag": {"type": "keyword"},
  39. "metadata": {"type": "object"},
  40. "dense_vector_1024": {
  41. "type": "dense_vector",
  42. "dims": 1024,
  43. "index": True,
  44. "similarity": "cosine"
  45. }
  46. }
  47. }
  48. # Elasticsearch 索引设置定义
  49. ES_INDEX_SETTINGS = {
  50. "index": {
  51. "number_of_shards": 2,
  52. "number_of_replicas": 0
  53. }
  54. }
  55. class PromptService:
  56. """提示词管理服务"""
  57. def __init__(self):
  58. """初始化提示词服务"""
  59. self._db = get_global_mysql_client()
  60. self._vector_client = None
  61. def _get_vector_client(self):
  62. """获取向量数据库客户端"""
  63. if self._vector_client is None:
  64. self._vector_client = get_vector_db_client()
  65. return self._vector_client
  66. def _get_table_name(self, dimension: Dict) -> str:
  67. """获取维度对应的表名/索引名"""
  68. return f"{ragflow_settings.custom_dataset_prefix}_{dimension['id']}_{dimension['dataset_id']}"
  69. def init_vector_db_tables(self):
  70. """
  71. 初始化向量数据库表/索引
  72. 检查所有已存在的维度,如果对应的表/索引不存在则创建。
  73. 支持 Infinity 和 Elasticsearch。
  74. 此方法由 main.py 生命周期管理调用。
  75. """
  76. db_type = vector_db_settings.vector_db_type
  77. try:
  78. # 获取所有维度
  79. dimensions = self.get_all_dimensions()
  80. if db_type == "infinity":
  81. self._init_infinity_tables(dimensions)
  82. elif db_type == "es":
  83. self._init_es_indices(dimensions)
  84. else:
  85. logger.warning(f"未知的向量数据库类型: {db_type},跳过初始化")
  86. except Exception as e:
  87. logger.error(f"初始化向量数据库表/索引失败: {str(e)}")
  88. def _init_infinity_tables(self, dimensions: List[Dict]):
  89. """初始化 Infinity 表"""
  90. try:
  91. client = self._get_vector_client()
  92. existing_tables = client.list_tables()
  93. logger.info(f"Infinity 现有表: {existing_tables}")
  94. for dim in dimensions:
  95. table_name = self._get_table_name(dim['id'])
  96. if table_name not in existing_tables:
  97. self._create_infinity_table(table_name)
  98. logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dim['name']})")
  99. except Exception as e:
  100. logger.error(f"初始化 Infinity 表失败: {str(e)}")
  101. def _init_es_indices(self, dimensions: List[Dict]):
  102. """初始化 Elasticsearch 索引"""
  103. try:
  104. client = self._get_vector_client()
  105. for dim in dimensions:
  106. index_name = self._get_table_name(dim)
  107. if not client.index_exists(index_name):
  108. self._create_es_index(index_name)
  109. logger.info(f"✅ 创建 ES 索引: {index_name} (维度: {dim['name']})")
  110. else:
  111. logger.debug(f"ES 索引已存在: {index_name}")
  112. except Exception as e:
  113. logger.error(f"初始化 ES 索引失败: {str(e)}")
  114. def _create_infinity_table(self, table_name: str):
  115. """创建 Infinity 表"""
  116. client = self._get_vector_client()
  117. client.create_table(
  118. table_name=table_name,
  119. columns_definition=INFINITY_TABLE_COLUMNS
  120. )
  121. def _create_es_index(self, index_name: str):
  122. """创建 Elasticsearch 索引"""
  123. client = self._get_vector_client()
  124. res = client.create_index(
  125. index_name=index_name,
  126. mappings=ES_INDEX_MAPPINGS,
  127. settings=ES_INDEX_SETTINGS
  128. )
  129. logger.info(f"✅ 创建 ES 索引: {index_name} (响应: {res})")
  130. def _create_vector_db_table(self, dimension: Dict):
  131. """
  132. 为维度创建向量数据库表/索引
  133. 根据配置自动选择 Infinity 或 Elasticsearch。
  134. """
  135. db_type = vector_db_settings.vector_db_type
  136. # table_name = self._get_table_name(dimension_id)
  137. table_name = self._get_table_name(dimension)
  138. try:
  139. if db_type == "infinity":
  140. self._create_infinity_table(table_name)
  141. logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dimension['name']})")
  142. elif db_type == "es":
  143. self._create_es_index(table_name)
  144. logger.info(f"✅ 创建 ES 索引: {table_name} (维度: {dimension['name']})")
  145. except Exception as e:
  146. logger.error(f"创建向量数据库表/索引失败: {str(e)}")
  147. # ==================== 维度管理 ====================
  148. def add_dimension(self, name: str, description: str = None) -> Dict[str, Any]:
  149. """
  150. 添加维度
  151. 创建维度后会自动创建对应的向量数据库表/索引 (book_{dataset_id})。
  152. Args:
  153. name: 维度名称
  154. description: 维度描述
  155. Returns:
  156. 新建的维度信息
  157. """
  158. # 1. 先创建 RAGFlow 数据集
  159. from src.utils.ragflow.ragflow_service import RAGFlowService
  160. from src.conf.rag_parser_config import RagParserDefaults
  161. from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
  162. ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(3)
  163. ragflow_api_key = ragflow_user.get("api_key")
  164. ragflow_service = RAGFlowService(api_key=ragflow_api_key)
  165. logger.info(f"开始创建 RAGFlow 数据集: {name}")
  166. dataset = ragflow_service.create_dataset(
  167. name=name,
  168. description=description or f"维度: {name}",
  169. permission=RagParserDefaults.DATASET_PERMISSION,
  170. chunk_method=RagParserDefaults.DATASET_CHUNK_METHOD,
  171. parser_config=RagParserDefaults.DATASET_CONFIG_DICT
  172. )
  173. dataset_id = dataset['id']
  174. logger.info(f"RAGFlow 数据集创建成功,ID: {dataset_id}")
  175. # 2. 将维度信息存入 MySQL,同时保存 dataset_id
  176. sql = """
  177. INSERT INTO prompt_dimensions (name, description, dataset_id)
  178. VALUES (%s, %s, %s)
  179. """
  180. self._db.execute(sql, [name, description, dataset_id])
  181. # 3. 获取新建的维度
  182. dimension = self.get_dimension_by_name(name)
  183. # 2. 创建对应的向量数据库表/索引(使用 dataset_id 作为维度ID)
  184. self._create_vector_db_table(dimension)
  185. return dimension
  186. def get_dimension_by_name(self, name: str) -> Optional[Dict[str, Any]]:
  187. """根据名称获取维度"""
  188. sql = "SELECT * FROM prompt_dimensions WHERE name = %s"
  189. return self._db.fetch_one(sql, [name])
  190. def get_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
  191. """根据ID获取维度"""
  192. sql = "SELECT * FROM prompt_dimensions WHERE id = %s"
  193. return self._db.fetch_one(sql, [dimension_id])
  194. def get_all_dimensions(self) -> List[Dict[str, Any]]:
  195. """获取所有维度"""
  196. sql = "SELECT * FROM prompt_dimensions ORDER BY created_at DESC"
  197. return self._db.fetch_all(sql)
  198. def get_dataset_id_by_dimension_id(self, dimension_id: int) -> Optional[str]:
  199. """根据维度ID获取对应的RAGFlow数据集ID"""
  200. sql = "SELECT dataset_id FROM prompt_dimensions WHERE id = %s"
  201. result = self._db.fetch_one(sql, [dimension_id])
  202. return result['dataset_id'] if result else None
  203. def get_decomposition_method(self) -> Dict[str, Dict[str, Any]]:
  204. """获取所有维度的decomposition method"""
  205. sql = "SELECT id, decomposition_method, is_preced, preced_node FROM prompt_dimensions"
  206. results = self._db.fetch_all(sql)
  207. return {
  208. str(row['id']): {
  209. "decomposition_method": row['decomposition_method'],
  210. "is_preced": row['is_preced'],
  211. "preced_node": json.loads(row['preced_node'])
  212. }
  213. for row in results
  214. }
  215. def get_decomposition_method_by_dimension_id(self, dimension_id: int) -> Optional[str]:
  216. """根据维度ID获取对应的decomposition method"""
  217. sql = "SELECT decomposition_method FROM prompt_dimensions WHERE id = %s"
  218. result = self._db.fetch_one(sql, [dimension_id])
  219. return result['decomposition_method'] if result else None
  220. def get_preced_node_name(self, dimension_id: int) -> List[str]:
  221. """根据维度ID获取前一个节点名称,preced_node为json格式存储的整型数组"""
  222. sql = "SELECT preced_node FROM prompt_dimensions WHERE id = %s"
  223. result = self._db.fetch_one(sql, [dimension_id])
  224. return json.loads(result['preced_node']) if result else []
  225. def update_dimension(self, dimension_id: int, name: str = None, description: str = None) -> int:
  226. """更新维度信息"""
  227. updates = []
  228. params = []
  229. if name is not None:
  230. updates.append("name = %s")
  231. params.append(name)
  232. if description is not None:
  233. updates.append("description = %s")
  234. params.append(description)
  235. if not updates:
  236. return 0
  237. params.append(dimension_id)
  238. sql = f"UPDATE prompt_dimensions SET {', '.join(updates)} WHERE id = %s"
  239. return self._db.execute(sql, params)
  240. def delete_dimension(self, dimension_id: int) -> int:
  241. """删除维度(级联删除所有版本)"""
  242. sql = "DELETE FROM prompt_dimensions WHERE id = %s"
  243. return self._db.execute(sql, [dimension_id])
  244. # ==================== 版本管理 ====================
  245. def add_version(
  246. self,
  247. dimension_id: int,
  248. content: str,
  249. remark: str = None,
  250. created_by: str = None,
  251. set_active: bool = True
  252. ) -> Dict[str, Any]:
  253. """
  254. 添加新版本
  255. Args:
  256. dimension_id: 维度ID
  257. content: 提示词内容(原始格式保留)
  258. remark: 版本备注
  259. created_by: 创建人
  260. set_active: 是否设为激活版本
  261. Returns:
  262. 新建的版本信息
  263. """
  264. # 获取当前最大版本号
  265. max_version_sql = """
  266. SELECT COALESCE(MAX(version_number), 0) as max_version
  267. FROM prompt_versions WHERE dimension_id = %s
  268. """
  269. result = self._db.fetch_one(max_version_sql, [dimension_id])
  270. new_version = result['max_version'] + 1
  271. # 如果设为激活版本,先取消当前激活版本
  272. if set_active:
  273. self._db.execute(
  274. "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s",
  275. [dimension_id]
  276. )
  277. # 插入新版本
  278. insert_sql = """
  279. INSERT INTO prompt_versions (dimension_id, version_number, content, is_active, remark, created_by)
  280. VALUES (%s, %s, %s, %s, %s, %s)
  281. """
  282. self._db.execute(insert_sql, [
  283. dimension_id, new_version, content,
  284. 1 if set_active else 0, remark, created_by
  285. ])
  286. return self.get_version(dimension_id, new_version)
  287. def get_version(self, dimension_id: int, version_number: int) -> Optional[Dict[str, Any]]:
  288. """获取指定版本"""
  289. sql = """
  290. SELECT * FROM prompt_versions
  291. WHERE dimension_id = %s AND version_number = %s
  292. """
  293. return self._db.fetch_one(sql, [dimension_id, version_number])
  294. def get_versions_by_dimension(self, dimension_id: int) -> List[Dict[str, Any]]:
  295. """获取维度的所有版本"""
  296. sql = """
  297. SELECT * FROM prompt_versions
  298. WHERE dimension_id = %s
  299. ORDER BY version_number DESC
  300. """
  301. return self._db.fetch_all(sql, [dimension_id])
  302. def get_active_version(self, dimension_id: int) -> Optional[Dict[str, Any]]:
  303. """获取当前激活版本"""
  304. sql = """
  305. SELECT * FROM prompt_versions
  306. WHERE dimension_id = %s AND is_active = 1
  307. """
  308. return self._db.fetch_one(sql, [dimension_id])
  309. def get_active_prompt(self, dimension_name: str) -> Optional[str]:
  310. """
  311. 根据维度名称获取当前激活的提示词内容
  312. 这是最常用的方法,用于在工作流中获取提示词。
  313. Args:
  314. dimension_name: 维度名称
  315. Returns:
  316. 提示词内容,若不存在则返回 None
  317. """
  318. sql = """
  319. SELECT pv.content
  320. FROM prompt_versions pv
  321. JOIN prompt_dimensions pd ON pv.dimension_id = pd.id
  322. WHERE pd.name = %s AND pv.is_active = 1
  323. """
  324. result = self._db.fetch_one(sql, [dimension_name])
  325. return result['content'] if result else None
  326. def get_active_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
  327. """
  328. 根据维度ID获取当前激活的提示词内容和数据集ID
  329. Args:
  330. dimension_id: 维度ID
  331. Returns:
  332. 包含提示词内容和数据集ID的字典,若不存在则返回 None
  333. """
  334. sql = """
  335. SELECT pv.content, pd.dataset_id
  336. FROM prompt_versions pv
  337. JOIN prompt_dimensions pd ON pv.dimension_id = pd.id
  338. WHERE pv.dimension_id = %s AND pv.is_active = 1
  339. """
  340. result = self._db.fetch_one(sql, [dimension_id])
  341. return result if result else None
  342. def get_active_prompt_by_id(self, dimension_id: int) -> Optional[str]:
  343. """
  344. 根据维度ID获取当前激活的提示词内容
  345. Args:
  346. dimension_id: 维度ID
  347. Returns:
  348. 提示词内容,若不存在则返回 None
  349. """
  350. sql = """
  351. SELECT content FROM prompt_versions
  352. WHERE dimension_id = %s AND is_active = 1
  353. """
  354. result = self._db.fetch_one(sql, [dimension_id])
  355. return result['content'] if result else None
  356. def set_active_version(self, dimension_id: int, version_number: int) -> int:
  357. """设置激活版本"""
  358. # 先取消当前激活版本
  359. self._db.execute(
  360. "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s",
  361. [dimension_id]
  362. )
  363. # 设置新的激活版本
  364. sql = """
  365. UPDATE prompt_versions SET is_active = 1
  366. WHERE dimension_id = %s AND version_number = %s
  367. """
  368. return self._db.execute(sql, [dimension_id, version_number])
  369. def delete_version(self, dimension_id: int, version_number: int) -> int:
  370. """删除版本"""
  371. sql = """
  372. DELETE FROM prompt_versions
  373. WHERE dimension_id = %s AND version_number = %s
  374. """
  375. return self._db.execute(sql, [dimension_id, version_number])
  376. # 全局服务实例
  377. _prompt_service: Optional[PromptService] = None
  378. def get_prompt_service() -> PromptService:
  379. """获取提示词服务实例"""
  380. global _prompt_service
  381. if _prompt_service is None:
  382. _prompt_service = PromptService()
  383. return _prompt_service