prompt_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """
  2. 提示词管理服务
  3. 提供维度和提示词版本的 CRUD 操作。
  4. 服务初始化时会自动检查并创建对应的向量数据库表/索引。
  5. """
  6. from typing import List, Dict, Any, Optional
  7. from src.utils.mysql import get_global_mysql_client
  8. from src.utils.vector_db import get_vector_db_client
  9. from src.conf.settings import vector_db_settings, ragflow_settings
  10. from src.common.logging_config import get_logger
  11. logger = get_logger(__name__)
  12. # Infinity 表结构定义(用于 book_{dimension_id} 表)
  13. INFINITY_TABLE_COLUMNS = [
  14. {"name": "id", "type": "varchar", "default": ""},
  15. {"name": "file_name", "type": "varchar", "default": ""},
  16. {"name": "page_number", "type": "int", "default": 0},
  17. {"name": "content", "type": "varchar", "default": ""},
  18. {"name": "image_path", "type": "varchar", "default": ""},
  19. {"name": "document_id", "type": "varchar", "default": ""},
  20. {"name": "chunk_id", "type": "varchar", "default": ""},
  21. {"name": "ability_tags", "type": "varchar", "default": ""},
  22. {"name": "content_tag", "type": "varchar", "default": ""},
  23. {"name": "metadata", "type": "varchar", "default": ""},
  24. {"name": "dense_vector_1024", "type": "vector,1024,float"},
  25. ]
  26. # Elasticsearch 索引映射定义
  27. ES_INDEX_MAPPINGS = {
  28. "properties": {
  29. "id": {"type": "keyword"},
  30. "file_name": {"type": "keyword"},
  31. "page_number": {"type": "integer"},
  32. "content": {"type": "text", "analyzer": "ik_smart"},
  33. "image_path": {"type": "keyword"},
  34. "document_id": {"type": "keyword"},
  35. "chunk_id": {"type": "keyword"},
  36. "ability_tags": {"type": "keyword"},
  37. "content_tag": {"type": "keyword"},
  38. "metadata": {"type": "object"},
  39. "dense_vector_1024": {
  40. "type": "dense_vector",
  41. "dims": 1024,
  42. "index": True,
  43. "similarity": "cosine"
  44. }
  45. }
  46. }
  47. # Elasticsearch 索引设置定义
  48. ES_INDEX_SETTINGS = {
  49. "index": {
  50. "number_of_shards": 2,
  51. "number_of_replicas": 0
  52. }
  53. }
  54. class PromptService:
  55. """提示词管理服务"""
  56. def __init__(self):
  57. """初始化提示词服务"""
  58. self._db = get_global_mysql_client()
  59. self._vector_client = None
  60. def _get_vector_client(self):
  61. """获取向量数据库客户端"""
  62. if self._vector_client is None:
  63. self._vector_client = get_vector_db_client()
  64. return self._vector_client
  65. def _get_table_name(self, dimension: Dict) -> str:
  66. """获取维度对应的表名/索引名"""
  67. return f"{ragflow_settings.custom_dataset_prefix}_{dimension['id']}_{dimension['dataset_id']}"
  68. def init_vector_db_tables(self):
  69. """
  70. 初始化向量数据库表/索引
  71. 检查所有已存在的维度,如果对应的表/索引不存在则创建。
  72. 支持 Infinity 和 Elasticsearch。
  73. 此方法由 main.py 生命周期管理调用。
  74. """
  75. db_type = vector_db_settings.vector_db_type
  76. try:
  77. # 获取所有维度
  78. dimensions = self.get_all_dimensions()
  79. if db_type == "infinity":
  80. self._init_infinity_tables(dimensions)
  81. elif db_type == "es":
  82. self._init_es_indices(dimensions)
  83. else:
  84. logger.warning(f"未知的向量数据库类型: {db_type},跳过初始化")
  85. except Exception as e:
  86. logger.error(f"初始化向量数据库表/索引失败: {str(e)}")
  87. def _init_infinity_tables(self, dimensions: List[Dict]):
  88. """初始化 Infinity 表"""
  89. try:
  90. client = self._get_vector_client()
  91. existing_tables = client.list_tables()
  92. logger.info(f"Infinity 现有表: {existing_tables}")
  93. for dim in dimensions:
  94. table_name = self._get_table_name(dim['id'])
  95. if table_name not in existing_tables:
  96. self._create_infinity_table(table_name)
  97. logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dim['name']})")
  98. except Exception as e:
  99. logger.error(f"初始化 Infinity 表失败: {str(e)}")
  100. def _init_es_indices(self, dimensions: List[Dict]):
  101. """初始化 Elasticsearch 索引"""
  102. try:
  103. client = self._get_vector_client()
  104. for dim in dimensions:
  105. index_name = self._get_table_name(dim)
  106. if not client.index_exists(index_name):
  107. self._create_es_index(index_name)
  108. logger.info(f"✅ 创建 ES 索引: {index_name} (维度: {dim['name']})")
  109. else:
  110. logger.debug(f"ES 索引已存在: {index_name}")
  111. except Exception as e:
  112. logger.error(f"初始化 ES 索引失败: {str(e)}")
  113. def _create_infinity_table(self, table_name: str):
  114. """创建 Infinity 表"""
  115. client = self._get_vector_client()
  116. client.create_table(
  117. table_name=table_name,
  118. columns_definition=INFINITY_TABLE_COLUMNS
  119. )
  120. def _create_es_index(self, index_name: str):
  121. """创建 Elasticsearch 索引"""
  122. client = self._get_vector_client()
  123. res = client.create_index(
  124. index_name=index_name,
  125. mappings=ES_INDEX_MAPPINGS,
  126. settings=ES_INDEX_SETTINGS
  127. )
  128. logger.info(f"✅ 创建 ES 索引: {index_name} (响应: {res})")
  129. def _create_vector_db_table(self, dimension: Dict):
  130. """
  131. 为维度创建向量数据库表/索引
  132. 根据配置自动选择 Infinity 或 Elasticsearch。
  133. """
  134. db_type = vector_db_settings.vector_db_type
  135. # table_name = self._get_table_name(dimension_id)
  136. table_name = self._get_table_name(dimension)
  137. try:
  138. if db_type == "infinity":
  139. self._create_infinity_table(table_name)
  140. logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dimension['name']})")
  141. elif db_type == "es":
  142. self._create_es_index(table_name)
  143. logger.info(f"✅ 创建 ES 索引: {table_name} (维度: {dimension['name']})")
  144. except Exception as e:
  145. logger.error(f"创建向量数据库表/索引失败: {str(e)}")
  146. # ==================== 维度管理 ====================
  147. def add_dimension(self, name: str, description: str = None) -> Dict[str, Any]:
  148. """
  149. 添加维度
  150. 创建维度后会自动创建对应的向量数据库表/索引 (book_{dataset_id})。
  151. Args:
  152. name: 维度名称
  153. description: 维度描述
  154. Returns:
  155. 新建的维度信息
  156. """
  157. # 1. 先创建 RAGFlow 数据集
  158. from src.utils.ragflow.ragflow_service import RAGFlowService
  159. from src.conf.rag_parser_config import RagParserDefaults
  160. ragflow_service = RAGFlowService(api_key="ragflow-XelVBvv8Uc6dZLNb1aBIKdbsupucEjESotOPTZZBrG4")
  161. logger.info(f"开始创建 RAGFlow 数据集: {name}")
  162. dataset = ragflow_service.create_dataset(
  163. name=name,
  164. description=description or f"维度: {name}",
  165. permission=RagParserDefaults.DATASET_PERMISSION,
  166. chunk_method=RagParserDefaults.DATASET_CHUNK_METHOD,
  167. parser_config=RagParserDefaults.DATASET_CONFIG_DICT
  168. )
  169. dataset_id = dataset['id']
  170. logger.info(f"RAGFlow 数据集创建成功,ID: {dataset_id}")
  171. # 2. 将维度信息存入 MySQL,同时保存 dataset_id
  172. sql = """
  173. INSERT INTO prompt_dimensions (name, description, dataset_id)
  174. VALUES (%s, %s, %s)
  175. """
  176. self._db.execute(sql, [name, description, dataset_id])
  177. # 3. 获取新建的维度
  178. dimension = self.get_dimension_by_name(name)
  179. # 2. 创建对应的向量数据库表/索引(使用 dataset_id 作为维度ID)
  180. self._create_vector_db_table(dimension)
  181. return dimension
  182. def get_dimension_by_name(self, name: str) -> Optional[Dict[str, Any]]:
  183. """根据名称获取维度"""
  184. sql = "SELECT * FROM prompt_dimensions WHERE name = %s"
  185. return self._db.fetch_one(sql, [name])
  186. def get_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
  187. """根据ID获取维度"""
  188. sql = "SELECT * FROM prompt_dimensions WHERE id = %s"
  189. return self._db.fetch_one(sql, [dimension_id])
  190. def get_all_dimensions(self) -> List[Dict[str, Any]]:
  191. """获取所有维度"""
  192. sql = "SELECT * FROM prompt_dimensions ORDER BY created_at DESC"
  193. return self._db.fetch_all(sql)
  194. def get_dataset_id_by_dimension_id(self, dimension_id: int) -> Optional[str]:
  195. """根据维度ID获取对应的RAGFlow数据集ID"""
  196. sql = "SELECT dataset_id FROM prompt_dimensions WHERE id = %s"
  197. result = self._db.fetch_one(sql, [dimension_id])
  198. return result['dataset_id'] if result else None
  199. def update_dimension(self, dimension_id: int, name: str = None, description: str = None) -> int:
  200. """更新维度信息"""
  201. updates = []
  202. params = []
  203. if name is not None:
  204. updates.append("name = %s")
  205. params.append(name)
  206. if description is not None:
  207. updates.append("description = %s")
  208. params.append(description)
  209. if not updates:
  210. return 0
  211. params.append(dimension_id)
  212. sql = f"UPDATE prompt_dimensions SET {', '.join(updates)} WHERE id = %s"
  213. return self._db.execute(sql, params)
  214. def delete_dimension(self, dimension_id: int) -> int:
  215. """删除维度(级联删除所有版本)"""
  216. sql = "DELETE FROM prompt_dimensions WHERE id = %s"
  217. return self._db.execute(sql, [dimension_id])
  218. # ==================== 版本管理 ====================
  219. def add_version(
  220. self,
  221. dimension_id: int,
  222. content: str,
  223. remark: str = None,
  224. created_by: str = None,
  225. set_active: bool = True
  226. ) -> Dict[str, Any]:
  227. """
  228. 添加新版本
  229. Args:
  230. dimension_id: 维度ID
  231. content: 提示词内容(原始格式保留)
  232. remark: 版本备注
  233. created_by: 创建人
  234. set_active: 是否设为激活版本
  235. Returns:
  236. 新建的版本信息
  237. """
  238. # 获取当前最大版本号
  239. max_version_sql = """
  240. SELECT COALESCE(MAX(version_number), 0) as max_version
  241. FROM prompt_versions WHERE dimension_id = %s
  242. """
  243. result = self._db.fetch_one(max_version_sql, [dimension_id])
  244. new_version = result['max_version'] + 1
  245. # 如果设为激活版本,先取消当前激活版本
  246. if set_active:
  247. self._db.execute(
  248. "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s",
  249. [dimension_id]
  250. )
  251. # 插入新版本
  252. insert_sql = """
  253. INSERT INTO prompt_versions (dimension_id, version_number, content, is_active, remark, created_by)
  254. VALUES (%s, %s, %s, %s, %s, %s)
  255. """
  256. self._db.execute(insert_sql, [
  257. dimension_id, new_version, content,
  258. 1 if set_active else 0, remark, created_by
  259. ])
  260. return self.get_version(dimension_id, new_version)
  261. def get_version(self, dimension_id: int, version_number: int) -> Optional[Dict[str, Any]]:
  262. """获取指定版本"""
  263. sql = """
  264. SELECT * FROM prompt_versions
  265. WHERE dimension_id = %s AND version_number = %s
  266. """
  267. return self._db.fetch_one(sql, [dimension_id, version_number])
  268. def get_versions_by_dimension(self, dimension_id: int) -> List[Dict[str, Any]]:
  269. """获取维度的所有版本"""
  270. sql = """
  271. SELECT * FROM prompt_versions
  272. WHERE dimension_id = %s
  273. ORDER BY version_number DESC
  274. """
  275. return self._db.fetch_all(sql, [dimension_id])
  276. def get_active_version(self, dimension_id: int) -> Optional[Dict[str, Any]]:
  277. """获取当前激活版本"""
  278. sql = """
  279. SELECT * FROM prompt_versions
  280. WHERE dimension_id = %s AND is_active = 1
  281. """
  282. return self._db.fetch_one(sql, [dimension_id])
  283. def get_active_prompt(self, dimension_name: str) -> Optional[str]:
  284. """
  285. 根据维度名称获取当前激活的提示词内容
  286. 这是最常用的方法,用于在工作流中获取提示词。
  287. Args:
  288. dimension_name: 维度名称
  289. Returns:
  290. 提示词内容,若不存在则返回 None
  291. """
  292. sql = """
  293. SELECT pv.content
  294. FROM prompt_versions pv
  295. JOIN prompt_dimensions pd ON pv.dimension_id = pd.id
  296. WHERE pd.name = %s AND pv.is_active = 1
  297. """
  298. result = self._db.fetch_one(sql, [dimension_name])
  299. return result['content'] if result else None
  300. def get_active_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
  301. """
  302. 根据维度ID获取当前激活的提示词内容和数据集ID
  303. Args:
  304. dimension_id: 维度ID
  305. Returns:
  306. 包含提示词内容和数据集ID的字典,若不存在则返回 None
  307. """
  308. sql = """
  309. SELECT pv.content, pd.dataset_id
  310. FROM prompt_versions pv
  311. JOIN prompt_dimensions pd ON pv.dimension_id = pd.id
  312. WHERE pv.dimension_id = %s AND pv.is_active = 1
  313. """
  314. result = self._db.fetch_one(sql, [dimension_id])
  315. return result if result else None
  316. def get_active_prompt_by_id(self, dimension_id: int) -> Optional[str]:
  317. """
  318. 根据维度ID获取当前激活的提示词内容
  319. Args:
  320. dimension_id: 维度ID
  321. Returns:
  322. 提示词内容,若不存在则返回 None
  323. """
  324. sql = """
  325. SELECT content FROM prompt_versions
  326. WHERE dimension_id = %s AND is_active = 1
  327. """
  328. result = self._db.fetch_one(sql, [dimension_id])
  329. return result['content'] if result else None
  330. def set_active_version(self, dimension_id: int, version_number: int) -> int:
  331. """设置激活版本"""
  332. # 先取消当前激活版本
  333. self._db.execute(
  334. "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s",
  335. [dimension_id]
  336. )
  337. # 设置新的激活版本
  338. sql = """
  339. UPDATE prompt_versions SET is_active = 1
  340. WHERE dimension_id = %s AND version_number = %s
  341. """
  342. return self._db.execute(sql, [dimension_id, version_number])
  343. def delete_version(self, dimension_id: int, version_number: int) -> int:
  344. """删除版本"""
  345. sql = """
  346. DELETE FROM prompt_versions
  347. WHERE dimension_id = %s AND version_number = %s
  348. """
  349. return self._db.execute(sql, [dimension_id, version_number])
  350. # 全局服务实例
  351. _prompt_service: Optional[PromptService] = None
  352. def get_prompt_service() -> PromptService:
  353. """获取提示词服务实例"""
  354. global _prompt_service
  355. if _prompt_service is None:
  356. _prompt_service = PromptService()
  357. return _prompt_service