""" 提示词管理服务 提供维度和提示词版本的 CRUD 操作。 服务初始化时会自动检查并创建对应的向量数据库表/索引。 """ from typing import List, Dict, Any, Optional from src.utils.mysql import get_global_mysql_client from src.utils.vector_db import get_vector_db_client from src.conf.settings import vector_db_settings, ragflow_settings from src.common.logging_config import get_logger logger = get_logger(__name__) # Infinity 表结构定义(用于 book_{dimension_id} 表) INFINITY_TABLE_COLUMNS = [ {"name": "id", "type": "varchar", "default": ""}, {"name": "file_name", "type": "varchar", "default": ""}, {"name": "page_number", "type": "int", "default": 0}, {"name": "content", "type": "varchar", "default": ""}, {"name": "image_path", "type": "varchar", "default": ""}, {"name": "document_id", "type": "varchar", "default": ""}, {"name": "chunk_id", "type": "varchar", "default": ""}, {"name": "ability_tags", "type": "varchar", "default": ""}, {"name": "content_tag", "type": "varchar", "default": ""}, {"name": "metadata", "type": "varchar", "default": ""}, {"name": "dense_vector_1024", "type": "vector,1024,float"}, ] # Elasticsearch 索引映射定义 ES_INDEX_MAPPINGS = { "properties": { "id": {"type": "keyword"}, "file_name": {"type": "keyword"}, "page_number": {"type": "integer"}, "content": {"type": "text", "analyzer": "ik_smart"}, "image_path": {"type": "keyword"}, "document_id": {"type": "keyword"}, "chunk_id": {"type": "keyword"}, "ability_tags": {"type": "keyword"}, "content_tag": {"type": "keyword"}, "metadata": {"type": "object"}, "dense_vector_1024": { "type": "dense_vector", "dims": 1024, "index": True, "similarity": "cosine" } } } # Elasticsearch 索引设置定义 ES_INDEX_SETTINGS = { "index": { "number_of_shards": 2, "number_of_replicas": 0 } } class PromptService: """提示词管理服务""" def __init__(self): """初始化提示词服务""" self._db = get_global_mysql_client() self._vector_client = None def _get_vector_client(self): """获取向量数据库客户端""" if self._vector_client is None: self._vector_client = get_vector_db_client() return self._vector_client def _get_table_name(self, dimension: Dict) -> str: """获取维度对应的表名/索引名""" return f"{ragflow_settings.custom_dataset_prefix}_{dimension['id']}_{dimension['dataset_id']}" def init_vector_db_tables(self): """ 初始化向量数据库表/索引 检查所有已存在的维度,如果对应的表/索引不存在则创建。 支持 Infinity 和 Elasticsearch。 此方法由 main.py 生命周期管理调用。 """ db_type = vector_db_settings.vector_db_type try: # 获取所有维度 dimensions = self.get_all_dimensions() if db_type == "infinity": self._init_infinity_tables(dimensions) elif db_type == "es": self._init_es_indices(dimensions) else: logger.warning(f"未知的向量数据库类型: {db_type},跳过初始化") except Exception as e: logger.error(f"初始化向量数据库表/索引失败: {str(e)}") def _init_infinity_tables(self, dimensions: List[Dict]): """初始化 Infinity 表""" try: client = self._get_vector_client() existing_tables = client.list_tables() logger.info(f"Infinity 现有表: {existing_tables}") for dim in dimensions: table_name = self._get_table_name(dim['id']) if table_name not in existing_tables: self._create_infinity_table(table_name) logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dim['name']})") except Exception as e: logger.error(f"初始化 Infinity 表失败: {str(e)}") def _init_es_indices(self, dimensions: List[Dict]): """初始化 Elasticsearch 索引""" try: client = self._get_vector_client() for dim in dimensions: index_name = self._get_table_name(dim) if not client.index_exists(index_name): self._create_es_index(index_name) logger.info(f"✅ 创建 ES 索引: {index_name} (维度: {dim['name']})") else: logger.debug(f"ES 索引已存在: {index_name}") except Exception as e: logger.error(f"初始化 ES 索引失败: {str(e)}") def _create_infinity_table(self, table_name: str): """创建 Infinity 表""" client = self._get_vector_client() client.create_table( table_name=table_name, columns_definition=INFINITY_TABLE_COLUMNS ) def _create_es_index(self, index_name: str): """创建 Elasticsearch 索引""" client = self._get_vector_client() res = client.create_index( index_name=index_name, mappings=ES_INDEX_MAPPINGS, settings=ES_INDEX_SETTINGS ) logger.info(f"✅ 创建 ES 索引: {index_name} (响应: {res})") def _create_vector_db_table(self, dimension: Dict): """ 为维度创建向量数据库表/索引 根据配置自动选择 Infinity 或 Elasticsearch。 """ db_type = vector_db_settings.vector_db_type # table_name = self._get_table_name(dimension_id) table_name = self._get_table_name(dimension) try: if db_type == "infinity": self._create_infinity_table(table_name) logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dimension['name']})") elif db_type == "es": self._create_es_index(table_name) logger.info(f"✅ 创建 ES 索引: {table_name} (维度: {dimension['name']})") except Exception as e: logger.error(f"创建向量数据库表/索引失败: {str(e)}") # ==================== 维度管理 ==================== def add_dimension(self, name: str, description: str = None) -> Dict[str, Any]: """ 添加维度 创建维度后会自动创建对应的向量数据库表/索引 (book_{dataset_id})。 Args: name: 维度名称 description: 维度描述 Returns: 新建的维度信息 """ # 1. 先创建 RAGFlow 数据集 from src.utils.ragflow.ragflow_service import RAGFlowService from src.conf.rag_parser_config import RagParserDefaults ragflow_service = RAGFlowService(api_key="ragflow-XelVBvv8Uc6dZLNb1aBIKdbsupucEjESotOPTZZBrG4") logger.info(f"开始创建 RAGFlow 数据集: {name}") dataset = ragflow_service.create_dataset( name=name, description=description or f"维度: {name}", permission=RagParserDefaults.DATASET_PERMISSION, chunk_method=RagParserDefaults.DATASET_CHUNK_METHOD, parser_config=RagParserDefaults.DATASET_CONFIG_DICT ) dataset_id = dataset['id'] logger.info(f"RAGFlow 数据集创建成功,ID: {dataset_id}") # 2. 将维度信息存入 MySQL,同时保存 dataset_id sql = """ INSERT INTO prompt_dimensions (name, description, dataset_id) VALUES (%s, %s, %s) """ self._db.execute(sql, [name, description, dataset_id]) # 3. 获取新建的维度 dimension = self.get_dimension_by_name(name) # 2. 创建对应的向量数据库表/索引(使用 dataset_id 作为维度ID) self._create_vector_db_table(dimension) return dimension def get_dimension_by_name(self, name: str) -> Optional[Dict[str, Any]]: """根据名称获取维度""" sql = "SELECT * FROM prompt_dimensions WHERE name = %s" return self._db.fetch_one(sql, [name]) def get_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]: """根据ID获取维度""" sql = "SELECT * FROM prompt_dimensions WHERE id = %s" return self._db.fetch_one(sql, [dimension_id]) def get_all_dimensions(self) -> List[Dict[str, Any]]: """获取所有维度""" sql = "SELECT * FROM prompt_dimensions ORDER BY created_at DESC" return self._db.fetch_all(sql) def get_dataset_id_by_dimension_id(self, dimension_id: int) -> Optional[str]: """根据维度ID获取对应的RAGFlow数据集ID""" sql = "SELECT dataset_id FROM prompt_dimensions WHERE id = %s" result = self._db.fetch_one(sql, [dimension_id]) return result['dataset_id'] if result else None def update_dimension(self, dimension_id: int, name: str = None, description: str = None) -> int: """更新维度信息""" updates = [] params = [] if name is not None: updates.append("name = %s") params.append(name) if description is not None: updates.append("description = %s") params.append(description) if not updates: return 0 params.append(dimension_id) sql = f"UPDATE prompt_dimensions SET {', '.join(updates)} WHERE id = %s" return self._db.execute(sql, params) def delete_dimension(self, dimension_id: int) -> int: """删除维度(级联删除所有版本)""" sql = "DELETE FROM prompt_dimensions WHERE id = %s" return self._db.execute(sql, [dimension_id]) # ==================== 版本管理 ==================== def add_version( self, dimension_id: int, content: str, remark: str = None, created_by: str = None, set_active: bool = True ) -> Dict[str, Any]: """ 添加新版本 Args: dimension_id: 维度ID content: 提示词内容(原始格式保留) remark: 版本备注 created_by: 创建人 set_active: 是否设为激活版本 Returns: 新建的版本信息 """ # 获取当前最大版本号 max_version_sql = """ SELECT COALESCE(MAX(version_number), 0) as max_version FROM prompt_versions WHERE dimension_id = %s """ result = self._db.fetch_one(max_version_sql, [dimension_id]) new_version = result['max_version'] + 1 # 如果设为激活版本,先取消当前激活版本 if set_active: self._db.execute( "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s", [dimension_id] ) # 插入新版本 insert_sql = """ INSERT INTO prompt_versions (dimension_id, version_number, content, is_active, remark, created_by) VALUES (%s, %s, %s, %s, %s, %s) """ self._db.execute(insert_sql, [ dimension_id, new_version, content, 1 if set_active else 0, remark, created_by ]) return self.get_version(dimension_id, new_version) def get_version(self, dimension_id: int, version_number: int) -> Optional[Dict[str, Any]]: """获取指定版本""" sql = """ SELECT * FROM prompt_versions WHERE dimension_id = %s AND version_number = %s """ return self._db.fetch_one(sql, [dimension_id, version_number]) def get_versions_by_dimension(self, dimension_id: int) -> List[Dict[str, Any]]: """获取维度的所有版本""" sql = """ SELECT * FROM prompt_versions WHERE dimension_id = %s ORDER BY version_number DESC """ return self._db.fetch_all(sql, [dimension_id]) def get_active_version(self, dimension_id: int) -> Optional[Dict[str, Any]]: """获取当前激活版本""" sql = """ SELECT * FROM prompt_versions WHERE dimension_id = %s AND is_active = 1 """ return self._db.fetch_one(sql, [dimension_id]) def get_active_prompt(self, dimension_name: str) -> Optional[str]: """ 根据维度名称获取当前激活的提示词内容 这是最常用的方法,用于在工作流中获取提示词。 Args: dimension_name: 维度名称 Returns: 提示词内容,若不存在则返回 None """ sql = """ SELECT pv.content FROM prompt_versions pv JOIN prompt_dimensions pd ON pv.dimension_id = pd.id WHERE pd.name = %s AND pv.is_active = 1 """ result = self._db.fetch_one(sql, [dimension_name]) return result['content'] if result else None def get_active_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]: """ 根据维度ID获取当前激活的提示词内容和数据集ID Args: dimension_id: 维度ID Returns: 包含提示词内容和数据集ID的字典,若不存在则返回 None """ sql = """ SELECT pv.content, pd.dataset_id FROM prompt_versions pv JOIN prompt_dimensions pd ON pv.dimension_id = pd.id WHERE pv.dimension_id = %s AND pv.is_active = 1 """ result = self._db.fetch_one(sql, [dimension_id]) return result if result else None def get_active_prompt_by_id(self, dimension_id: int) -> Optional[str]: """ 根据维度ID获取当前激活的提示词内容 Args: dimension_id: 维度ID Returns: 提示词内容,若不存在则返回 None """ sql = """ SELECT content FROM prompt_versions WHERE dimension_id = %s AND is_active = 1 """ result = self._db.fetch_one(sql, [dimension_id]) return result['content'] if result else None def set_active_version(self, dimension_id: int, version_number: int) -> int: """设置激活版本""" # 先取消当前激活版本 self._db.execute( "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s", [dimension_id] ) # 设置新的激活版本 sql = """ UPDATE prompt_versions SET is_active = 1 WHERE dimension_id = %s AND version_number = %s """ return self._db.execute(sql, [dimension_id, version_number]) def delete_version(self, dimension_id: int, version_number: int) -> int: """删除版本""" sql = """ DELETE FROM prompt_versions WHERE dimension_id = %s AND version_number = %s """ return self._db.execute(sql, [dimension_id, version_number]) # 全局服务实例 _prompt_service: Optional[PromptService] = None def get_prompt_service() -> PromptService: """获取提示词服务实例""" global _prompt_service if _prompt_service is None: _prompt_service = PromptService() return _prompt_service