Python接口说明

注意:AgentSDK可通过Python接口进行应用开发,从代码调用角度上来说所有Python侧接口都可以被调用。本章节仅列出业务提供的对外接口,其余未进行说明的接口用户请勿直接调用。

AgentSDK 是一个 Agent 训推调框架,支持对接任意 Agent 引擎、训练引擎、推理引擎。本文档介绍框架对外暴露的核心接口。


一、核心基类

核心基类是用户必须继承并实现的抽象类,用于自定义 Agent、环境、工具等核心组件。

1.1 BaseAgent - Agent 抽象基类

功能描述

Agent 抽象基类,负责与模型交互、维护对话状态、解析模型响应、记录轨迹。用户需要继承此类实现自定义 Agent。

类定义

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any

class BaseAgent(ABC):
    @property
    def chat_completions(self) -> list[dict[str, str]]: ...

    @property
    def trajectory(self) -> "Trajectory": ...

    @abstractmethod
    def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs): ...

    @abstractmethod
    def update_from_model(self, response: str, **kwargs) -> "Action": ...

    @abstractmethod
    def reset(self): ...

    def get_current_state(self) -> "Step | None": ...

抽象方法说明

方法名 说明
update_from_env 从环境接收观测、奖励、终止信号,更新 Agent 内部状态
update_from_model 从模型接收响应,解析并返回动作
reset 重置 Agent 状态,开始新的轨迹

文件位置: aura/runner/agent_engine_wrapper/base/agent/base_agent.py


1.2 BaseEnv - 环境抽象基类

功能描述

环境抽象基类,负责工具执行、奖励计算、状态管理。用户需要继承此类实现自定义环境。

类定义

from abc import ABC, abstractmethod
from typing import Any, tuple

class BaseEnv(ABC):
    @abstractmethod
    def reset(self) -> tuple[dict, dict]: ...

    @abstractmethod
    def step(self, action: Any) -> tuple[Any, float, bool, dict]: ...

    def close(self): ...

    @staticmethod
    @abstractmethod
    def from_dict(info: dict) -> "BaseEnv": ...

    @staticmethod
    def is_multithread_safe() -> bool: ...

抽象方法说明

方法名 说明
reset 重置环境,返回初始观测和附加信息
step 执行动作,返回 (观测, 奖励, 是否终止, 附加信息)
from_dict 从配置字典创建环境实例

文件位置: aura/runner/agent_engine_wrapper/base/environment/base_env.py


1.3 BaseEngineWrapper - 引擎包装器基类

功能描述

引擎包装器抽象基类,提供统一的 Agent 引擎适配接口。用户可继承此类对接不同的 Agent 引擎。

类定义

from abc import ABC, abstractmethod
from typing import List
from pydantic import BaseModel

class AgentTask(BaseModel):
    task_id: str
    sample_id: int
    iteration: int
    agent_name: str
    problem: str
    ground_truth: str = ""
    prompt_id: int = 0
    content: str = ""
    extra_args: dict = None

class BaseEngineWrapper(ABC):
    @abstractmethod
    async def generate_trajectory(self, task: AgentTask, stream_queue = None, *args, **kwargs) -> "Trajectory": ...

参数说明

参数名 类型 说明
agent_name str Agent 场景名称
tokenizer object 文本分词器对象
sampling_params dict 模型推理时的采样参数
max_prompt_length int 输入提示的最大长度,默认 128K
max_response_length int 输出响应的最大长度,默认 8K
n_parallel_agents int 并行执行的 Agent 数量,默认 8
max_steps int Agent 执行的最大步骤数,默认 5

文件位置: aura/runner/agent_engine_wrapper/base_engine_wrapper.py


二、注册表接口

注册表接口用于注册自定义的训练引擎、推理引擎、数据管理器等组件。

2.1 TrainRegistry - 训练引擎注册表

功能描述

训练引擎注册表,用于注册自定义的训练方法和 Rollout 方法。

类定义

class TrainRegistry:
    def register(self, train_engine: str, cluster_mode: str,
                 rollout_method: Callable | None, train_method: Callable) -> None: ...

    def get_method(self, train_engine: str, cluster_mode: str) -> tuple | None: ...

# 全局实例
registry = TrainRegistry()

已注册引擎

train_engine cluster_mode 说明
verl hybrid verl 共卡模式
verl one_step_off verl 全异步模式

文件位置: aura/trainer/train_register.py


2.2 InferBackendRegistry - 推理引擎注册表

功能描述

推理引擎注册表,用于注册自定义的推理服务后端。

类定义

class InferBackendRegistry:
    def register(self, name: str, cls: type) -> None: ...

    def get_class(self, name: str) -> type | None: ...

# 全局实例
registry = InferBackendRegistry()

已注册后端

名称 说明
vllm vLLM 推理服务
vllm_pd vLLM PD 分离推理服务

文件位置: aura/runner/infer_adapter/infer_registry.py


2.3 DataManagerRegistry - 数据管理器注册表

功能描述

数据管理器注册表,用于注册自定义的数据管理器。

类定义

class DataManagerRegistry:
    def register(self, train_backend: str, service_mode: str, cls: type) -> None: ...

    def get_class(self, train_backend: str, service_mode: str) -> type | None: ...

# 全局实例
registry = DataManagerRegistry()

已注册管理器

train_backend service_mode 数据管理器类 说明
verl train VerlDataManager verl 训练数据管理器
verl infer InferDataManager 统一推理数据管理器

文件位置: aura/data_manager/data_registry.py


2.4 AGENTS_MAPPING - Agent 配置映射

功能描述

Agent 配置映射,存储已注册的 Agent 配置信息。

数据结构

AGENTS_MAPPING = [
    {
        "name": "my_agent",
        ...
    }
]

def get_agent_by_name(name: str) -> Optional[dict]:
    for agent_config in AGENTS_MAPPING:
        if name == agent_config.get("name", ""):
            return agent_config

    return None

配置项说明

字段 类型 说明
name str Agent 名称,配置文件中通过 agent_name 引用
env_class class 环境类,必须继承自 BaseEnv
env_args dict 环境初始化参数,传递给 env_class 构造函数
agent_class class Agent 类,必须继承自 BaseAgent
agent_args dict Agent 初始化参数,传递给 agent_class 构造函数
compute_trajectory_reward_fn callable 轨迹奖励计算函数,用于计算最终奖励

使用方式

在配置文件中通过 agent_name 引用已注册的 Agent:

agent_instances:
  - name: MY-AGENT
    executor_kwargs:
      agent_engine: rllm
      agent_engine_kwargs:
        agent_name: my_agent    # 引用注册的 Agent 名称

文件位置: agents/agents_mapping.py


三、数据类

数据类定义了 Agent 运行过程中的核心数据结构。

3.1 Step - 单步数据

功能描述

记录 Agent 运行的单步信息,包含对话上下文、动作、观测、奖励等。

类定义

@dataclass
class Step:
    chat_completions: list[dict[str, str]] = field(default_factory=list)
    thought: str = ""
    action: Any = None
    observation: Any = None
    model_response: str = ""
    info: dict = field(default_factory=dict)
    reward: float = 0.0
    done: bool = False
    mc_return: float = 0.0
    step_id: int = 0

参数说明

参数名 类型 说明
chat_completions list[dict[str, str]] 推理所有的完整对话上下文(含历史轮次),用于构造模型输入
thought str 模型回复中 <think> 标签内的内容,表示模型在本步骤的内部推理
action Any 模型回复中 <tool call> 标签内的内容,表示模型决定执行的动作(如工具调用)
observation Any 本步骤接收到的外部观测:第 0 轮为用户原始提问,后续轮次为上一轮动作的执行结果(如工具返回)
model_response str 大模型生成的完整回复内容(即 'role': 'assistant'content
info dict 附加信息字典,默认为空,可用于记录工具 ID、耗时等元数据
reward float 本步骤获得的即时奖励,默认为 0.0,反映当前动作的质量
done bool 是否在本步骤终止轨迹,默认为 False,标识任务是否完成
mc_return float 从本步骤开始的 Monte Carlo 回报,默认为 0.0,用于策略梯度训练
step_id int 步骤编号

3.2 Trajectory - 轨迹数据

功能描述

记录 Agent 运行的完整轨迹信息,包含所有步骤和整体奖励。

类定义

@dataclass
class Trajectory:
    task: Any = None
    steps: list[Step] = field(default_factory=list)
    reward: float = 0.0
    toolcall_reward: float = 0.0
    res_reward: float = 0.0
    prompt_id: int = 0
    data_id: str = None
    training_id: str = None
    epoch_id: int = 0
    iteration_id: int = 0
    sample_id: int = 0
    trajectory_id: int = 0
    application_id: str = ""
    termination_reason: str = "unknown"

参数说明

参数名 类型 说明
task Any 原始任务输入
steps list[Step] 所有步骤列表
reward float 轨迹总奖励
toolcall_reward float 工具调用奖励
res_reward float 最终结果奖励
termination_reason str 终止原因

3.3 Action - 动作数据

功能描述

记录 Agent 决定的动作信息。

类定义

@dataclass
class Action:
    action: Any = None

3.4 AgentTask - 任务数据

功能描述

定义 Agent 任务的数据结构。

类定义

from pydantic import BaseModel, Field
import uuid

class AgentTask(BaseModel):
    task_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    sample_id: int
    iteration: int
    agent_name: str
    problem: str
    ground_truth: str = ""
    prompt_id: int = 0
    content: str = ""
    extra_args: dict[str, Any] = None

参数说明

参数名 类型 说明
task_id str 任务唯一标识
sample_id int 样本编号
iteration int 迭代次数
agent_name str Agent 名称
problem str 问题描述
ground_truth str 正确答案
content str 额外内容
extra_args dict 额外参数