| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- """
- 节点抽象基类和基础状态定义
- 提供所有节点组件的抽象基类和通用状态类。
- """
- from abc import ABC, abstractmethod
- from typing import Dict, Any, Optional, List
- from pydantic import BaseModel, Field, ConfigDict
- class BaseState(BaseModel):
- """
- 基础状态类
-
- 所有工作流状态类都应继承此类。
- 提供通用的状态字段和配置。
-
- Attributes:
- is_complete: 是否处理完成
- error_message: 错误信息
- """
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
- is_complete: bool = Field(default=False, description="是否处理完成")
- error_message: Optional[str] = Field(default=None, description="错误信息")
- class BaseNode(ABC):
- """
- 节点抽象基类
-
- 所有节点组件都应继承此类并实现抽象方法。
- 支持作为callable使用,可直接传递给LangGraph。
-
- Example:
- >>> class MyNode(BaseNode):
- ... @property
- ... def name(self) -> str:
- ... return "my_node"
- ...
- ... def execute(self, state: BaseState) -> Dict[str, Any]:
- ... return {"processed": True}
- >>>
- >>> node = MyNode()
- >>> graph.add_node(node.name, node)
- """
-
- @property
- @abstractmethod
- def name(self) -> str:
- """
- 节点名称
-
- 用于在工作流中标识节点,必须唯一。
-
- Returns:
- 节点名称字符串
- """
- pass
-
- @property
- def description(self) -> str:
- """
- 节点描述
-
- 可选,用于文档和调试。
-
- Returns:
- 节点描述字符串
- """
- return self.__class__.__doc__ or ""
-
- @abstractmethod
- def execute(self, state: BaseState) -> Dict[str, Any]:
- """
- 执行节点逻辑
-
- Args:
- state: 当前工作流状态
-
- Returns:
- 状态更新字典,仅包含需要更新的字段
- """
- pass
-
- def __call__(self, state) -> Dict[str, Any]:
- """
- 使节点可作为callable使用
-
- Args:
- state: 当前工作流状态
-
- Returns:
- 状态更新字典
- """
- return self.execute(state)
-
- def __repr__(self) -> str:
- return f"<{self.__class__.__name__}(name='{self.name}')>"
- class ConditionalNode(BaseNode):
- """
- 条件节点抽象基类
-
- 用于需要进行条件判断并返回路由结果的节点。
- """
-
- @abstractmethod
- def check_condition(self, state: BaseState) -> str:
- """
- 检查条件并返回路由结果
-
- Args:
- state: 当前工作流状态
-
- Returns:
- 路由目标名称字符串
- """
- pass
-
- def execute(self, state: BaseState) -> Dict[str, Any]:
- """条件节点不执行逻辑,仅用于路由判断"""
- return {}
|