"""
Copyright 2026 Huawei Technologies Co., Ltd

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import queue
import re
from threading import Thread
from typing import List, Any, Dict

import torch
from are.simulation.agents.agent_builder import AgentBuilder
from are.simulation.agents.agent_config_builder import AgentConfigBuilder
from are.simulation.agents.are_simulation_agent import RunnableARESimulationAgent
from are.simulation.agents.are_simulation_agent_config import LLMEngineConfig, MainAgentConfig
from are.simulation.agents.default_agent.base_agent import DEFAULT_STEP_2_MESSAGE
from are.simulation.agents.default_agent.tools.action_executor import BaseActionExecutor, AgentAction
from are.simulation.agents.llm.llm_engine import LLMEngine
from are.simulation.agents.llm.llm_engine_builder import LLMEngineBuilder
from are.simulation.agents.llm.types import MessageRole
from are.simulation.benchmark.scenario_loader import load_scenario
from are.simulation.environment import EnvironmentConfig, Environment
from are.simulation.exceptions import InvalidActionAgentError
from are.simulation.notification_system import VerboseNotificationSystem
from are.simulation.scenarios.scenario_imported_from_json.utils import preprocess_scenario
from are.simulation.types import EnvironmentType, SimulatedGenerationTimeConfig
from are.simulation.validation import GraphPerEventJudgeConfig

from agentic_rl import BaseEngineWrapper
from agentic_rl.runner import Trajectory

DEFAULT_STEP_2_MESSAGE["llm_output"] = "{content}"
MAX_SCENARIO_DURATION = 1800


def extract_action(self, llm_output: str, split_token: str) -> AgentAction:
    """
    Extract action from LLM responses.

    In the llm reasoning (think) section, the model may generate the split_token, which can cause the original
    parsing logic to fail. This patch addresses that issue by splitting the response using only the last
    occurrence of the split_token.
    
    Args:
        llm_output (str): LLM response.
        split_token (str): token to indicate Action.
    """

    try:
        split = llm_output.rsplit(split_token, 1)
        if len(split) < 2:
            raise IndexError(f"Expected 2 parts after splitting by last '{split_token}', got {len(split)}")

        rationale, action = (split[-2], split[-1])
    except Exception as e:
        self.logger.error(e, exc_info=True)
        raise InvalidActionAgentError(
            f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. "
            f"Be sure to include an action, prefaced with '{split_token}'!\n Exception: {e}"
        )
    return AgentAction(rationale, action)


BaseActionExecutor.extract_action = extract_action


def transform_role(value):
    """
    Transform role from Meta-are's format to universal str format.
    """
    if isinstance(value, MessageRole):
        match value:
            case MessageRole.USER:
                return "user"
            case MessageRole.SYSTEM:
                return "system"
            case MessageRole.ASSISTANT:
                return "assistant"
            case MessageRole.TOOL_RESPONSE:
                return "user"
        raise ValueError(f"Unable to transform invalid role: {value} ")
    return value


def get_score_from_are_tool_response(content):
    """
    Judge score from tool_response generated by Meta-are.
    If ERROR found in tool_response, score -1.0 will be returned, else 1.0.
    """

    if re.match(r"^\[OUTPUT OF STEP \d+] ERROR:", content):
        return -1.0
    return 1.0


def _transform_messages(messages: list[dict[str, Any]]):
    """
    Tool used to transform message from Meta-are.
    """
    messages = [{key: transform_role(value) for key, value in message.items()} for message in messages]
    messages = [{key: value for key, value in mes.items() if key in ("role", "content")} for mes in messages]
    return messages


class _AgentSDKLLMEngine(LLMEngine):
    """
    Implement of Meta-are LLMEngine use interface provided by AgentSDK
    """

    def __init__(self,
                 completion=None,
                 tokenizer=None,
                 trajectory_store: list = None,
                 sampling_params: Dict = None,
                 max_model_len: int = None):
        """
        Initialization for _AgentSDKLLMEngine.

        Args:
            completion: The inference interface provided by AgentSDK.
            tokenizer: Tokenizer will be used to format message from ARE.
            trajectory_store: History chat messages will be store in trajectory_store.
            sampling_params: GRPO algorithm need sample trajectory for better performance.
            max_model_len: Calculate tokens to prevent over lengths.
        """

        super().__init__(model_name="")

        self.completion = completion
        self.tokenizer = tokenizer

        self.trajectory_store = trajectory_store

        self.sampling_params = sampling_params
        self.max_model_len = max_model_len

    def chat_completion(self,
                        messages: list[dict[str, Any]],
                        stop_sequences=[],
                        **kwargs) -> tuple[str, dict | None]:
        """
        Inference use AgentSDK. Add token length check. Successful response will be added as a trajectory step.
        
        Args:
            messages: History chat completions.
            stop_sequences: Stop words.
            **kwargs: Extended arguments.
        """

        messages = _transform_messages(messages)

        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        tokens = len(self.tokenizer.encode(prompt))

        if tokens >= self.max_model_len:
            raise ValueError(f"Prompt token is {tokens}, exceed max_model_len: {self.max_model_len}")

        self.sampling_params["max_tokens"] = self.max_model_len - tokens

        response = asyncio.run(self.completion({"prompt": prompt, **self.sampling_params}))

        choices = response.get("choices", [])
        if not choices:
            raise ValueError("No choices found in response")

        choice = choices[0]
        text = choice.get("text", "")

        messages.append({"role": "assistant", "content": text})
        self.trajectory_store.append(messages)

        return text, None


class _JudgeLLMEngine(LLMEngine):
    def __init__(self,
                 completion=None,
                 tokenizer=None,
                 sampling_params=None):
        """
        Initialization for _JudgeLLMEngine.

        Args:
            completion: The inference interface provided by AgentSDK.
            tokenizer: Tokenizer will be used to format message from Meta-are.
            sampling_params: Sampling params used for LLM inference.
        """
        super().__init__(model_name="")
        self.completion = completion
        self.tokenizer = tokenizer
        self.sampling_params = sampling_params

    def chat_completion(self,
                        messages: list[dict[str, Any]],
                        stop_sequences=[],
                        **kwargs):
        """
        Judgement use AgentSDK.

        Args:
            messages: History chat completions.
            stop_sequences: Stop words.
            **kwargs: Extended arguments.
        """
        messages = _transform_messages(messages)

        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        response = asyncio.run(self.completion({"prompt": prompt, **self.sampling_params}))

        choices = response.get("choices", [])
        if not choices:
            raise ValueError("No choices found in response for Judge")

        choice = choices[0]
        text = choice.get("text", "")

        return text, None


class AreEngineWrapper(BaseEngineWrapper):
    """
    Implement of BaseEngineWrapper from AgentSDK.

    Generating trajectory by using Meta-are.
    """

    def __init__(self,
                 agent_name: str,
                 tokenizer: Any,
                 sampling_params: Dict[str, Any],
                 max_prompt_length: int,
                 max_response_length: int,
                 n_parallel_agents: int = 8,
                 max_steps: int = 10):
        """
        Initialization.

        Args:
            agent_name: Choose which agent to use. Useless in Meta-are.
            tokenizer: Huggingface style tokenizer.
            sampling_params: Inference sampling params.
            max_prompt_length: Max prompt token length set by AgentSDK.
            max_response_length: Max response token length set by AgentSDK.
            n_parallel_agents: How many agents will be running at same time.
            max_steps: How many rounds of chat will be performed during one trajectory.
        """

        super().__init__(agent_name, tokenizer, sampling_params,
                         max_prompt_length, max_response_length, n_parallel_agents, max_steps)

        self.max_model_len = max_prompt_length + max_response_length

    def initialize(self):
        """
        Additional initialize. Useless in Meta-ara.
        """
        pass

    def _parse(self, messages, add_generation_prompt=False) -> str:
        """
        Parse messages, apply chat template.

        Args:
            messages: Messaged to be parsed.
            add_generation_prompt: If generation prompt should be added.

        Returns:
            str: Parsed prompt.
        """
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)

    def tokenize_and_mask(self, messages):
        """
        Tokenize messages and generate mask.

        In token-mode training, tokens before the first assistant role are used as prompt_ids, and the remaining
        dialogue is used as response_ids. A response_mask (value = 1) is applied only to assistant tokens so that
        only model-generated responses contribute to the loss, excluding tool call outputs.

        Args:
            messages: message list.

        Return:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: prompt_ids, response_ids and response_masks.
        """

        prompt_ids = []
        response_ids = []
        response_mask = []

        try:
            first_assistant_idx = next(i for i, msg in enumerate(messages) if msg["role"] == "assistant")
        except StopIteration:
            raise RuntimeError("No assistant message found in completions")

        for i in range(first_assistant_idx):
            parsed_msg = self._parse([messages[i]], add_generation_prompt=False)
            ids = self.tokenizer.encode(parsed_msg)
            prompt_ids.extend(ids)

        for i in range(first_assistant_idx, len(messages)):
            parsed_msg = self._parse([messages[i]], add_generation_prompt=False)
            ids = self.tokenizer.encode(parsed_msg)
            response_ids.extend(ids)

            if messages[i]["role"] == "assistant":
                response_mask.extend([1] * len(ids))
            else:
                response_mask.extend([0] * len(ids))

        prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
        response_ids = torch.tensor(response_ids, dtype=torch.long)
        response_mask = torch.tensor(response_mask, dtype=torch.long)

        return prompt_ids, response_ids, response_mask

    def worker(self, task_queue, result, completion):
        """
        Worker sub process, extract task from queue and execute task.

        Args:
            task_queue: Queue of tasks.
            result: Result list to store task result
            completion: AgentSDK provided completion interface.
        """

        while True:
            try:
                task_id, task = task_queue.get(timeout=10)
            except queue.Empty:
                return

            r = self.run(task, task_id, completion)

            result[task_id] = r

    def run(self, task, idx, completion):
        """
        Execute tasks, generate trajectory and calculate scores.

        Args:
            task: Task of gaia2, generated from gaia2 datasets, should have Meta-are scenario data in json str format.
            idx: Index of task in the train_iter.
            completion: AgentSDK provided completion interface.
        """
        if "data" not in task.keys():
            raise ValueError("Task from AgentSDK does not have data! please check config of data_path or "
                             "datasets_additional_keys.")

        scenario_task = task["data"]

        trajectory_store = []

        result = self.generate_traj(scenario_data=scenario_task,
                                    trajectory_store=trajectory_store,
                                    completion=completion,
                                    sampling_params=self.sampling_params,
                                    max_steps=self.max_steps)

        # get final result score from validate result of Meta-are
        if result.success:
            res_reward = 4.0
        else:
            res_reward = -2.0

        # select the last (i.e., the longest) step trajectory that meet the length requirement as the valid trajectory
        valid_length_traj = None
        for chat_completions in trajectory_store:
            if len(
                    self.tokenizer.encode(self.tokenizer.apply_chat_template(chat_completions, tokenize=False))
            ) < self.max_model_len:
                valid_length_traj = chat_completions

        if not valid_length_traj:
            raise RuntimeError("No trajectory within valid length")

        # Count all assistant messages that contain tool call results and assign scores to them
        score = 0
        assistant_cnt = 0
        for i, chat in enumerate(valid_length_traj):
            if chat["role"] == "assistant":
                if i + 1 < len(valid_length_traj):
                    assistant_cnt = assistant_cnt + 1
                    if valid_length_traj[i + 1]["role"] == "user":
                        score = score + get_score_from_are_tool_response(valid_length_traj[i + 1]["content"])

        # compute the average score of the tool calls and subtract the length penalty associated with tool calls
        if assistant_cnt > 0:
            toolcall_reward = score / assistant_cnt - 0.01 * assistant_cnt
        # if no tool calls, just return -1
        else:
            toolcall_reward = -1

        trajectory_reward = res_reward + toolcall_reward

        prompt_ids, response_ids, response_mask = self.tokenize_and_mask(valid_length_traj)

        trajectory = Trajectory(
            prompt_tokens=prompt_ids,
            response_tokens=response_ids,
            response_masks=response_mask,
            idx=idx,
            trajectory_reward=trajectory_reward,
            chat_completions=valid_length_traj,
            metrics={
                "steps": assistant_cnt + 1,
                "reward_time": None,
                "env_time": None,
                "llm_time": None,
                "total_time": result.duration,
                "res_reward": res_reward,
                "toolcall_reward": toolcall_reward,
            }
        )

        return trajectory

    def generate_agent_trajectories_async(self, tasks: List[dict]) -> List[Trajectory]:
        """
        Generate several trajectories asynchronously using multiple agents.

        Args:
            tasks: Meta-are scenario data.
        """
        completions_size = len(self.completions)
        num_worker = self.n_parallel_agents
        task_queue = queue.Queue()

        result = [None] * len(tasks)

        for idx, task in enumerate(tasks):
            task_queue.put((idx, task))

        runners = []
        for wid in range(num_worker):
            p = Thread(target=self.worker, args=(task_queue, result, self.completions[wid % completions_size]))
            p.start()
            runners.append(p)

        for p in runners:
            p.join()

        return result

    def generate_traj(self,
                      scenario_data: str,
                      trajectory_store: list = None,
                      completion=None,
                      sampling_params: dict = None,
                      max_steps: int = 5):
        """
        Generate one trajectory with Meta-are.

        Args:
            scenario_data: Gaia2 style scenario data, will be deserialized by Meta-are.
            trajectory_store: List to store trajectory step by step.
            completion: AgentSDK provided completion interface.
            sampling_params: Sampling params to generate trajectory.
            max_steps: How many rounds of chat will be performed in one trajectory.

        Return:
            Validation result: If this scenario executed successfully and execute time.
        """
        tokenizer = self.tokenizer
        max_model_len = self.max_model_len

        # create a builder for Meta-are to create llm engine provided by us
        class _AgentSDKEngineBuilder(LLMEngineBuilder):
            def create_engine(self,
                              engine_config: LLMEngineConfig,
                              mock_responses: list[str] | None = None) -> LLMEngine:
                return _AgentSDKLLMEngine(completion=completion,
                                          tokenizer=tokenizer,
                                          trajectory_store=trajectory_store,
                                          sampling_params=sampling_params,
                                          max_model_len=max_model_len)

        # use json str style scenario_data to deserialize scenario
        scenario, _ = load_scenario(
            scenario_data, "scenario_id", False
        )

        # use llm engine provided by us
        judge_engine = _JudgeLLMEngine(completion=completion, tokenizer=self.tokenizer, sampling_params=sampling_params)
        judge_config = GraphPerEventJudgeConfig(engine=judge_engine)

        # preprocess scenario, check oracle event, initialize turn of scenario and create validation function
        preprocess_scenario(scenario=scenario,
                            judge_config=judge_config,
                            offline_validation=False,
                            max_scenario_duration=MAX_SCENARIO_DURATION,
                            tool_augmentation_config=None,
                            env_events_config=None)

        # create and start env
        env_config = EnvironmentConfig(oracle_mode=False,
                                       queue_based_loop=False,
                                       wait_for_user_input_timeout=None,
                                       dump_dir=None,
                                       time_increment_in_seconds=scenario.time_increment_in_seconds,
                                       exit_when_no_events=False)
        if scenario.start_time and scenario.start_time > 0:
            env_config.start_time = scenario.start_time
        env = Environment(environment_type=EnvironmentType.CLI,
                          config=env_config,
                          notification_system=VerboseNotificationSystem())
        env.run(scenario, wait_for_end=False)

        # create agent and start running agent
        agent_config_builder = AgentConfigBuilder()
        agent_config = agent_config_builder.build("default")
        simulated_generation_time_config = SimulatedGenerationTimeConfig(mode="measured")
        agent_config.get_base_agent_config().simulated_generation_time_config = simulated_generation_time_config
        if isinstance(agent_config, MainAgentConfig) and scenario.nb_turns is not None:
            agent_config.max_turns = scenario.nb_turns

        agent_builder = AgentBuilder(_AgentSDKEngineBuilder())
        are_simulation_agent: RunnableARESimulationAgent = agent_builder.build(
            agent_config=agent_config, env=env
        )
        are_simulation_agent.react_agent.max_iterations = max_steps
        are_simulation_agent.run_scenario(scenario=scenario, notification_system=env.notification_system)

        # validate result
        validate_result = scenario.validate(env)

        # clean resources
        env.stop()

        return validate_result