| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- """
- 提示词管理服务
- 提供维度和提示词版本的 CRUD 操作。
- 服务初始化时会自动检查并创建对应的向量数据库表/索引。
- """
- from typing import List, Dict, Any, Optional
- from src.utils.mysql import get_global_mysql_client
- from src.utils.vector_db import get_vector_db_client
- from src.conf.settings import vector_db_settings, ragflow_settings
- from src.common.logging_config import get_logger
- logger = get_logger(__name__)
- # Infinity 表结构定义(用于 book_{dimension_id} 表)
- INFINITY_TABLE_COLUMNS = [
- {"name": "id", "type": "varchar", "default": ""},
- {"name": "file_name", "type": "varchar", "default": ""},
- {"name": "page_number", "type": "int", "default": 0},
- {"name": "content", "type": "varchar", "default": ""},
- {"name": "image_path", "type": "varchar", "default": ""},
- {"name": "document_id", "type": "varchar", "default": ""},
- {"name": "chunk_id", "type": "varchar", "default": ""},
- {"name": "ability_tags", "type": "varchar", "default": ""},
- {"name": "content_tag", "type": "varchar", "default": ""},
- {"name": "metadata", "type": "varchar", "default": ""},
- {"name": "dense_vector_1024", "type": "vector,1024,float"},
- ]
- # Elasticsearch 索引映射定义
- ES_INDEX_MAPPINGS = {
- "properties": {
- "id": {"type": "keyword"},
- "file_name": {"type": "keyword"},
- "page_number": {"type": "integer"},
- "content": {"type": "text", "analyzer": "ik_smart"},
- "image_path": {"type": "keyword"},
- "document_id": {"type": "keyword"},
- "chunk_id": {"type": "keyword"},
- "ability_tags": {"type": "keyword"},
- "content_tag": {"type": "keyword"},
- "metadata": {"type": "object"},
- "dense_vector_1024": {
- "type": "dense_vector",
- "dims": 1024,
- "index": True,
- "similarity": "cosine"
- }
- }
- }
- # Elasticsearch 索引设置定义
- ES_INDEX_SETTINGS = {
- "index": {
- "number_of_shards": 2,
- "number_of_replicas": 0
- }
- }
- class PromptService:
- """提示词管理服务"""
-
- def __init__(self):
- """初始化提示词服务"""
- self._db = get_global_mysql_client()
- self._vector_client = None
-
- def _get_vector_client(self):
- """获取向量数据库客户端"""
- if self._vector_client is None:
- self._vector_client = get_vector_db_client()
- return self._vector_client
-
- def _get_table_name(self, dimension: Dict) -> str:
- """获取维度对应的表名/索引名"""
- return f"{ragflow_settings.custom_dataset_prefix}_{dimension['id']}_{dimension['dataset_id']}"
-
- def init_vector_db_tables(self):
- """
- 初始化向量数据库表/索引
-
- 检查所有已存在的维度,如果对应的表/索引不存在则创建。
- 支持 Infinity 和 Elasticsearch。
-
- 此方法由 main.py 生命周期管理调用。
- """
- db_type = vector_db_settings.vector_db_type
-
- try:
- # 获取所有维度
- dimensions = self.get_all_dimensions()
-
- if db_type == "infinity":
- self._init_infinity_tables(dimensions)
- elif db_type == "es":
- self._init_es_indices(dimensions)
- else:
- logger.warning(f"未知的向量数据库类型: {db_type},跳过初始化")
- except Exception as e:
- logger.error(f"初始化向量数据库表/索引失败: {str(e)}")
-
- def _init_infinity_tables(self, dimensions: List[Dict]):
- """初始化 Infinity 表"""
- try:
- client = self._get_vector_client()
- existing_tables = client.list_tables()
- logger.info(f"Infinity 现有表: {existing_tables}")
-
- for dim in dimensions:
- table_name = self._get_table_name(dim['id'])
- if table_name not in existing_tables:
- self._create_infinity_table(table_name)
- logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dim['name']})")
- except Exception as e:
- logger.error(f"初始化 Infinity 表失败: {str(e)}")
-
- def _init_es_indices(self, dimensions: List[Dict]):
- """初始化 Elasticsearch 索引"""
- try:
- client = self._get_vector_client()
-
- for dim in dimensions:
- index_name = self._get_table_name(dim)
- if not client.index_exists(index_name):
- self._create_es_index(index_name)
- logger.info(f"✅ 创建 ES 索引: {index_name} (维度: {dim['name']})")
- else:
- logger.debug(f"ES 索引已存在: {index_name}")
- except Exception as e:
- logger.error(f"初始化 ES 索引失败: {str(e)}")
-
- def _create_infinity_table(self, table_name: str):
- """创建 Infinity 表"""
- client = self._get_vector_client()
- client.create_table(
- table_name=table_name,
- columns_definition=INFINITY_TABLE_COLUMNS
- )
-
- def _create_es_index(self, index_name: str):
- """创建 Elasticsearch 索引"""
- client = self._get_vector_client()
- res = client.create_index(
- index_name=index_name,
- mappings=ES_INDEX_MAPPINGS,
- settings=ES_INDEX_SETTINGS
- )
- logger.info(f"✅ 创建 ES 索引: {index_name} (响应: {res})")
-
- def _create_vector_db_table(self, dimension: Dict):
- """
- 为维度创建向量数据库表/索引
-
- 根据配置自动选择 Infinity 或 Elasticsearch。
- """
- db_type = vector_db_settings.vector_db_type
- # table_name = self._get_table_name(dimension_id)
- table_name = self._get_table_name(dimension)
- try:
- if db_type == "infinity":
- self._create_infinity_table(table_name)
- logger.info(f"✅ 创建 Infinity 表: {table_name} (维度: {dimension['name']})")
- elif db_type == "es":
- self._create_es_index(table_name)
- logger.info(f"✅ 创建 ES 索引: {table_name} (维度: {dimension['name']})")
- except Exception as e:
- logger.error(f"创建向量数据库表/索引失败: {str(e)}")
-
- # ==================== 维度管理 ====================
-
- def add_dimension(self, name: str, description: str = None) -> Dict[str, Any]:
- """
- 添加维度
-
- 创建维度后会自动创建对应的向量数据库表/索引 (book_{dataset_id})。
-
- Args:
- name: 维度名称
- description: 维度描述
-
- Returns:
- 新建的维度信息
- """
- # 1. 先创建 RAGFlow 数据集
- from src.utils.ragflow.ragflow_service import RAGFlowService
- from src.conf.rag_parser_config import RagParserDefaults
- from src.utils.ragflow.ragflow_user_service import get_ragflow_user_service
- ragflow_user = get_ragflow_user_service().get_ragflow_id_and_api_key(3)
- ragflow_api_key = ragflow_user.get("api_key")
- ragflow_service = RAGFlowService(api_key=ragflow_api_key)
- logger.info(f"开始创建 RAGFlow 数据集: {name}")
-
- dataset = ragflow_service.create_dataset(
- name=name,
- description=description or f"维度: {name}",
- permission=RagParserDefaults.DATASET_PERMISSION,
- chunk_method=RagParserDefaults.DATASET_CHUNK_METHOD,
- parser_config=RagParserDefaults.DATASET_CONFIG_DICT
- )
-
- dataset_id = dataset['id']
- logger.info(f"RAGFlow 数据集创建成功,ID: {dataset_id}")
-
- # 2. 将维度信息存入 MySQL,同时保存 dataset_id
- sql = """
- INSERT INTO prompt_dimensions (name, description, dataset_id)
- VALUES (%s, %s, %s)
- """
- self._db.execute(sql, [name, description, dataset_id])
-
- # 3. 获取新建的维度
- dimension = self.get_dimension_by_name(name)
- # 2. 创建对应的向量数据库表/索引(使用 dataset_id 作为维度ID)
- self._create_vector_db_table(dimension)
-
- return dimension
-
- def get_dimension_by_name(self, name: str) -> Optional[Dict[str, Any]]:
- """根据名称获取维度"""
- sql = "SELECT * FROM prompt_dimensions WHERE name = %s"
- return self._db.fetch_one(sql, [name])
-
- def get_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
- """根据ID获取维度"""
- sql = "SELECT * FROM prompt_dimensions WHERE id = %s"
- return self._db.fetch_one(sql, [dimension_id])
-
- def get_all_dimensions(self) -> List[Dict[str, Any]]:
- """获取所有维度"""
- sql = "SELECT * FROM prompt_dimensions ORDER BY created_at DESC"
- return self._db.fetch_all(sql)
-
- def get_dataset_id_by_dimension_id(self, dimension_id: int) -> Optional[str]:
- """根据维度ID获取对应的RAGFlow数据集ID"""
- sql = "SELECT dataset_id FROM prompt_dimensions WHERE id = %s"
- result = self._db.fetch_one(sql, [dimension_id])
- return result['dataset_id'] if result else None
- def update_dimension(self, dimension_id: int, name: str = None, description: str = None) -> int:
- """更新维度信息"""
- updates = []
- params = []
-
- if name is not None:
- updates.append("name = %s")
- params.append(name)
- if description is not None:
- updates.append("description = %s")
- params.append(description)
-
- if not updates:
- return 0
-
- params.append(dimension_id)
- sql = f"UPDATE prompt_dimensions SET {', '.join(updates)} WHERE id = %s"
- return self._db.execute(sql, params)
-
- def delete_dimension(self, dimension_id: int) -> int:
- """删除维度(级联删除所有版本)"""
- sql = "DELETE FROM prompt_dimensions WHERE id = %s"
- return self._db.execute(sql, [dimension_id])
-
- # ==================== 版本管理 ====================
-
- def add_version(
- self,
- dimension_id: int,
- content: str,
- remark: str = None,
- created_by: str = None,
- set_active: bool = True
- ) -> Dict[str, Any]:
- """
- 添加新版本
-
- Args:
- dimension_id: 维度ID
- content: 提示词内容(原始格式保留)
- remark: 版本备注
- created_by: 创建人
- set_active: 是否设为激活版本
-
- Returns:
- 新建的版本信息
- """
- # 获取当前最大版本号
- max_version_sql = """
- SELECT COALESCE(MAX(version_number), 0) as max_version
- FROM prompt_versions WHERE dimension_id = %s
- """
- result = self._db.fetch_one(max_version_sql, [dimension_id])
- new_version = result['max_version'] + 1
-
- # 如果设为激活版本,先取消当前激活版本
- if set_active:
- self._db.execute(
- "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s",
- [dimension_id]
- )
-
- # 插入新版本
- insert_sql = """
- INSERT INTO prompt_versions (dimension_id, version_number, content, is_active, remark, created_by)
- VALUES (%s, %s, %s, %s, %s, %s)
- """
- self._db.execute(insert_sql, [
- dimension_id, new_version, content,
- 1 if set_active else 0, remark, created_by
- ])
-
- return self.get_version(dimension_id, new_version)
-
- def get_version(self, dimension_id: int, version_number: int) -> Optional[Dict[str, Any]]:
- """获取指定版本"""
- sql = """
- SELECT * FROM prompt_versions
- WHERE dimension_id = %s AND version_number = %s
- """
- return self._db.fetch_one(sql, [dimension_id, version_number])
-
- def get_versions_by_dimension(self, dimension_id: int) -> List[Dict[str, Any]]:
- """获取维度的所有版本"""
- sql = """
- SELECT * FROM prompt_versions
- WHERE dimension_id = %s
- ORDER BY version_number DESC
- """
- return self._db.fetch_all(sql, [dimension_id])
-
- def get_active_version(self, dimension_id: int) -> Optional[Dict[str, Any]]:
- """获取当前激活版本"""
- sql = """
- SELECT * FROM prompt_versions
- WHERE dimension_id = %s AND is_active = 1
- """
- return self._db.fetch_one(sql, [dimension_id])
-
- def get_active_prompt(self, dimension_name: str) -> Optional[str]:
- """
- 根据维度名称获取当前激活的提示词内容
-
- 这是最常用的方法,用于在工作流中获取提示词。
-
- Args:
- dimension_name: 维度名称
-
- Returns:
- 提示词内容,若不存在则返回 None
- """
- sql = """
- SELECT pv.content
- FROM prompt_versions pv
- JOIN prompt_dimensions pd ON pv.dimension_id = pd.id
- WHERE pd.name = %s AND pv.is_active = 1
- """
- result = self._db.fetch_one(sql, [dimension_name])
- return result['content'] if result else None
-
- def get_active_dimension_by_id(self, dimension_id: int) -> Optional[Dict[str, Any]]:
- """
- 根据维度ID获取当前激活的提示词内容和数据集ID
-
- Args:
- dimension_id: 维度ID
-
- Returns:
- 包含提示词内容和数据集ID的字典,若不存在则返回 None
- """
- sql = """
- SELECT pv.content, pd.dataset_id
- FROM prompt_versions pv
- JOIN prompt_dimensions pd ON pv.dimension_id = pd.id
- WHERE pv.dimension_id = %s AND pv.is_active = 1
- """
- result = self._db.fetch_one(sql, [dimension_id])
- return result if result else None
-
- def get_active_prompt_by_id(self, dimension_id: int) -> Optional[str]:
- """
- 根据维度ID获取当前激活的提示词内容
-
- Args:
- dimension_id: 维度ID
-
- Returns:
- 提示词内容,若不存在则返回 None
- """
- sql = """
- SELECT content FROM prompt_versions
- WHERE dimension_id = %s AND is_active = 1
- """
- result = self._db.fetch_one(sql, [dimension_id])
- return result['content'] if result else None
- def set_active_version(self, dimension_id: int, version_number: int) -> int:
- """设置激活版本"""
- # 先取消当前激活版本
- self._db.execute(
- "UPDATE prompt_versions SET is_active = 0 WHERE dimension_id = %s",
- [dimension_id]
- )
- # 设置新的激活版本
- sql = """
- UPDATE prompt_versions SET is_active = 1
- WHERE dimension_id = %s AND version_number = %s
- """
- return self._db.execute(sql, [dimension_id, version_number])
-
- def delete_version(self, dimension_id: int, version_number: int) -> int:
- """删除版本"""
- sql = """
- DELETE FROM prompt_versions
- WHERE dimension_id = %s AND version_number = %s
- """
- return self._db.execute(sql, [dimension_id, version_number])
- # 全局服务实例
- _prompt_service: Optional[PromptService] = None
- def get_prompt_service() -> PromptService:
- """获取提示词服务实例"""
- global _prompt_service
- if _prompt_service is None:
- _prompt_service = PromptService()
- return _prompt_service
|