import json
import logging
import os
from typing import Dict, Optional

import numpy as np
import torch
from transformers.initialization import no_init_weights

from ..core.model_source_security import normalize_model_source
from ..layers.attention import AttentionTensorCast
from ..layers.quant_linear import TensorCastQuantLinear
from ..model_config import (
    DiffusersConfig,
    DiffusersTransformerConfig,
    DiffusersVaeConfig,
    RemoteSource,
)
from ..parallel_group import ParallelGroup
from ..transformers.model import ModelWrapperBase
from ..transformers.transformations import quantize_linear, quantize_model, wrap_model
from ..transformers.utils import init_on_device_without_buffers
from .cache_agent import CacheConfig, CacheState
from .cache_agent.dit_block_cache import DiTBlockCache
from .diffusers_utils import get_diffusers_transformer_module
from .dit_cache_registry import get_dit_block_cache_spec, replace_blocks_in_range
from .model_resolver import resolve_diffusers_model_path

logger = logging.getLogger(__name__)


def build_diffusers_transformer_model(
    model_id: str,
    parallel_config: None,
    quant_config: None,
    dtype: torch.dtype,
    remote_source: str = RemoteSource.huggingface,
    resolved_model_path: str | None = None,
):
    if resolved_model_path is None:
        resolved_model_path = resolve_diffusers_model_path(model_id, remote_source)
    model_config = load_config_from_file(
        model_path=resolved_model_path,
        parallel_config=parallel_config,
        quant_config=quant_config,
        quant_linear_cls=TensorCastQuantLinear,
        attention_cls=AttentionTensorCast,
        dtype=dtype,
    )
    model = DiffusersTransformerModel(model_id, model_config.transformer_config)
    return model, model_config


def load_config_from_file(
    model_path: str,
    parallel_config: None,
    quant_config: None,
    quant_linear_cls: None,
    attention_cls: None,
    dtype: torch.dtype,
):
    # TODO add seperate parallel_config and quant_config(atten_cls is needed?) for vae and text
    source_info = normalize_model_source(model_path, RemoteSource.huggingface)
    resolved_model_path = source_info.model_id
    if not source_info.is_local_path or not os.path.isdir(resolved_model_path):
        raise ValueError(f"Input args.model_id should be dir, but got {resolved_model_path}")

    config_path_dict: Dict[str, str] = {}
    for root, _, files in os.walk(resolved_model_path):
        if "config.json" in files:
            folder_name = os.path.basename(root)
            config_path = os.path.join(root, "config.json")
            config_path = os.path.abspath(config_path)
            config_path_dict[folder_name] = config_path

    def _load_config(config_path: str) -> Dict:
        with open(config_path, encoding="utf-8") as f:
            return json.load(f)

    transformer_config_json_path = config_path_dict.get("transformer")
    transformer_config = None
    if transformer_config_json_path is not None:
        transformer_config = _load_config(transformer_config_json_path)
    else:

        def _looks_like_transformer_config(cfg: Dict) -> bool:
            class_name = cfg.get("_class_name")
            return isinstance(class_name, str) and "Transformer" in class_name

        transformer_candidates: Dict[str, tuple[str, Dict]] = {}
        for folder_name, config_path in config_path_dict.items():
            config = _load_config(config_path)
            if _looks_like_transformer_config(config):
                transformer_candidates[folder_name] = (config_path, config)

        if len(transformer_candidates) == 1:
            transformer_config_json_path, transformer_config = next(iter(transformer_candidates.values()))
        else:
            raise ValueError(
                "No transformer/config.json found in input model path. "
                "Expect a Diffusers-style model directory that contains transformer/config.json."
            )

    vae_config_json_path = config_path_dict.get("vae")

    model_config = DiffusersConfig()
    model_config.model_path = resolved_model_path
    model_config.transformer_config = DiffusersTransformerConfig(
        parallel_config=parallel_config,
        quant_config=quant_config,
        config_json=transformer_config_json_path,
        model_config=transformer_config,
        quant_linear_cls=quant_linear_cls,
        attention_cls=attention_cls,
        dtype=dtype,
    )
    if vae_config_json_path is not None and os.path.isfile(vae_config_json_path):
        with open(vae_config_json_path, encoding="utf-8") as f:
            vae_config = json.load(f)
        model_config.vae_config = DiffusersVaeConfig(
            parallel_config=parallel_config,
            quant_config=quant_config,
            config_json=vae_config_json_path,
            model_config=vae_config,
            dtype=dtype,
        )
    return model_config


class DiffusersTransformerModel(ModelWrapperBase):
    def __init__(
        self,
        model_id: str,
        model_config: DiffusersTransformerConfig,
    ):
        super().__init__(None)
        self.model_id = model_id
        self.model_config = model_config

        hf_config_json = self.model_config.config_json
        self.sp_group = get_sp_group(
            world_size=self.model_config.parallel_config.world_size,
            ulysses_size=self.model_config.parallel_config.ulysses_size,
        )

        if hf_config_json is None:
            raise ValueError("hf_config_json should not be None.")
        hf_config = self.model_config.model_config
        if hf_config is None:
            raise ValueError("transformer model_config should not be None.")
        model_class = get_diffusers_transformer_module(hf_config)

        with init_on_device_without_buffers("meta"), no_init_weights():
            self._inner = model_class.from_config(hf_config).to(model_config.dtype)
        self._inner.eval()
        wrap_model(self)
        quantize_model(self)
        quantize_linear(self)

    def forward(
        self,
        hidden_states: torch.Tensor,
        timestep: torch.LongTensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_states_images: Optional[torch.Tensor] = None,
        return_dict=False,
        **kwargs: object,
    ):
        hidden_states = self._inner(
            hidden_states=hidden_states,
            timestep=timestep,
            encoder_hidden_states=encoder_hidden_states,
            return_dict=return_dict,
            **kwargs,
        )[0]
        return hidden_states

    def enable_dit_block_cache(self, cache_config: CacheConfig) -> Optional[CacheState]:
        """
        Enable DiT block cache (dit_block_cache).

        Replace blocks in the configured cache range with cache-aware wrappers.
        Step scheduling (update/reuse/bypass) is driven externally by the caller.
        """
        model_config = self.model_config.model_config or {}
        class_name = model_config.get("_class_name")
        spec = get_dit_block_cache_spec(class_name)
        if spec is None:
            logger.warning("dit_block_cache is not implemented for model %r.", class_name)
            return None

        blocks_with_setters = list(spec.get_blocks_with_setters(self._inner))
        if not blocks_with_setters:
            return None
        blocks_count = len(blocks_with_setters)

        bounded_block_end = min(cache_config.block_end, blocks_count)

        cache_state = CacheState()
        replaced = replace_blocks_in_range(
            blocks_with_setters,
            cache_config.block_start,
            bounded_block_end,
            lambda block, flat_idx: DiTBlockCache(
                block=block,
                state=cache_state,
                block_index=flat_idx,
                block_start=cache_config.block_start,
                block_end=bounded_block_end,
                make_wrapped_forward=spec.make_wrapped_forward,
            ),
        )

        logger.info(
            "Enabled dit_block_cache for %s: replaced %d blocks in range [%d, %d) out of %d.",
            spec.model_type,
            replaced,
            cache_config.block_start,
            bounded_block_end,
            blocks_count,
        )
        return cache_state if replaced > 0 else None


def get_sp_group(world_size: int, ulysses_size: int) -> ParallelGroup:
    all_ranks = np.arange(world_size)
    rank = 0
    if ulysses_size > 0:
        rank_groups = all_ranks.reshape(-1, ulysses_size)
    else:
        rank_groups = all_ranks.reshape(1, -1)
    sp_group = ParallelGroup(
        rank=rank,
        rank_groups=[x.tolist() for x in rank_groups],
        global_world_size=world_size,
    )
    return sp_group