| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- from typing import Optional
- from ragflow_sdk import DataSet
- from conf.settings import ragflow_settings, model_settings
- from .conf.rag_parser_config import RagParserDefaults
- from .base_util import RAGFlowBaseUtil
- class DataSetUtil(RAGFlowBaseUtil):
- def __init__(self):
- super().__init__()
- def create_dataset(self, chunk_method: str = RagParserDefaults.DATASET_CHUNK_METHOD,
- dataset_name: Optional[str] = None, dataset_desc: str = "",
- embedding_model: str = f"{model_settings.embedding_model_name}@SILICONFLOW",
- parser_config: dict = RagParserDefaults.DATASET_CONFIG_DICT):
- """
- 创建数据集
- """
- # 封装数据集参数
- parser_obj = DataSet.ParserConfig(self.ragflow_client, parser_config)
- # 创建数据集
- dataset = self.ragflow_client.create_dataset(
- name=dataset_name,
- description=dataset_desc,
- embedding_model=embedding_model,
- permission=RagParserDefaults.DATASET_PERMISSION,
- chunk_method= chunk_method,
- parser_config=parser_obj
- )
- return dataset
- def delete_dataset(self, ids: list[str] | None = None):
- """
- 删除数据集
- """
- try:
- if ids is not None:
- self.ragflow_client.delete_datasets(ids)
- except Exception as e:
- print(f"删除数据集失败: {e}")
- def get_dataset(self, name: Optional[str] = None):
- """
- 获取数据集
- """
- try:
- dataset = self.ragflow_client.get_dataset(name=name)
- except Exception as e:
- print(f"获取数据集失败: {e}")
- return None
- return dataset
-
- def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
- """
- 获取数据集列表
- """
- try:
- dataset_list = self.ragflow_client.list_datasets(
- page=page,
- page_size=page_size,
- orderby=orderby,
- desc=desc,
- id=id,
- name=name,
- )
- except Exception as e:
- print(f"获取数据集列表失败: {e}")
- return None
- return dataset_list
|