base.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. 节点抽象基类和基础状态定义
  3. 提供所有节点组件的抽象基类和通用状态类。
  4. """
  5. from abc import ABC, abstractmethod
  6. from typing import Dict, Any, Optional, List
  7. from pydantic import BaseModel, Field, ConfigDict
  8. class BaseState(BaseModel):
  9. """
  10. 基础状态类
  11. 所有工作流状态类都应继承此类。
  12. 提供通用的状态字段和配置。
  13. Attributes:
  14. is_complete: 是否处理完成
  15. error_message: 错误信息
  16. """
  17. model_config = ConfigDict(arbitrary_types_allowed=True)
  18. is_complete: bool = Field(default=False, description="是否处理完成")
  19. error_message: Optional[str] = Field(default=None, description="错误信息")
  20. class BaseNode(ABC):
  21. """
  22. 节点抽象基类
  23. 所有节点组件都应继承此类并实现抽象方法。
  24. 支持作为callable使用,可直接传递给LangGraph。
  25. Example:
  26. >>> class MyNode(BaseNode):
  27. ... @property
  28. ... def name(self) -> str:
  29. ... return "my_node"
  30. ...
  31. ... def execute(self, state: BaseState) -> Dict[str, Any]:
  32. ... return {"processed": True}
  33. >>>
  34. >>> node = MyNode()
  35. >>> graph.add_node(node.name, node)
  36. """
  37. @property
  38. @abstractmethod
  39. def name(self) -> str:
  40. """
  41. 节点名称
  42. 用于在工作流中标识节点,必须唯一。
  43. Returns:
  44. 节点名称字符串
  45. """
  46. pass
  47. @property
  48. def description(self) -> str:
  49. """
  50. 节点描述
  51. 可选,用于文档和调试。
  52. Returns:
  53. 节点描述字符串
  54. """
  55. return self.__class__.__doc__ or ""
  56. @abstractmethod
  57. def execute(self, state: BaseState) -> Dict[str, Any]:
  58. """
  59. 执行节点逻辑
  60. Args:
  61. state: 当前工作流状态
  62. Returns:
  63. 状态更新字典,仅包含需要更新的字段
  64. """
  65. pass
  66. def __call__(self, state) -> Dict[str, Any]:
  67. """
  68. 使节点可作为callable使用
  69. Args:
  70. state: 当前工作流状态
  71. Returns:
  72. 状态更新字典
  73. """
  74. return self.execute(state)
  75. def __repr__(self) -> str:
  76. return f"<{self.__class__.__name__}(name='{self.name}')>"
  77. class ConditionalNode(BaseNode):
  78. """
  79. 条件节点抽象基类
  80. 用于需要进行条件判断并返回路由结果的节点。
  81. """
  82. @abstractmethod
  83. def check_condition(self, state: BaseState) -> str:
  84. """
  85. 检查条件并返回路由结果
  86. Args:
  87. state: 当前工作流状态
  88. Returns:
  89. 路由目标名称字符串
  90. """
  91. pass
  92. def execute(self, state: BaseState) -> Dict[str, Any]:
  93. """条件节点不执行逻辑,仅用于路由判断"""
  94. return {}