from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_outputs import ModelOutput

from .configuration_thermo_qwen import ThermoQwenConfig


@dataclass
class ThermoQwenForecastOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    pred_depths: Optional[torch.FloatTensor] = None


def _config_from_dict(config_dict):
    model_type = config_dict.get("model_type")
    kwargs = dict(config_dict)
    kwargs.pop("model_type", None)
    return AutoConfig.for_model(model_type, **kwargs)


def _pool_hidden_states(hidden_states, attention_mask):
    if attention_mask is None:
        mean_pool = hidden_states.mean(dim=1)
        last_pool = hidden_states[:, -1, :]
        return torch.cat([last_pool, mean_pool], dim=-1)

    mask = attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
    mean_pool = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
    seq_lengths = attention_mask.sum(dim=1).clamp(min=1) - 1
    last_pool = hidden_states[
        torch.arange(hidden_states.size(0), device=hidden_states.device),
        seq_lengths,
    ]
    return torch.cat([last_pool, mean_pool], dim=-1)


class ThermoQwenForForecasting(PreTrainedModel):
    config_class = ThermoQwenConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _keys_to_ignore_on_load_missing = [r"model\.lm_head\.weight"]

    def __init__(self, config: ThermoQwenConfig):
        super().__init__(config)

        if config.backbone_config is not None:
            backbone_config = _config_from_dict(config.backbone_config)
        elif config.base_model_name_or_path:
            backbone_config = AutoConfig.from_pretrained(
                config.base_model_name_or_path,
                trust_remote_code=True,
            )
        else:
            raise ValueError(
                "ThermoQwenConfig requires either backbone_config or "
                "base_model_name_or_path."
            )

        base_model = AutoModelForCausalLM.from_config(
            backbone_config,
            trust_remote_code=True,
        )
        lora_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=config.lora_target_modules,
            bias="none",
            task_type="SEQ_CLS",
        )
        self.model = get_peft_model(base_model, lora_config).model
        # The forecasting wrapper never uses lm_head logits. Keep the outer
        # tied-weight list empty so Transformers does not look for a top-level
        # lm_head on this wrapper.
        self.all_tied_weights_keys = {}

        hidden_size = backbone_config.hidden_size
        self.regressor = self._build_regressor(
            hidden_size * 2,
            config.regressor_hidden_sizes,
            config.regressor_dropout,
            config.horizon,
        )

    @staticmethod
    def _build_regressor(input_size, hidden_sizes, dropouts, output_size):
        layers = [nn.LayerNorm(input_size)]
        in_features = input_size
        for idx, hidden_size in enumerate(hidden_sizes):
            layers.extend([
                nn.Linear(in_features, hidden_size),
                nn.GELU(),
            ])
            dropout = dropouts[idx] if idx < len(dropouts) else 0.0
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            in_features = hidden_size
        layers.append(nn.Linear(in_features, output_size))
        return nn.Sequential(*layers)

    def denormalize_depth(self, logits):
        depth_range = self.config.depth_max - self.config.depth_min
        return (logits + 1.0) / 2.0 * depth_range + self.config.depth_min

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,
        output_hidden_states=True,
        return_dict=True,
        **kwargs,
    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs,
        )
        pooled = _pool_hidden_states(outputs.hidden_states[-1], attention_mask)
        logits = self.regressor(pooled.to(self.regressor[0].weight.dtype))

        loss = None
        if labels is not None:
            loss = F.huber_loss(logits.float(), labels.float(), delta=0.2)

        pred_depths = self.denormalize_depth(logits.float())
        if not return_dict:
            output: Tuple[torch.Tensor, ...] = (logits, pred_depths)
            return ((loss,) + output) if loss is not None else output

        return ThermoQwenForecastOutput(
            loss=loss,
            logits=logits,
            pred_depths=pred_depths,
        )