"""
Copyright 2026 Huawei Technologies Co., Ltd

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import copy
import uuid
from typing import Any

from rllm.agents.agent import Action, BaseAgent, Trajectory, Step

from examples.agents.websearcher.websearcher_tool_parser import WebSearcherToolParser
from examples.agents.websearcher.websearcher_tools import websearcher_tools

from agentic_rl import MemorySummary


class WebSearcherAgent(BaseAgent):
    """
    An agent that can perform web searches using predefined tools.
    """
    def __init__(
            self,
            memory_config: dict = None
    ):
        """
        Initialize the WebSearcherAgent with specified memory configuration.

        Args:
            memory_config (dict): Configuration for the agent's memory management.
        """
        if memory_config is not None and not isinstance(memory_config, dict):
            raise TypeError("memory_config must be a dictionary if not None.")
        
        if memory_config is None:
            from openai import OpenAI
            client = OpenAI(base_url="/your/oai_model_url")
            memory_config = {
                "simplify_thinking": False,
                "use_summary": False,
                "max_summary_length": 1024,
                "max_prompt_length": 8192,
                "before_raw_message": 2,
                "end_raw_message": -2,
                "train_model_tokenizer_path": "/path/to/tokenizer",
                "oai_client": client,
                "oai_model_name": "qwen-2.5-7b-instruct"
            }
        
        self.tool_parser = WebSearcherToolParser()
        self.tools_prompt = self.tool_parser.get_tool_prompt(str(websearcher_tools))
        self.memory = MemorySummary(config=memory_config)

        # initial state
        self._trajectory = Trajectory()
        self._current_observation = None
        self.reset()

    @property
    def trajectory(self) -> Trajectory:
        """
        Returns the current trajectory of the agent.

        Returns:
            Trajectory: The agent's trajectory.
        """
        return self._trajectory
    
    @property
    def chat_completions(self) -> list[dict[str, str]]:
        """
        Returns the current message history of the agent.

        Returns:
            list[dict[str, str]]: The agent's current chat completions.
        """
        return self.memory.get_prompt_messages()
        
    def update_from_env(self, observation: Any, reward: float, done: bool, info: dict):
        """
        Updates the agent's internal state based on environment feedback.

        Args:
            observation (Any): The observation received from the environment.
            reward (float): The reward received from the environment.
            done (bool): Whether the episode is done.
            info (dict): Additional information from the environment.
        """
        obs_messages = self._format_observation(observation)
        self.memory.add_message(obs_messages, metadata=[{"reward": reward}])
        self._current_observation = observation

    def update_from_model(self, response: str):
        """
        Updates the agent's internal state based on the model's response.

        Args:
            response (str): The response generated by the model.

        Returns:
            Action: An Action object containing the parsed tool calls or final response.        
        """
        if not isinstance(response, str):
            raise TypeError("response must be a string.")
        
        tool_call = self.tool_parser.parse(response)
        if tool_call:
            formatted_tool_call = {
                            "id": str(uuid.uuid4()),
                            "type": "function",
                            "function": tool_call.to_dict()
                        }
        else:
            formatted_tool_call = {
                            "id": str(uuid.uuid4()),
                            "type": "function",
                            "function": {
                                "name": "finish",
                                "arguments": {"response": response}
                            }
                        }
                    
        assistant_message = {"role": "assistant", "content": response}
        self.memory.add_message(assistant_message)
        
        new_step = Step(
            chat_completions=copy.deepcopy(self.chat_completions),
            action=formatted_tool_call,
            model_response=response,
            observation=self._current_observation,
        )
        self._trajectory.steps.append(new_step)

        return Action(action=formatted_tool_call)
    
    def reset(self):
        """
        Resets the agent's internal state and memory for a new episode.
        """
        self._trajectory = Trajectory()
        self.memory.clear_memory("system", self.tools_prompt)

    def _format_observation(self, observation: Any) -> list[dict]:
        """
        Formats an observation into a list of message dictionaries.

        Args:
            observation (Any): The observation to format.
        
        Returns:
            list[dict]: A list of formatted message dictionaries, each containing a role and content field.
        """
        messages = []
        if isinstance(observation, dict):
            if "problem" in observation:
                messages.append({"role": "user", "content": observation["problem"]})
            elif "tool_output" in observation:
                for tool_call_id, tool_result in observation["tool_output"].items():
                    messages.append({
                        "role": "tool",
                        "content": tool_result,
                        "tool_call_id": tool_call_id
                    })
        elif isinstance(observation, str):
            messages.append({"role": "user", "content": observation})
        elif observation is not None:
            messages.append({"role": "user", "content": str(observation)})
        else:
            raise ValueError("Empty observation received.")
            
        return messages