from typing import Optional from ragflow_sdk import DataSet from conf.settings import ragflow_settings, model_settings from .conf.rag_parser_config import RagParserDefaults from .base_util import RAGFlowBaseUtil class DataSetUtil(RAGFlowBaseUtil): def __init__(self): super().__init__() def create_dataset(self, chunk_method: str = RagParserDefaults.DATASET_CHUNK_METHOD, dataset_name: Optional[str] = None, dataset_desc: str = "", embedding_model: str = f"{model_settings.embedding_model_name}@SILICONFLOW", parser_config: dict = RagParserDefaults.DATASET_CONFIG_DICT): """ 创建数据集 """ # 封装数据集参数 parser_obj = DataSet.ParserConfig(self.ragflow_client, parser_config) # 创建数据集 dataset = self.ragflow_client.create_dataset( name=dataset_name, description=dataset_desc, embedding_model=embedding_model, permission=RagParserDefaults.DATASET_PERMISSION, chunk_method= chunk_method, parser_config=parser_obj ) return dataset def delete_dataset(self, ids: list[str] | None = None): """ 删除数据集 """ try: if ids is not None: self.ragflow_client.delete_datasets(ids) except Exception as e: print(f"删除数据集失败: {e}") def get_dataset(self, name: Optional[str] = None): """ 获取数据集 """ try: dataset = self.ragflow_client.get_dataset(name=name) except Exception as e: print(f"获取数据集失败: {e}") return None return dataset def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]: """ 获取数据集列表 """ try: dataset_list = self.ragflow_client.list_datasets( page=page, page_size=page_size, orderby=orderby, desc=desc, id=id, name=name, ) except Exception as e: print(f"获取数据集列表失败: {e}") return None return dataset_list