"""
Copyright 2026 Huawei Technologies Co., Ltd

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import time
import torch
from typing import Any, Dict, List, Optional, Literal, TypedDict, Annotated
import uuid

from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, ToolMessage, AnyMessage
from langchain_core.messages.content import InvalidToolCall
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel
from rllm.parser.chat_template import ChatTemplateParser  # v0.1
from rllm.agents.agent import Trajectory, Action
from rllm.agents.utils import (
    convert_messages_to_tokens_and_masks,
    get_recent_assistant_user_messages,
)
from rllm.environments.env_utils import compute_mc_return
from agentic_rl import BaseEngineWrapper, Trajectory as AgenticRlTrajectory
from examples.agents.agents_mapping import get_agent_by_name
from examples.rllm.utils.utils import compute_trajectory_reward


logger = logging.getLogger(__name__)
MODEL = "Qwen2.5-7B-Instruct"
GREEN = "\033[92m"
RESET = "\033[0m"


class AgentState(TypedDict):
    messages: Annotated[List, add_messages]
    agent: Any
    env: Any
    observation: Any
    reward: float
    done: bool
    info: Any
    delta_time: float
    tool_calls: Any
    llm_time: float
    env_delta_time: float


def tools_condition(state: AgentState) -> Literal["tools", "agent", "__end__"]:
    """Determines the next execution path in LangGraph based on reward value.

    Args:
        state: Current graph state containing agent, environment, and execution data.

    Returns:
        str: Next node to execute - "tools", "agent", or "__end__".
    """
    if "reward" in state:
        if state["reward"] < 0:
            return "agent"
        else:
            return "tools"
    return "end"


class LangGraphEngineWrapper(BaseEngineWrapper):
    DEFAULT_MAX_PROMPT_LENGTH = 8192
    DEFAULT_MAX_RESPONSE_LENGTH = 16384
    DEFAULT_N_PARALLEL_AGENTS = 8
    DEFAULT_MAX_STEPS = 128
    DEFAULT_ENV_CREATION_WORKERS = 64
    DEFAULT_AGENT_CREATION_WORKERS = 64
    DEFAULT_MAX_WORKERS = 64

    def __init__(
        self,
        agent_name: str,
        tokenizer: Any,
        sampling_params: Optional[Dict[str, Any]] = None,
        max_prompt_length: int = DEFAULT_MAX_PROMPT_LENGTH,
        max_response_length: int = DEFAULT_MAX_RESPONSE_LENGTH,
        n_parallel_agents: int = DEFAULT_N_PARALLEL_AGENTS,
        max_steps: int = DEFAULT_MAX_STEPS,
    ) -> None:
        """Initializes the LangGraphEngineWrapper.

        Args:
            agent_name: Name of the agent type to use.
            tokenizer: Tokenizer for processing messages.
            sampling_params: Parameters for model sampling.
            max_prompt_length: Maximum allowed prompt length in tokens.
            max_response_length: Maximum allowed response length in tokens.
            n_parallel_agents: Number of parallel agents to run.
            max_steps: Maximum number of steps per trajectory.
        """
        self.executor = ThreadPoolExecutor(max_workers=self.DEFAULT_MAX_WORKERS)
        super().__init__(
            agent_name,
            tokenizer,
            sampling_params,
            max_prompt_length,
            max_response_length,
            n_parallel_agents,
            max_steps,
        )

        agent_config = get_agent_by_name(agent_name)
        self.agent_class = agent_config.agent_class
        self.env_class = agent_config.env_class
        self.agent_args = agent_config.agent_args
        self.env_args = agent_config.env_args

        self.graphs = []
        self.agents = []
        self.envs = []

    def initialize(self):
        """Initializes the engine with timeout and chat parser settings."""
        self.trajectory_timeout = 1e9
        self.chat_parser = ChatTemplateParser.get_parser(
            self.tokenizer, disable_thinking=False
        )

    async def run_search_agent(self, idx: int, max_truns: int = 5):
        """Runs a single search agent to generate a trajectory for a given task.

        Args:
            idx: Index of the agent/environment pair to use.
            max_truns: Maximum number of graph execution runs.

        Returns:
            Dict: Token result containing trajectory data, tokens, masks, and metrics.
        """
        agent = self.agents[idx]
        env = self.envs[idx]
        termination_reason = None
        done = False
        response_token_len = 0
        response_tokens = []
        response_masks = []
        total_time = 0.0
        reward_time = None
        next_observation = None
        llm_time = 0.0
        env_time = 0.0
        reward = 0.0

        loop = asyncio.get_event_loop()
        observation, info = await loop.run_in_executor(self.executor, env.reset)
        info["max_steps"] = self.max_steps

        agent.reset()
        agent.update_from_env(
            observation=observation,
            reward=0.0,
            done=False,
            info=info,
        )

        messages = agent.chat_completions
        prompt_tokens, _ = convert_messages_to_tokens_and_masks(
            messages,
            tokenizer=self.tokenizer,
            parser=self.chat_parser,
            contains_first_msg=True,
            contains_generation_msg=True,
        )
        prompt_token_len = len(prompt_tokens)

        if prompt_token_len > self.max_prompt_length:
            agent.reset()
            raise Exception(
                f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying"
            )

        async for step in self.graphs[idx].astream(
            {"messages": messages, "agent": agent, "env": env},
            {"recursion_limit": max_truns * 2 + 5},
        ):
            node_name, ctx = step.popitem()
            message = ctx["messages"][0]
            content = message.content
            step_done = True
            if node_name == "agent":
                llm_time += ctx["llm_time"]
                reward = ctx["reward"]
                if not ctx["done"] and reward >= 0:
                    step_done = False

            elif node_name == "tools":

                reward += ctx["reward"]

            env_time += ctx["env_delta_time"]
            total_time += ctx["llm_time"] + ctx["env_delta_time"]
            if not step_done:
                continue
            next_observation = ctx["observation"]
            reward = ctx["reward"]
            done = ctx["done"]
            info = ctx["info"]
            info["max_steps"] = self.max_steps
            info["cur_tokens"] = response_token_len
            agent.update_from_env(
                observation=next_observation,
                reward=reward,
                done=done,
                info=info,
            )
            cur_step = agent.get_current_state()
            cur_step.reward = reward
            cur_step.done = done
            cur_step.info.update(info)

            chat_completions_messages = agent.chat_completions
            assistant_message, env_messages = get_recent_assistant_user_messages(
                chat_completions_messages
            )

            assistant_msg_tokens, assistant_msg_masks = [], []
            env_msg_tokens, env_msg_masks = [], []

            if assistant_message:
                assistant_msg_tokens, assistant_msg_masks = (
                    convert_messages_to_tokens_and_masks(
                        [assistant_message],
                        tokenizer=self.tokenizer,
                        parser=self.chat_parser,
                        contains_first_msg=False,
                        contains_generation_msg=False,
                    )
                )

            if env_messages:
                env_msg_tokens, env_msg_masks = convert_messages_to_tokens_and_masks(
                    env_messages,
                    tokenizer=self.tokenizer,
                    parser=self.chat_parser,
                    contains_first_msg=False,
                    contains_generation_msg=True,
                )

            response_token_len += len(assistant_msg_tokens) + len(env_msg_tokens)

            if response_token_len >= self.max_response_length:
                truncation_length = self.max_response_length - response_token_len

                if truncation_length < 0:
                    truncated_response_tokens = (assistant_msg_tokens + env_msg_tokens)[
                        :truncation_length
                    ]
                    truncated_response_masks = (assistant_msg_masks + env_msg_masks)[
                        :truncation_length
                    ]
                else:
                    truncated_response_tokens = assistant_msg_tokens + env_msg_tokens
                    truncated_response_masks = assistant_msg_masks + env_msg_masks

                response_tokens.extend(truncated_response_tokens)
                response_masks.extend(truncated_response_masks)

                cur_step = agent.get_current_state()
                if response_token_len - len(env_msg_tokens) > self.max_response_length:
                    cur_step.reward = 0.0
                cur_step.done = True
                termination_reason = "TRUNCATION"
                break

            response_tokens.extend(assistant_msg_tokens)
            response_masks.extend(assistant_msg_masks)

            if total_time >= self.trajectory_timeout:
                termination_reason = "TIMEOUT"
                cur_step = agent.get_current_state()
                done = True
                cur_step.done = done
                break

            if done:
                termination_reason = "ENV_DONE"
                break

            response_tokens.extend(env_msg_tokens)
            response_masks.extend(env_msg_masks)

            if env.step_count == self.max_steps - 1:
                termination_reason = "MAX_STEPS"
                break

        trajectory: Trajectory = agent.trajectory
        compute_trajectory_reward(trajectory)
        compute_mc_return(trajectory, gamma=0.2)
        print(
            f"{GREEN}Trajectory {idx} completed due to: {termination_reason}. Reward is {trajectory.reward}. \n{RESET}"
        )
        token_result = {
            "prompt_tokens": torch.tensor(prompt_tokens, dtype=torch.long),
            "response_tokens": torch.tensor(response_tokens, dtype=torch.long),
            "response_masks": torch.tensor(response_masks, dtype=torch.long),
            "trajectory_reward": trajectory.reward,
            "idx": env.idx,
            "chat_completions": agent.chat_completions,
            "metrics": {
                "steps": len(trajectory.steps),
                "reward_time": reward_time,
                "env_time": env_time,
                "llm_time": llm_time,
                "total_time": total_time,
                "res_reward": trajectory.res_reward,
                "toolcall_reward": trajectory.toolcall_reward,
            },
        }

        return token_result

    def init_envs_and_agents(self, tasks: List[dict]):
        """Initializes environments, agents, and LangGraph workflows for all tasks.

        Args:
            tasks: List of task dictionaries to initialize.
        """
        task_num = len(tasks)
        logger.info(f"Initializing {task_num} environments and agents...")

        search_url = self.env_args.get("search_url", "")
        address_num = len(self.server_addresses)
        self.envs = [None] * task_num
        self.agents = [None] * task_num
        self.graphs = [None] * task_num
        for i, _ in enumerate(tasks):
            agent = self.agent_class(**self.agent_args)
            self.agents[i] = agent

            env_args_copy = self.env_args.copy()
            env_args_copy["task"] = tasks[i]
            env_args_copy["max_steps"] = self.max_steps
            env = self.env_class.from_dict(env_args_copy)
            env.idx = i
            self.envs[i] = env

            address = self.server_addresses[i % address_num]

            llm_model = init_chat_model(
                MODEL,
                model_provider="openai",
                base_url="http://" + address + "/v1",
                api_key="EMPTY",
            )

            retriever_tool = env.to_langchain_tool(
                server_url=search_url, tokenizer=self.tokenizer, max_tool_length=8192
            )

            def create_agent_step(llm_model):
                async def agent_step(state: AgentState) -> Dict[str, Any]:
                    agent = state["agent"]
                    env = state["env"]
                    prompt_messages = agent.chat_completions.copy()
                    start_time = time.time()
                    response = await llm_model.ainvoke(prompt_messages)
                    state["llm_time"] = time.time() - start_time
                    try:
                        tool_call = agent.tool_parser.parse(response.content)
                        tool_call_dict = (
                            {
                                "id": str(uuid.uuid4()),
                                "type": "tool_call",
                                "function": tool_call,
                            }
                            if tool_call
                            else None
                        )
                    except Exception as e:
                        logger.error(
                            f"Failed to parse tool calls from string response: {e}"
                        )
                        tool_call_dict = None

                    if tool_call_dict:
                        response = AIMessage(
                            content=response.content,
                            tool_calls=[
                                {
                                    "name": tool_call_dict["function"].name,
                                    "args": tool_call_dict["function"].arguments,
                                    "id": tool_call_dict["id"],
                                    "type": tool_call_dict["type"],
                                }
                            ],
                        )
                    action: Action = agent.update_from_model(response.content)
                    action = action.action
                    start_time = time.time()
                    (
                        state["observation"],
                        state["reward"],
                        state["done"],
                        state["info"],
                    ) = env.calculate_llm_reward(action)
                    state["env_delta_time"] = time.time() - start_time
                    state["messages"] = [response]
                    state["tool_calls"] = response.tool_calls
                    return state
                return agent_step

            def wrap_tool_node(tools):
                tool_node = ToolNode(tools)

                async def wrapped_node(state: Dict[str, Any]) -> Dict[str, Any]:
                    start_time = time.time()
                    result = await tool_node.ainvoke(state)
                    state.update(result)
                    env = state["env"]
                    (
                        state["observation"],
                        state["reward"],
                        state["done"],
                        state["info"],
                    ) = env.calculate_tool_reward(
                        state["messages"][0].content, state["tool_calls"]
                    )
                    state["env_delta_time"] = time.time() - start_time
                    return state

                return wrapped_node

            workflow = StateGraph(AgentState)
            workflow.add_node("agent", create_agent_step(llm_model))
            workflow.add_node("tools", wrap_tool_node([retriever_tool]))

            workflow.add_edge(START, "agent")
            workflow.add_conditional_edges(
                "agent",
                tools_condition,
                {
                    "tools": "tools",
                    "agent": "agent",
                    "end": END,
                },
            )
            workflow.add_edge("tools", "agent")

            graph = workflow.compile()
            self.graphs[i] = graph

        logger.info(
            f"Successfully initialized {len(self.graphs)} agents of {task_num} tasks."
        )

    def generate_agent_trajectories_async(self, tasks: List[dict]):
        """Generates trajectories for multiple tasks asynchronously.

        Args:
            tasks: List of task dictionaries to process.

        Returns:
            List: Generated trajectories for all tasks.
        """
        self.init_envs_and_agents(tasks)
        result = asyncio.run(self._generate_agent_trajectories_async(tasks))
        return result

    async def _generate_agent_trajectories_async(self, tasks: List[dict]):
        """Internal method to generate trajectories asynchronously using asyncio.

        Args:
            tasks: List of task dictionaries to process.

        Returns:
            List: Generated trajectories.
        """
        trajectories = []

        async def launch_one_trajectory_task(env_idx: int):
            try:
                result = await self.run_search_agent(
                    env_idx, self.env_args.get("max_steps", 5)
                )
            except Exception as e:
                logger.error(f"Trajectory {env_idx} trajectory generation failed.")
                raise e
            return result

        tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))]

        tasks_completed = 0
        for future in asyncio.as_completed(tasks_to_run):
            try:
                result = await future
                tasks_completed += 1
                print(
                    f"{GREEN}Number of Trajectories {tasks_completed}/{len(self.envs)} completed"
                )
                trajectories.append(AgenticRlTrajectory(**result))
            except Exception as e:
                logger.error(
                    f"Trajectory generation failed. {tasks_completed} trajectories have been generated now."
                )
                raise e

        return trajectories