"""
Copyright 2026 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import queue
import re
from threading import Thread
from typing import List, Any, Dict
import torch
from are.simulation.agents.agent_builder import AgentBuilder
from are.simulation.agents.agent_config_builder import AgentConfigBuilder
from are.simulation.agents.are_simulation_agent import RunnableARESimulationAgent
from are.simulation.agents.are_simulation_agent_config import LLMEngineConfig, MainAgentConfig
from are.simulation.agents.default_agent.base_agent import DEFAULT_STEP_2_MESSAGE
from are.simulation.agents.default_agent.tools.action_executor import BaseActionExecutor, AgentAction
from are.simulation.agents.llm.llm_engine import LLMEngine
from are.simulation.agents.llm.llm_engine_builder import LLMEngineBuilder
from are.simulation.agents.llm.types import MessageRole
from are.simulation.benchmark.scenario_loader import load_scenario
from are.simulation.environment import EnvironmentConfig, Environment
from are.simulation.exceptions import InvalidActionAgentError
from are.simulation.notification_system import VerboseNotificationSystem
from are.simulation.scenarios.scenario_imported_from_json.utils import preprocess_scenario
from are.simulation.types import EnvironmentType, SimulatedGenerationTimeConfig
from are.simulation.validation import GraphPerEventJudgeConfig
from agentic_rl import BaseEngineWrapper
from agentic_rl.runner import Trajectory
DEFAULT_STEP_2_MESSAGE["llm_output"] = "{content}"
MAX_SCENARIO_DURATION = 1800
def extract_action(self, llm_output: str, split_token: str) -> AgentAction:
"""
Extract action from LLM responses.
In the llm reasoning (think) section, the model may generate the split_token, which can cause the original
parsing logic to fail. This patch addresses that issue by splitting the response using only the last
occurrence of the split_token.
Args:
llm_output (str): LLM response.
split_token (str): token to indicate Action.
"""
try:
split = llm_output.rsplit(split_token, 1)
if len(split) < 2:
raise IndexError(f"Expected 2 parts after splitting by last '{split_token}', got {len(split)}")
rationale, action = (split[-2], split[-1])
except Exception as e:
self.logger.error(e, exc_info=True)
raise InvalidActionAgentError(
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. "
f"Be sure to include an action, prefaced with '{split_token}'!\n Exception: {e}"
)
return AgentAction(rationale, action)
BaseActionExecutor.extract_action = extract_action
def transform_role(value):
"""
Transform role from Meta-are's format to universal str format.
"""
if isinstance(value, MessageRole):
match value:
case MessageRole.USER:
return "user"
case MessageRole.SYSTEM:
return "system"
case MessageRole.ASSISTANT:
return "assistant"
case MessageRole.TOOL_RESPONSE:
return "user"
raise ValueError(f"Unable to transform invalid role: {value} ")
return value
def get_score_from_are_tool_response(content):
"""
Judge score from tool_response generated by Meta-are.
If ERROR found in tool_response, score -1.0 will be returned, else 1.0.
"""
if re.match(r"^\[OUTPUT OF STEP \d+] ERROR:", content):
return -1.0
return 1.0
def _transform_messages(messages: list[dict[str, Any]]):
"""
Tool used to transform message from Meta-are.
"""
messages = [{key: transform_role(value) for key, value in message.items()} for message in messages]
messages = [{key: value for key, value in mes.items() if key in ("role", "content")} for mes in messages]
return messages
class _AgentSDKLLMEngine(LLMEngine):
"""
Implement of Meta-are LLMEngine use interface provided by AgentSDK
"""
def __init__(self,
completion=None,
tokenizer=None,
trajectory_store: list = None,
sampling_params: Dict = None,
max_model_len: int = None):
"""
Initialization for _AgentSDKLLMEngine.
Args:
completion: The inference interface provided by AgentSDK.
tokenizer: Tokenizer will be used to format message from ARE.
trajectory_store: History chat messages will be store in trajectory_store.
sampling_params: GRPO algorithm need sample trajectory for better performance.
max_model_len: Calculate tokens to prevent over lengths.
"""
super().__init__(model_name="")
self.completion = completion
self.tokenizer = tokenizer
self.trajectory_store = trajectory_store
self.sampling_params = sampling_params
self.max_model_len = max_model_len
def chat_completion(self,
messages: list[dict[str, Any]],
stop_sequences=[],
**kwargs) -> tuple[str, dict | None]:
"""
Inference use AgentSDK. Add token length check. Successful response will be added as a trajectory step.
Args:
messages: History chat completions.
stop_sequences: Stop words.
**kwargs: Extended arguments.
"""
messages = _transform_messages(messages)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
tokens = len(self.tokenizer.encode(prompt))
if tokens >= self.max_model_len:
raise ValueError(f"Prompt token is {tokens}, exceed max_model_len: {self.max_model_len}")
self.sampling_params["max_tokens"] = self.max_model_len - tokens
response = asyncio.run(self.completion({"prompt": prompt, **self.sampling_params}))
choices = response.get("choices", [])
if not choices:
raise ValueError("No choices found in response")
choice = choices[0]
text = choice.get("text", "")
messages.append({"role": "assistant", "content": text})
self.trajectory_store.append(messages)
return text, None
class _JudgeLLMEngine(LLMEngine):
def __init__(self,
completion=None,
tokenizer=None,
sampling_params=None):
"""
Initialization for _JudgeLLMEngine.
Args:
completion: The inference interface provided by AgentSDK.
tokenizer: Tokenizer will be used to format message from Meta-are.
sampling_params: Sampling params used for LLM inference.
"""
super().__init__(model_name="")
self.completion = completion
self.tokenizer = tokenizer
self.sampling_params = sampling_params
def chat_completion(self,
messages: list[dict[str, Any]],
stop_sequences=[],
**kwargs):
"""
Judgement use AgentSDK.
Args:
messages: History chat completions.
stop_sequences: Stop words.
**kwargs: Extended arguments.
"""
messages = _transform_messages(messages)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response = asyncio.run(self.completion({"prompt": prompt, **self.sampling_params}))
choices = response.get("choices", [])
if not choices:
raise ValueError("No choices found in response for Judge")
choice = choices[0]
text = choice.get("text", "")
return text, None
class AreEngineWrapper(BaseEngineWrapper):
"""
Implement of BaseEngineWrapper from AgentSDK.
Generating trajectory by using Meta-are.
"""
def __init__(self,
agent_name: str,
tokenizer: Any,
sampling_params: Dict[str, Any],
max_prompt_length: int,
max_response_length: int,
n_parallel_agents: int = 8,
max_steps: int = 10):
"""
Initialization.
Args:
agent_name: Choose which agent to use. Useless in Meta-are.
tokenizer: Huggingface style tokenizer.
sampling_params: Inference sampling params.
max_prompt_length: Max prompt token length set by AgentSDK.
max_response_length: Max response token length set by AgentSDK.
n_parallel_agents: How many agents will be running at same time.
max_steps: How many rounds of chat will be performed during one trajectory.
"""
super().__init__(agent_name, tokenizer, sampling_params,
max_prompt_length, max_response_length, n_parallel_agents, max_steps)
self.max_model_len = max_prompt_length + max_response_length
def initialize(self):
"""
Additional initialize. Useless in Meta-ara.
"""
pass
def _parse(self, messages, add_generation_prompt=False) -> str:
"""
Parse messages, apply chat template.
Args:
messages: Messaged to be parsed.
add_generation_prompt: If generation prompt should be added.
Returns:
str: Parsed prompt.
"""
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
def tokenize_and_mask(self, messages):
"""
Tokenize messages and generate mask.
In token-mode training, tokens before the first assistant role are used as prompt_ids, and the remaining
dialogue is used as response_ids. A response_mask (value = 1) is applied only to assistant tokens so that
only model-generated responses contribute to the loss, excluding tool call outputs.
Args:
messages: message list.
Return:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: prompt_ids, response_ids and response_masks.
"""
prompt_ids = []
response_ids = []
response_mask = []
try:
first_assistant_idx = next(i for i, msg in enumerate(messages) if msg["role"] == "assistant")
except StopIteration:
raise RuntimeError("No assistant message found in completions")
for i in range(first_assistant_idx):
parsed_msg = self._parse([messages[i]], add_generation_prompt=False)
ids = self.tokenizer.encode(parsed_msg)
prompt_ids.extend(ids)
for i in range(first_assistant_idx, len(messages)):
parsed_msg = self._parse([messages[i]], add_generation_prompt=False)
ids = self.tokenizer.encode(parsed_msg)
response_ids.extend(ids)
if messages[i]["role"] == "assistant":
response_mask.extend([1] * len(ids))
else:
response_mask.extend([0] * len(ids))
prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
response_ids = torch.tensor(response_ids, dtype=torch.long)
response_mask = torch.tensor(response_mask, dtype=torch.long)
return prompt_ids, response_ids, response_mask
def worker(self, task_queue, result, completion):
"""
Worker sub process, extract task from queue and execute task.
Args:
task_queue: Queue of tasks.
result: Result list to store task result
completion: AgentSDK provided completion interface.
"""
while True:
try:
task_id, task = task_queue.get(timeout=10)
except queue.Empty:
return
r = self.run(task, task_id, completion)
result[task_id] = r
def run(self, task, idx, completion):
"""
Execute tasks, generate trajectory and calculate scores.
Args:
task: Task of gaia2, generated from gaia2 datasets, should have Meta-are scenario data in json str format.
idx: Index of task in the train_iter.
completion: AgentSDK provided completion interface.
"""
if "data" not in task.keys():
raise ValueError("Task from AgentSDK does not have data! please check config of data_path or "
"datasets_additional_keys.")
scenario_task = task["data"]
trajectory_store = []
result = self.generate_traj(scenario_data=scenario_task,
trajectory_store=trajectory_store,
completion=completion,
sampling_params=self.sampling_params,
max_steps=self.max_steps)
if result.success:
res_reward = 4.0
else:
res_reward = -2.0
valid_length_traj = None
for chat_completions in trajectory_store:
if len(
self.tokenizer.encode(self.tokenizer.apply_chat_template(chat_completions, tokenize=False))
) < self.max_model_len:
valid_length_traj = chat_completions
if not valid_length_traj:
raise RuntimeError("No trajectory within valid length")
score = 0
assistant_cnt = 0
for i, chat in enumerate(valid_length_traj):
if chat["role"] == "assistant":
if i + 1 < len(valid_length_traj):
assistant_cnt = assistant_cnt + 1
if valid_length_traj[i + 1]["role"] == "user":
score = score + get_score_from_are_tool_response(valid_length_traj[i + 1]["content"])
if assistant_cnt > 0:
toolcall_reward = score / assistant_cnt - 0.01 * assistant_cnt
else:
toolcall_reward = -1
trajectory_reward = res_reward + toolcall_reward
prompt_ids, response_ids, response_mask = self.tokenize_and_mask(valid_length_traj)
trajectory = Trajectory(
prompt_tokens=prompt_ids,
response_tokens=response_ids,
response_masks=response_mask,
idx=idx,
trajectory_reward=trajectory_reward,
chat_completions=valid_length_traj,
metrics={
"steps": assistant_cnt + 1,
"reward_time": None,
"env_time": None,
"llm_time": None,
"total_time": result.duration,
"res_reward": res_reward,
"toolcall_reward": toolcall_reward,
}
)
return trajectory
def generate_agent_trajectories_async(self, tasks: List[dict]) -> List[Trajectory]:
"""
Generate several trajectories asynchronously using multiple agents.
Args:
tasks: Meta-are scenario data.
"""
completions_size = len(self.completions)
num_worker = self.n_parallel_agents
task_queue = queue.Queue()
result = [None] * len(tasks)
for idx, task in enumerate(tasks):
task_queue.put((idx, task))
runners = []
for wid in range(num_worker):
p = Thread(target=self.worker, args=(task_queue, result, self.completions[wid % completions_size]))
p.start()
runners.append(p)
for p in runners:
p.join()
return result
def generate_traj(self,
scenario_data: str,
trajectory_store: list = None,
completion=None,
sampling_params: dict = None,
max_steps: int = 5):
"""
Generate one trajectory with Meta-are.
Args:
scenario_data: Gaia2 style scenario data, will be deserialized by Meta-are.
trajectory_store: List to store trajectory step by step.
completion: AgentSDK provided completion interface.
sampling_params: Sampling params to generate trajectory.
max_steps: How many rounds of chat will be performed in one trajectory.
Return:
Validation result: If this scenario executed successfully and execute time.
"""
tokenizer = self.tokenizer
max_model_len = self.max_model_len
class _AgentSDKEngineBuilder(LLMEngineBuilder):
def create_engine(self,
engine_config: LLMEngineConfig,
mock_responses: list[str] | None = None) -> LLMEngine:
return _AgentSDKLLMEngine(completion=completion,
tokenizer=tokenizer,
trajectory_store=trajectory_store,
sampling_params=sampling_params,
max_model_len=max_model_len)
scenario, _ = load_scenario(
scenario_data, "scenario_id", False
)
judge_engine = _JudgeLLMEngine(completion=completion, tokenizer=self.tokenizer, sampling_params=sampling_params)
judge_config = GraphPerEventJudgeConfig(engine=judge_engine)
preprocess_scenario(scenario=scenario,
judge_config=judge_config,
offline_validation=False,
max_scenario_duration=MAX_SCENARIO_DURATION,
tool_augmentation_config=None,
env_events_config=None)
env_config = EnvironmentConfig(oracle_mode=False,
queue_based_loop=False,
wait_for_user_input_timeout=None,
dump_dir=None,
time_increment_in_seconds=scenario.time_increment_in_seconds,
exit_when_no_events=False)
if scenario.start_time and scenario.start_time > 0:
env_config.start_time = scenario.start_time
env = Environment(environment_type=EnvironmentType.CLI,
config=env_config,
notification_system=VerboseNotificationSystem())
env.run(scenario, wait_for_end=False)
agent_config_builder = AgentConfigBuilder()
agent_config = agent_config_builder.build("default")
simulated_generation_time_config = SimulatedGenerationTimeConfig(mode="measured")
agent_config.get_base_agent_config().simulated_generation_time_config = simulated_generation_time_config
if isinstance(agent_config, MainAgentConfig) and scenario.nb_turns is not None:
agent_config.max_turns = scenario.nb_turns
agent_builder = AgentBuilder(_AgentSDKEngineBuilder())
are_simulation_agent: RunnableARESimulationAgent = agent_builder.build(
agent_config=agent_config, env=env
)
are_simulation_agent.react_agent.max_iterations = max_steps
are_simulation_agent.run_scenario(scenario=scenario, notification_system=env.notification_system)
validate_result = scenario.validate(env)
env.stop()
return validate_result