import torch

from ..custom_model_registry import (
    ModelProfile,
    register_model_profile,
)


def _patch_hf_config_for_kimi_k25(config):
    """Fix HuggingFace config and import environment for Kimi K2.5.

    These patches modify the Transformers *environment* (not model classes)
    and must run BEFORE the model is loaded so the downstream loading code
    picks up the corrected settings.

    This function does NOT require ``model_id`` — it operates purely on the
    HF config object and global import state.
    """
    import transformers.utils.import_utils as import_utils
    import logging
    import importlib.util

    logger = logging.getLogger(__name__)

    model_type = getattr(config, "model_type", None)
    if model_type != "kimi_k25":
        return False

    patched = False

    # ----------------------------------------------------------------
    # Patch 1: Restore is_torch_fx_available
    # ----------------------------------------------------------------
    if not hasattr(import_utils, "is_torch_fx_available"):

        def is_torch_fx_available():
            return importlib.util.find_spec("torch.fx") is not None

        import_utils.is_torch_fx_available = is_torch_fx_available
        patched = True

    # ----------------------------------------------------------------
    # Patch 2: Downgrade flash_attention_2 → tensor_cast (PRE-LOAD)
    # ----------------------------------------------------------------
    # WHY:   Kimi K2.5's config.json specifies "_attn_implementation": "flash_attention_2".
    # transformers 5.x enforces flash_attn availability during PreTrainedModel.__init__()
    # via _flash_attn_2_can_dispatch() — if flash_attn is not installed, ImportError
    # is raised BEFORE the model instance is returned.
    #
    # The _attn_implementation reassignment in model.py L206 runs AFTER load_model()
    # returns, so it cannot prevent that early failure.  This patch intercepts the
    # config BEFORE loading and downgrades to "tensor_cast", letting the HF loader
    # skip the flash_attn check.
    # Only downgrades when flash_attn is genuinely absent to respect environments
    # that do have it installed.
    # ----------------------------------------------------------------
    def _downgrade_attn_implementation(cfg):
        if getattr(cfg, "_attn_implementation", None) == "flash_attention_2":
            if importlib.util.find_spec("flash_attn") is None:
                logger.warning(
                    "Flash Attention 2 is requested but not installed. "
                    "Falling back to 'tensor_cast' attention implementation for simulation."
                )
                cfg._attn_implementation = "tensor_cast"
                return True
        return False

    text_downgraded = _downgrade_attn_implementation(config)
    if hasattr(config, "vision_config"):
        vision_downgraded = _downgrade_attn_implementation(config.vision_config)
        if vision_downgraded:
            text_downgraded = True

    if text_downgraded:
        patched = True

    # ----------------------------------------------------------------
    # Patch 3: Bridge vision config attributes for input generator
    # ----------------------------------------------------------------
    # WHY:   Kimi K2.5 vision config uses different attribute names
    #        (``merge_kernel_size``) or omits attributes altogether
    #        (``temporal_patch_size``, ``in_channels``).  The generic
    #        image-input generator expects these attributes and fails
    #        with AttributeError inside transformers v5.x due to
    #        hasattr/__getattribute__ mismatch.
    # WITHOUT: AttributeError on spatial_merge_size /
    #          temporal_patch_size / in_channels.
    if hasattr(config, "vision_config") and config.vision_config is not None:
        vc = config.vision_config

        if hasattr(vc, "merge_kernel_size"):
            mk = vc.merge_kernel_size
            vc.spatial_merge_size = mk[0] if isinstance(mk, (list, tuple)) else mk
            patched = True

        if not hasattr(vc, "temporal_patch_size"):
            vc.temporal_patch_size = 1
            patched = True

        if not hasattr(vc, "in_channels"):
            vc.in_channels = 3
            patched = True

    return patched


def _patch_model_classes_for_kimi_k25(config, model_id):
    """Monkey-patch *remote* model classes before model instantiation.

    These patches modify **class-level methods** (not instances), so they
    MUST run before the HF loader constructs model objects from the dynamic
    module.  Once the model is loaded, class monkey-patches have no effect
    on already-instantiated objects.

    Requires ``model_id`` to locate and import the remote modeling files.
    """
    import logging
    import sys
    from transformers.dynamic_module_utils import get_class_from_dynamic_module
    from typing import Optional, Tuple

    logger = logging.getLogger(__name__)

    model_type = getattr(config, "model_type", None)
    if model_type != "kimi_k25" or model_id is None:
        return False

    patched = False

    # ----------------------------------------------------------------
    # Patch 4a: Windows SIGALRM — resolve trust_remote_code without alarm
    # ----------------------------------------------------------------
    # WHY:   Windows lacks signal.SIGALRM, which breaks transformers'
    #        trust_remote_code interactive prompt.  Default
    #        trust_remote_code=True on platforms without SIGALRM so
    #        that headless simulation never blocks on stdin.
    #        This was previously a global monkey-patch in utils.py;
    #        moved here to limit its scope to Kimi K2.5 only.
    # WITHOUT: Blocking on stdin / AttributeError from signal.SIGALRM
    #          when loading remote model code.
    import signal as _signal

    if not hasattr(_signal, "SIGALRM"):
        import transformers.dynamic_module_utils

        _orig_resolve = transformers.dynamic_module_utils.resolve_trust_remote_code
        if not getattr(_orig_resolve, "_tensor_cast_patched", False):

            def _patched_resolve(trust_remote_code, *args, **kwargs):
                if trust_remote_code is None:
                    trust_remote_code = True
                return _orig_resolve(trust_remote_code, *args, **kwargs)

            _patched_resolve._tensor_cast_patched = True
            transformers.dynamic_module_utils.resolve_trust_remote_code = _patched_resolve

    try:
        # ----------------------------------------------------------------
        # Patch 4: Filter KimiK25ForConditionalGeneration.forward kwargs
        # ----------------------------------------------------------------
        # WHY:   TensorCast injects extra kwargs (attention_meta,
        #        kv_cache_by_layers, etc.) via model_runner, but the
        #        original VL forward only accepts standard HF keys.
        #        Passing unexpected kwargs causes TypeError.
        # WITHOUT: TypeError from unexpected keyword arguments.
        class_ref_vl = "modeling_kimi_k25.KimiK25ForConditionalGeneration"
        vl_cls = get_class_from_dynamic_module(class_ref_vl, model_id, force_download=False)

        if not hasattr(vl_cls, "_original_vl_forward"):
            vl_cls._original_vl_forward = vl_cls.forward

        _STANDARD_VL_FORWARD_KEYS = frozenset(
            {
                "input_ids",
                "pixel_values",
                "grid_thws",
                "attention_mask",
                "position_ids",
                "past_key_values",
                "inputs_embeds",
                "labels",
                "use_cache",
                "output_attentions",
                "output_hidden_states",
                "return_dict",
            }
        )

        def patched_vl_forward(self, *args, **kwargs):
            # Inject TC kwargs into attention layers BEFORE calling the
            # original forward (which filters them out).  The decoder
            # (patched by P10) reads them back from _extra_forward_kwargs.
            from tensor_cast.transformers.model import _EXTRA_TC_KWARGS_KEYS

            _tc_extra = {k: kwargs[k] for k in _EXTRA_TC_KWARGS_KEYS if k in kwargs and kwargs[k] is not None}
            _injected = []
            if _tc_extra:
                try:
                    for layer in self.language_model.model.layers:
                        if hasattr(layer, 'self_attn'):
                            layer.self_attn._extra_forward_kwargs = _tc_extra
                            _injected.append(layer.self_attn)
                except AttributeError as e:
                    logger.warning(
                        "Failed to inject TC kwargs into attention layers: %s. "
                        "This may affect tensor casting for Kimi K2.5 vision-language model.",
                        e,
                    )

            hf_kwargs = {k: v for k, v in kwargs.items() if k in _STANDARD_VL_FORWARD_KEYS}
            # The generic input generator uses "image_grid_thw" but
            # Kimi K2.5's forward expects "grid_thws".
            if "grid_thws" not in hf_kwargs and "image_grid_thw" in kwargs:
                hf_kwargs["grid_thws"] = kwargs["image_grid_thw"]
            return vl_cls._original_vl_forward(self, *args, **hf_kwargs)

        vl_cls.forward = patched_vl_forward

        # ----------------------------------------------------------------
        # Patch 5: _merge_input_ids_with_image_features (meta device)
        # ----------------------------------------------------------------
        # WHY:   During torch.compile graph capture, input_ids live on
        #        the 'meta' device.  The original merge function calls
        #        embedding layers which raise on meta tensors.  This
        #        patch returns a correctly-shaped meta embedding to
        #        keep the graph tracer happy.
        # WITHOUT: RuntimeError from operations on meta tensors.
        if not hasattr(vl_cls, "_original_merge_input_ids_with_image_features"):
            vl_cls._original_merge_input_ids_with_image_features = vl_cls._merge_input_ids_with_image_features

        def patched_merge_input_ids_with_image_features(
            self,
            image_features,
            feature_lens,
            input_ids,
            attention_mask=None,
            position_ids=None,
            labels=None,
        ):
            batch_size, sequence_length = input_ids.shape
            if input_ids.device.type == 'meta':
                embed_dim = (
                    image_features[0].shape[-1] if len(image_features) > 0 else self.config.text_config.hidden_size
                )
                return (
                    torch.empty(batch_size, sequence_length, embed_dim, device='meta', dtype=self.dtype),
                    attention_mask,
                    labels,
                    position_ids,
                )

            return vl_cls._original_merge_input_ids_with_image_features(
                self,
                image_features,
                feature_lens,
                input_ids,
                attention_mask,
                position_ids,
                labels,
            )

        vl_cls._merge_input_ids_with_image_features = patched_merge_input_ids_with_image_features
        patched = True

    except Exception as e:
        logger.warning(f"Could not patch remote VL / attention class attributes: {e}")

    try:
        # ----------------------------------------------------------------
        # Patch 6: MoonViT3dEncoder — add deterministic attn flag & adapter
        # ----------------------------------------------------------------
        # WHY:   The remote encoder checks ``self.use_deterministic_attn``
        #        but never defines it.  We must add the attribute so the
        #        check doesn't fail.  Additionally, we register a
        #        'tensor_cast' attention adapter that handles meta tensors
        #        and avoids O(n²) computation for very long sequences.
        # WITHOUT: AttributeError for missing use_deterministic_attn;
        #          KeyError for missing "tensor_cast" attention backend.
        class_ref_enc = "modeling_kimi_k25.MoonViT3dEncoder"
        encoder_cls = get_class_from_dynamic_module(class_ref_enc, model_id, force_download=False)
        if not hasattr(encoder_cls, "use_deterministic_attn"):
            setattr(encoder_cls, "use_deterministic_attn", False)
            patched = True

            for name, module in sys.modules.items():
                if "moonshotai" in name and "modeling_kimi_k25" in name:
                    if hasattr(module, "VL_VISION_ATTENTION_FUNCTIONS"):

                        def visual_tc_adapter(
                            q,
                            k,
                            v,
                            q_cu_seqlens,
                            k_cu_seqlens,
                            max_seqlen_q,
                            max_seqlen_k,
                            deterministic=False,
                        ):
                            import math

                            seq_length = q.shape[0]
                            num_heads = q.shape[1]
                            head_dim = q.shape[-1]

                            if q.device.type == 'meta':
                                # -------------------------------------------------------
                                # Call the fused tensor_cast.attention op so that
                                # `tensor_cast.attention.default` appears in the chrome
                                # trace, enabling accurate analytic performance modeling.
                                #
                                # Shape mapping (varlen → tensor_cast convention):
                                #   q: (seq_len, num_heads, head_dim)
                                #      → query: (seq_len, num_heads * head_dim)
                                #   k: (seq_len, num_heads, head_dim)
                                #      → key:   (seq_len, num_heads, head_dim)
                                #   v: (seq_len, num_heads, head_dim)
                                #      → value: (seq_len, num_heads, head_dim)
                                #
                                # Metadata is passed as None — matching the standard
                                # visual attention path in flash_attention_forward()
                                # (attention.py L80: attention_meta = None).  The
                                # performance model falls back to deriving seq_lens
                                # and query_lens from query.shape, avoiding
                                # .item() on meta tensors.
                                # -------------------------------------------------------
                                query = q.reshape(seq_length, num_heads * head_dim)
                                return torch.ops.tensor_cast.attention(
                                    query,
                                    k,
                                    v,
                                    None,  # attention_mask
                                    None,  # block_table
                                    None,  # query_start_loc
                                    None,  # seq_lens
                                    None,  # query_lens
                                )

                            if seq_length > 4096:
                                logger.warning(
                                    "Visual attention sequence length %d exceeds safe "
                                    "threshold. Skipping O(n²) attention to avoid OOM.",
                                    seq_length,
                                )
                                return torch.zeros(
                                    seq_length,
                                    num_heads * head_dim,
                                    device=q.device,
                                    dtype=q.dtype,
                                )

                            # Build causal-like attention mask: allow attention
                            # only within each image chunk (diagonal blocks).
                            # Using -inf for masked positions (correct additive mask)
                            # instead of boolean True/False (which would add 1.0/0.0).
                            attention_mask = torch.full(
                                [1, seq_length, seq_length],
                                float('-inf'),
                                device=q.device,
                                dtype=q.dtype,
                            )

                            q_cu_seqlens_list = q_cu_seqlens.tolist()
                            for i in range(1, len(q_cu_seqlens_list)):
                                start = q_cu_seqlens_list[i - 1]
                                end = q_cu_seqlens_list[i]
                                attention_mask[..., start:end, start:end] = 0.0

                            q = q.transpose(0, 1)
                            k = k.transpose(0, 1)
                            v = v.transpose(0, 1)

                            attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
                            attn_weight += attention_mask
                            attn_weight = torch.softmax(
                                attn_weight,
                                dim=-1,
                                dtype=torch.float32,
                            ).to(q.dtype)

                            attn_output = attn_weight @ v
                            attn_output = attn_output.transpose(0, 1)
                            attn_output = attn_output.reshape(seq_length, -1)
                            return attn_output

                        module.VL_VISION_ATTENTION_FUNCTIONS["tensor_cast"] = visual_tc_adapter
                        module.VL_VISION_ATTENTION_FUNCTIONS["eager"] = visual_tc_adapter
                        break

        # ----------------------------------------------------------------
        # Patch 7: DeepseekV3MoE — stub forward for graph tracing
        # ----------------------------------------------------------------
        # WHY:   The real MoE forward contains dynamic dispatch logic
        #        (expert selection + token routing + expert combine)
        #        that torch.compile cannot trace — the control flow
        #        depends on runtime token-to-expert assignments.
        #        Additionally the real computations (matmul across all
        #        experts) would OOM during graph capture.  This stub
        #        returns correct shapes without executing any experts,
        #        allowing compile to proceed.  The actual performance
        #        modeling is handled later by transformations.patch_moe()
        #        which wraps these modules with fused MoELayer wrappers.
        # WITHOUT: torch.compile failure (untraceable dynamic dispatch)
        #          or OOM during graph capture.
        class_ref_moe = "modeling_deepseek.DeepseekV3MoE"
        moe_cls = get_class_from_dynamic_module(class_ref_moe, model_id, force_download=False)

        def patched_forward(_self, hidden_states):
            return torch.zeros_like(hidden_states)

        def patched_moe_infer(_self, x, _topk_ids, _topk_weight):
            return torch.zeros_like(x)

        if not hasattr(moe_cls, "_original_forward"):
            moe_cls._original_forward = moe_cls.forward
        moe_cls.forward = patched_forward

        if not hasattr(moe_cls, "_original_moe_infer"):
            moe_cls._original_moe_infer = moe_cls.moe_infer
        moe_cls.moe_infer = patched_moe_infer

        # ----------------------------------------------------------------
        # Patch 8: MoEGate — deterministic routing for simulation
        # ----------------------------------------------------------------
        # WHY:   The real gate performs top-k softmax + random sampling
        #        which is non-deterministic and untraceable.  We replace
        #        it with equal-weight routing to produce deterministic
        #        shapes during graph capture.
        # WITHOUT: Non-deterministic / un-traceable routing logic during
        #          torch.compile; shape mismatches downstream.
        class_ref_gate = "modeling_deepseek.MoEGate"
        gate_cls = get_class_from_dynamic_module(class_ref_gate, model_id, force_download=False)

        def patched_gate_forward(self, hidden_states, **kwargs):
            if hidden_states.dim() == 3:
                bsz, seq_len, _ = hidden_states.shape
            else:
                bsz = hidden_states.shape[0]
                seq_len = 1
            device = hidden_states.device
            dtype = hidden_states.dtype
            top_k = self.top_k
            topk_idx = torch.zeros(bsz * seq_len, top_k, dtype=torch.long, device=device)
            topk_weight = torch.ones(bsz * seq_len, top_k, dtype=dtype, device=device) / top_k
            return topk_idx, topk_weight

        gate_cls.forward = patched_gate_forward

        # ----------------------------------------------------------------
        # Patch 9: Monkey-patch _resolve_position_embeddings onto MLA
        # ----------------------------------------------------------------
        # WHY:   Kimi K2.5's decoder only passes ``position_ids``, not
        #        pre-computed RoPE (cos, sin) tensors.  TensorCast MLA
        #        needs explicit position_embeddings.  This method
        #        computes them from position_ids via the rotary_emb
        #        cache.  Moved here from layers/mla.py to keep the
        #        generic MLA layer free of model-specific logic.
        # WITHOUT: Missing RoPE → simulation results inaccurate
        #          (cos=1, sin=0 fallback).
        from tensor_cast.layers.mla import MultiheadLatentAttentionTensorCast

        if not hasattr(MultiheadLatentAttentionTensorCast, "_patched_rope_resolve"):

            def _patched_resolve_position_embeddings(
                self,
                hidden_states: torch.Tensor,
                position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]],
                **kwargs,
            ) -> tuple[torch.Tensor, torch.Tensor]:
                """Compute position_embeddings from position_ids when not explicitly provided.

                This provides compatibility when the caller (e.g. patched decoder forward)
                only passes ``position_ids`` instead of pre-computed RoPE tensors.
                The resolved (cos, sin) tuple is always returned.
                """
                if position_embeddings is not None:
                    return position_embeddings

                position_ids = kwargs.get("position_ids", None)

                if position_ids is not None and self._has_rotary_emb and hidden_states.device.type != 'meta':
                    max_pos = position_ids.max().item() + 1
                    if hasattr(self.rotary_emb, "cos_cached"):
                        if self.rotary_emb.cos_cached.shape[0] < max_pos:
                            self.rotary_emb._update_cos_sin_tables(max_pos, hidden_states.device, hidden_states.dtype)
                    cos = self.rotary_emb.cos_cached[position_ids].to(hidden_states.dtype)
                    sin = self.rotary_emb.sin_cached[position_ids].to(hidden_states.dtype)
                    return (cos, sin)

                # No position info available → neutral RoPE (identity rotation).
                if self._has_rotary_emb:
                    import warnings

                    warnings.warn(
                        "position_embeddings was not provided and position_ids is unavailable; "
                        "RoPE will be disabled (cos=1, sin=0). If this model uses RoPE-based "
                        "attention, simulation results may be inaccurate.",
                        RuntimeWarning,
                        stacklevel=2,
                    )

                seq_len = hidden_states.shape[1]
                dim = self.qk_rope_head_dim
                cos = torch.ones(seq_len, dim, device=hidden_states.device, dtype=hidden_states.dtype)
                sin = torch.zeros(seq_len, dim, device=hidden_states.device, dtype=hidden_states.dtype)
                return (cos, sin)

            MultiheadLatentAttentionTensorCast._resolve_position_embeddings = _patched_resolve_position_embeddings
            MultiheadLatentAttentionTensorCast._patched_rope_resolve = True
            patched = True

        # ----------------------------------------------------------------
        # Patch 10: DeepseekV3DecoderLayer — bridge HF ↔ TensorCast MLA
        # ----------------------------------------------------------------
        # WHY:   (a) The original decoder unpacks 3 values from self_attn
        #            but the TensorCast MLA wrapper returns 2 (no attn
        #            weights).  This patch handles both return conventions.
        #        (b) Kimi K2.5 computes RoPE internally but TensorCast MLA
        #            needs explicit (cos, sin) position_embeddings.
        #        (c) The patched VL forward filters out tensor_cast-
        #            specific kwargs; we recover them from
        #            _extra_forward_kwargs (injected by model_runner).
        # WITHOUT: ValueError from tuple unpacking; missing RoPE;
        #          missing attention_meta leading to broken KV cache ops.
        class_ref_decoder = "modeling_deepseek.DeepseekV3DecoderLayer"
        decoder_cls = get_class_from_dynamic_module(class_ref_decoder, model_id, force_download=False)

        # ----------------------------------------------------------------
        # Patch 10a: Register 'tensor_cast' in ATTENTION_CLASSES
        # ----------------------------------------------------------------
        # WHY:   Patch 2 downgrades config._attn_implementation from
        #        'flash_attention_2' to 'tensor_cast'.  Later, MTP block
        #        creation calls ATTENTION_CLASSES[config._attn_implementation]
        #        which only knows 'eager' / 'sdpa' / 'flash_attention_2'.
        # WITHOUT: KeyError: 'tensor_cast' during MTP block construction.
        import sys

        remote_module = sys.modules.get(decoder_cls.__module__)
        if remote_module is not None and hasattr(remote_module, 'ATTENTION_CLASSES'):
            if 'tensor_cast' not in remote_module.ATTENTION_CLASSES:
                fallback = remote_module.ATTENTION_CLASSES.get('sdpa') or remote_module.ATTENTION_CLASSES.get('eager')
                if fallback is None:
                    raise ValueError(
                        f"ATTENTION_CLASSES lacks 'sdpa' or 'eager' fallback. "
                        f"Available: {list(remote_module.ATTENTION_CLASSES.keys())}"
                    )
                remote_module.ATTENTION_CLASSES['tensor_cast'] = fallback

        if not hasattr(decoder_cls, "_original_decoder_forward"):
            decoder_cls._original_decoder_forward = decoder_cls.forward

        def patched_decoder_forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            output_attentions: Optional[bool] = False,
            use_cache: Optional[bool] = False,
            **kwargs,
        ):
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)

            # Resolve position_embeddings (cos, sin) for TensorCast MLA.
            position_embeddings = kwargs.pop("position_embeddings", None)
            if position_embeddings is None:
                # Lazy-initialize _has_rotary_emb (moved from mla.py __init__
                # to avoid polluting the generic MLA layer).
                if not hasattr(self.self_attn, '_has_rotary_emb'):
                    self.self_attn._has_rotary_emb = hasattr(self.self_attn._inner, "rotary_emb")
                    if not self.self_attn._has_rotary_emb:
                        import warnings

                        warnings.warn(
                            f"MLA module '{type(self.self_attn._inner).__name__}' "
                            "lacks 'rotary_emb'.  If position_embeddings is not "
                            "provided at forward time, RoPE will be disabled "
                            "(cos=1, sin=0), producing incorrect results for "
                            "RoPE-dependent models.",
                            RuntimeWarning,
                            stacklevel=2,
                        )
                position_embeddings = self.self_attn._resolve_position_embeddings(
                    hidden_states, None, position_ids=position_ids, **kwargs
                )

            # Recover tensor_cast-specific kwargs filtered by the VL forward.
            if "attention_meta" not in kwargs:
                extra_kwargs = getattr(self.self_attn, '_extra_forward_kwargs', None)
                if extra_kwargs is not None and extra_kwargs.get('attention_meta') is not None:
                    for k, v in extra_kwargs.items():
                        if k not in kwargs and v is not None:
                            kwargs[k] = v

            attn_result = self.self_attn(
                hidden_states=hidden_states,
                position_embeddings=position_embeddings,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                **kwargs,
            )
            if isinstance(attn_result, tuple) and len(attn_result) == 2:
                hidden_states, present_key_value = attn_result
                self_attn_weights = None
            else:
                hidden_states, self_attn_weights, present_key_value = attn_result

            hidden_states = residual + hidden_states
            residual = hidden_states
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)
            hidden_states = residual + hidden_states

            outputs = (hidden_states,)
            if output_attentions:
                outputs += (self_attn_weights,)
            if use_cache:
                outputs += (present_key_value,)
            return outputs

        decoder_cls.forward = patched_decoder_forward
        patched = True

        # ----------------------------------------------------------------
        # Patch 11: MoonVision3dPatchEmbed — support 2D (flattened) input
        # ----------------------------------------------------------------
        # WHY:   During simulation, vision tokens may arrive as a flat 2D
        #        tensor (total_tokens, channels) rather than the original
        #        3D patches.  The original Conv2d projection expects 4D
        #        input.  This patch reshapes 2D input back to 4D chunks
        #        and uses linear projection instead.
        # WITHOUT: RuntimeError from Conv2d receiving 2D input.
        class_ref_patch_embed = "modeling_kimi_k25.MoonVision3dPatchEmbed"
        patch_embed_cls = get_class_from_dynamic_module(
            class_ref_patch_embed,
            model_id,
            force_download=False,
        )

        if not hasattr(patch_embed_cls, "_original_patch_embed_forward"):
            patch_embed_cls._original_patch_embed_forward = patch_embed_cls.forward

        def patched_patch_embed_forward(
            self,
            x: torch.Tensor,
            grid_thws: torch.Tensor,
        ) -> torch.Tensor:
            if x.dim() == 2:
                hidden_dim = x.shape[1]
                total_tokens = 0
                reshaped_parts = []
                out_dim, in_channels, kH, kW = self.proj.weight.shape
                expected_hidden_dim = in_channels * kH * kW
                if hidden_dim != expected_hidden_dim:
                    raise ValueError(
                        f"Hidden dim mismatch: input has {hidden_dim}, "
                        f"but proj expects {expected_hidden_dim} "
                        f"(in_channels={in_channels}, kernel_size=({kH}, {kW}))"
                    )

                for t, h, w in grid_thws.tolist():
                    num_tokens = t * h * w
                    part = x[total_tokens : total_tokens + num_tokens]
                    part = part.view(num_tokens, in_channels, kH, kW)
                    linear_weight = self.proj.weight.view(
                        out_dim,
                        in_channels * kH * kW,
                    )
                    projected = torch.nn.functional.linear(
                        part.reshape(num_tokens, -1),
                        linear_weight,
                        self.proj.bias,
                    )
                    reshaped_parts.append(projected)
                    total_tokens += num_tokens

                x = torch.cat(reshaped_parts, dim=0)
            else:
                x = self.proj(x).view(x.size(0), -1)
            x = self.pos_emb(x, grid_thws)
            return x

        patch_embed_cls.forward = patched_patch_embed_forward
        patched = True

        # ----------------------------------------------------------------
        # Patch 12: Fix expert counts on root config
        # ----------------------------------------------------------------
        # WHY:   Kimi K2.5 stores ``n_routed_experts`` and
        #        ``n_shared_experts`` inside ``text_config`` rather than
        #        at the root level.  The downstream MoE patching logic
        #        (transformations.patch_moe) reads them from the root
        #        config object.
        # WITHOUT: AttributeError when patch_moe tries to read
        #          ``config.n_routed_experts`` / ``config.n_shared_experts``.
        if config is not None:
            if hasattr(config, "text_config") and hasattr(config.text_config, "n_routed_experts"):
                setattr(
                    config,
                    "n_routed_experts",
                    config.text_config.n_routed_experts,
                )

            if not hasattr(config, "n_routed_experts"):
                setattr(config, "n_routed_experts", 384)
                logger.warning(
                    "n_routed_experts not found in config or text_config; "
                    "falling back to default value 384. "
                    "Verify that the model's expert count matches this default."
                )

            if not hasattr(config, "n_shared_experts"):
                if hasattr(config, "text_config") and hasattr(config.text_config, "n_shared_experts"):
                    setattr(
                        config,
                        "n_shared_experts",
                        config.text_config.n_shared_experts,
                    )
                else:
                    setattr(config, "n_shared_experts", 1)
                    logger.warning(
                        "n_shared_experts not found in config or text_config; "
                        "falling back to default value 1. "
                        "Verify that the model's shared expert count matches this default."
                    )

        # ----------------------------------------------------------------
        # Patch 13: ModelWrapper — add output_intermediate_hidden_states for MTP
        #           and apply selected_token_indices for prefill token pruning
        # ----------------------------------------------------------------
        # WHY:   (a) The generic ``ModelWrapper`` only returns a single tensor
        #        from ``forward()``, but ``MtpWrapper`` expects
        #        ``(logits, hidden_states)`` when MTP is enabled.
        #        (b) ``ModelWrapper`` delegates directly to the HF model
        #        which has its own internal ``lm_head`` — it cannot apply
        #        ``selected_token_indices`` *before* the lm_head like
        #        ``CausalLmWrapper`` can.  Instead we apply it *after* the
        #        HF model's forward to select only the desired logit rows.
        #        (c) For the MTP branch we prune logits only; intermediate
        #        hidden_states stay full-width for rotary_emb and proposal row selection.
        #
        #        The patch is additive and backwards-compatible: the default
        #        path (no MTP, no selected_indices) is identical to the
        #        original behaviour.
        # WITHOUT: (a) ``AssertionError: Can't unpack a tensor of 1 rows
        #          into a tuple of 2 elements`` in ``MtpWrapper.forward()``.
        #          (b) 42000×7168×163840 lm_head matmul instead of
        #          12×7168×163840 during prefill, inflating compute cost
        #          ~3500×.
        from tensor_cast.transformers.model import ModelWrapper
        from tensor_cast.layers.sampler import _has_explicit_selected_token_indices, select_lm_head_hidden_states

        if not hasattr(ModelWrapper, "_patched_for_mtp"):
            _original_mw_forward = ModelWrapper.forward

            def patched_mw_forward(
                self,
                input_ids: Optional[torch.Tensor],
                position_ids: torch.Tensor,
                inputs_embeds: Optional[torch.Tensor] = None,
                output_intermediate_hidden_states: bool = False,
                **kwargs: object,
            ):
                # Extract sampling_metadata from generate_inputs(); spec decode uses it for target row selection.
                sampling_metadata = kwargs.get("sampling_metadata")

                if output_intermediate_hidden_states:
                    has_image_input = kwargs.get("pixel_values") is not None or kwargs.get("image_grid_thw") is not None
                    if not has_image_input and inputs_embeds is None and hasattr(self._inner, "language_model"):
                        # MTP text path: keep full intermediate hidden states for rotary/proposal
                        # selection, but prune target rows before the internal lm_head.
                        from tensor_cast.transformers.model import _EXTRA_TC_KWARGS_KEYS

                        lm = self._inner.language_model
                        _tc_extra = {
                            k: kwargs[k] for k in _EXTRA_TC_KWARGS_KEYS if k in kwargs and kwargs[k] is not None
                        }
                        if _tc_extra:
                            for layer in lm.model.layers:
                                if hasattr(layer, "self_attn"):
                                    layer.self_attn._extra_forward_kwargs = _tc_extra

                        body_outputs = lm.model(
                            input_ids=input_ids,
                            position_ids=position_ids,
                            use_cache=False,
                            return_dict=True,
                        )
                        intermediate_hidden_states = body_outputs.last_hidden_state
                        hidden_states = select_lm_head_hidden_states(
                            intermediate_hidden_states,
                            sampling_metadata,
                            mode="target",
                        )
                        logits = lm.lm_head(hidden_states)
                        return logits, intermediate_hidden_states

                    # Fallback for image / embedding paths that require the full VL forward.
                    kwargs_with_hidden = {**kwargs, "output_hidden_states": True}
                    outputs = self._inner(
                        input_ids=input_ids,
                        use_cache=False,
                        position_ids=position_ids,
                        inputs_embeds=inputs_embeds,
                        return_dict=False,
                        **kwargs_with_hidden,
                    )
                    logits = outputs[0]
                    intermediate_hidden_states = outputs[1][-1]
                    logits = select_lm_head_hidden_states(logits, sampling_metadata, mode="target")
                    return logits, intermediate_hidden_states

                selected_indices = sampling_metadata.selected_token_indices if sampling_metadata is not None else None

                # Non-MTP path
                if (
                    _has_explicit_selected_token_indices(selected_indices)
                    and sampling_metadata.spec_decode_metadata is None
                    and inputs_embeds is None
                ):
                    # ------------------------------------------------------------
                    # Fix: Check whether image inputs are present.  If the user
                    # supplied pixel_values / image_grid_thw, we must route
                    # through the full VL forward (KimiK25ForConditionalGeneration)
                    # so that the visual encoder is executed and image features
                    # are merged with text embeddings.  Bypassing the VL forward
                    # would silently drop the image, producing wrong results
                    # and a misleading trace (no visual ops).
                    # ------------------------------------------------------------
                    has_image_input = kwargs.get("pixel_values") is not None or kwargs.get("image_grid_thw") is not None
                    if not has_image_input:
                        # Optimization: prune hidden_states BEFORE lm_head.
                        # Bypass the HF VL forward (KimiK25ForConditionalGeneration)
                        # and directly call the language model's transformer body,
                        # then apply lm_head on only the selected tokens.
                        # This avoids computing lm_head on all tokens.
                        #
                        # We must inject tensor_cast kwargs (attention_meta,
                        # kv_cache_by_layers, etc.) into each attention layer's
                        # _extra_forward_kwargs side-channel, replicating what
                        # Patch 4 does for the normal VL forward path.  Without
                        # this the MLA layers see None kv_cache and the
                        # performance estimator crashes.
                        from tensor_cast.transformers.model import _EXTRA_TC_KWARGS_KEYS

                        lm = self._inner.language_model
                        _tc_extra = {
                            k: kwargs[k] for k in _EXTRA_TC_KWARGS_KEYS if k in kwargs and kwargs[k] is not None
                        }
                        if _tc_extra:
                            for layer in lm.model.layers:
                                if hasattr(layer, "self_attn"):
                                    layer.self_attn._extra_forward_kwargs = _tc_extra

                        body_outputs = lm.model(
                            input_ids=input_ids,
                            position_ids=position_ids,
                            use_cache=False,
                            return_dict=True,
                        )
                        hidden_states = body_outputs.last_hidden_state
                        hidden_states = hidden_states.index_select(1, selected_indices)
                        logits = lm.lm_head(hidden_states)
                        return logits

                # Default / fallback path
                logits = _original_mw_forward(self, input_ids, position_ids, inputs_embeds, **kwargs)
                return select_lm_head_hidden_states(logits, sampling_metadata, mode="target")

            ModelWrapper.forward = patched_mw_forward
            ModelWrapper._patched_for_mtp = True
            patched = True

        # ----------------------------------------------------------------
        # Patch 14: DeepseekV3RotaryEmbedding — handle position_ids as seq_len
        # ----------------------------------------------------------------
        # WHY:   ``maybe_enable_mtp`` (line 228) runs BEFORE
        #        ``patch_rotary_emb`` (line 231), so ``MtpWrapper.__init__``
        #        captures the inner ``DeepseekV3RotaryEmbedding`` (not the
        #        ``CachingRotaryEmb`` wrapper that is applied later).
        #        The inner ``forward(x, seq_len)`` expects an integer
        #        ``seq_len``, but ``MtpWrapper`` passes ``position_ids``
        #        (a tensor).  This patch makes the inner rotary embedding
        #        tolerate a tensor ``seq_len`` by extracting its maximum
        #        value.
        # WITHOUT: ``TypeError: arange() received an invalid combination
        #          of arguments - got (Tensor, ...)`` at the
        #          ``rotary_emb`` call in ``MtpWrapper.forward()``.
        class_ref_rotary = "modeling_deepseek.DeepseekV3RotaryEmbedding"
        rotary_cls = get_class_from_dynamic_module(class_ref_rotary, model_id, force_download=False)

        if not hasattr(rotary_cls, "_patched_for_kimi_k25"):
            _original_rotary_forward = rotary_cls.forward

            def patched_rotary_forward(self, x, seq_len=None):
                if isinstance(seq_len, torch.Tensor):
                    # MtpWrapper passes position_ids (tensor) as seq_len.
                    # Determine the sequence-length integer for arange/slicing.
                    if seq_len.device.type == "meta":
                        # TorchDynamo tracing on meta: use config value as a
                        # safe upper bound. The cache will be rebuilt with the
                        # real max position at runtime.
                        max_pos = self.max_position_embeddings
                    else:
                        # Runtime (eager, after graph-break resume).
                        # +1 because position_ids are 0-based (e.g. [0..N-1]).
                        max_pos = int(seq_len.max().item()) + 1
                    if self.max_seq_len_cached is None or max_pos > self.max_seq_len_cached:
                        self._set_cos_sin_cache(
                            seq_len=max_pos,
                            device=x.device,
                            dtype=x.dtype,
                        )
                    return (
                        self.cos_cached[:max_pos].to(dtype=x.dtype),
                        self.sin_cached[:max_pos].to(dtype=x.dtype),
                    )
                return _original_rotary_forward(self, x, seq_len)

            rotary_cls.forward = patched_rotary_forward
            rotary_cls._patched_for_kimi_k25 = True
            patched = True

    except Exception as e:
        logger.warning(f"Could not patch remote modules: {e}")

    return patched


def _shard_lm_head_for_kimi_vl(model):
    """Manually apply ``ColumnParallelLinear`` to the nested lm_head.

    Kimi K2.5 is a VL model where ``lm_head`` lives inside the
    ``language_model`` submodule (``_inner.language_model.lm_head``),
    not at the top level.  The standard ``shard_model_by_tp`` uses a
    fnmatch pattern ``"lm_head"`` which only matches a top-level
    (unprefixed) name.  After ``strip_module_name``, the nested path
    becomes ``"language_model.lm_head"`` → no match → lm_head stays as
    a raw ``nn.Linear`` and escapes TP sharding.

    This function is called AFTER ``shard_model`` in the custom
    pipeline and replaces the still-unsharded lm_head with a
    ``ColumnParallelLinear`` that gathers output across the TP group.

    Args:
        model: A ``TransformerModel`` whose ``_inner`` is a ``ModelWrapper``
               wrapping the Kimi HF model.
    """
    from tensor_cast.layers.parallel_linear import ColumnParallelLinear

    pgm = model.parallel_group_manager
    lmhead_tp_group = pgm.lmhead_tp_group
    tp_group = pgm.tp_group

    if lmhead_tp_group.world_size <= 1:
        return  # No TP configured — nothing to do.

    # Two nested lm_head instances escape the standard
    # ``shard_model_by_tp`` fnmatch pattern ``"lm_head"``:
    # 1. VL model:  ``*language_model.lm_head``
    # 2. MTP block: ``*mtp.lm_head``
    #
    # Iterate all modules and shard every still-raw nn.Linear
    # whose path ends with one of those suffixes.
    _LMIHEAD_SUFFIXES = ("language_model.lm_head", "mtp.lm_head")
    for name, module in model._inner.named_modules():
        if isinstance(module, torch.nn.Linear) and name.endswith(_LMIHEAD_SUFFIXES):
            params = {
                "tp_group": lmhead_tp_group,
                "global_tp_group": tp_group,
                "gather_output": True,
            }
            parallel_module = ColumnParallelLinear(module, **params)
            model._replace_module(name, parallel_module)


_patched_kimi_k25 = False
_shard_model_patched = False


def _patch_shard_model_for_kimi_vl():
    """Monkey-patch ``shard_model`` to automatically shard nested lm_head.

    Kimi K2.5's ``lm_head`` lives at ``language_model.lm_head`` (not top-level),
    so the standard ``shard_model_by_tp`` fnmatch pattern ``"lm_head"``
    misses it (``strip_module_name`` yields ``"language_model.lm_head"``).

    This patch wraps ``shard_model`` to call ``_shard_lm_head_for_kimi_vl``
    after the standard sharding.

    IMPORTANT:  Two references must be patched because ``model.py`` imports
    ``shard_model`` via ``from ... import shard_model``, creating a local
    binding that bypasses a module-attribute monkey-patch.
    """
    global _shard_model_patched
    if _shard_model_patched:
        return

    from tensor_cast.transformers import transformations as _t
    from tensor_cast.transformers import model as _model

    _original_shard_model = _t.shard_model

    def _patched_shard_model(model):
        result = _original_shard_model(model)
        _shard_lm_head_for_kimi_vl(result)
        return result

    # Patch both references:
    # 1. transformations.shard_model — for callers that use the module attribute
    # 2. model.shard_model         — for model.py's ``from ... import shard_model``
    _t.shard_model = _patched_shard_model
    _model.shard_model = _patched_shard_model
    _shard_model_patched = True


# ----------------------------------------------------------------
# Patch 16: resize_image — Kimi K2.5 specific image resize logic
# ----------------------------------------------------------------
_resize_image_patched = False


def _patch_resize_image_for_kimi_k25(model_id):
    """Monkey-patch ``resize_image`` to use Kimi K2.5's resize logic.

    WHY:   The generic ``resize_image`` in ``input_generator.py`` delegates
           to Qwen2-VL's ``smart_resize``, which relies on the image
           processor's ``size`` attribute for min/max pixel limits.
           Kimi K2.5's ``KimiK25VisionProcessor`` (from remote code) does
           NOT expose a standard ``size`` attribute — it uses
           ``media_proc_cfg["in_patch_limit"]`` instead.  Without this
           patch, ``smart_resize`` falls back to its hardcoded defaults
           (``max_pixels=1_003_520``), which are too restrictive for
           Kimi K2.5's larger images (e.g. 1080×1920 = 2 073 600 pixels),
           causing the image to be incorrectly downscaled.

    HOW:   When ``model_id`` contains "kimi" (case-insensitive), the
           patched ``resize_image`` bypasses ``smart_resize`` entirely and
           computes resized dimensions directly by rounding the original
           image dimensions to multiples of ``patch_size * merge_size``.
           This preserves the full resolution (limited only by
           ``in_patch_limit``, which is generous enough for common
           resolutions).

    WITHOUT: Vision token count mismatch — e.g. 4888 tokens instead of
             the expected 10764 for a 1080×1920 image.
    """
    global _resize_image_patched
    if _resize_image_patched:
        return

    import logging

    logger = logging.getLogger(__name__)

    from tensor_cast.core import input_generator as _ig

    _original_resize_image = _ig.resize_image

    def _kimi_resize_image(
        mid,
        mtype,
        image_height,
        image_width,
        patch_size,
        merge_size,
        temporal_patch_size,
    ):
        # Only intercept Kimi K2.5 (model_id check is case-insensitive).
        if "kimi" not in mid.lower():
            return _original_resize_image(
                mid,
                mtype,
                image_height,
                image_width,
                patch_size,
                merge_size,
                temporal_patch_size,
            )

        # Kimi K2.5 does NOT use Qwen2-VL's smart_resize.
        # MoonViT processes images at (near) full resolution: dimensions are
        # simply rounded to multiples of ``patch_size * merge_size``.
        #
        # The processor's ``media_proc_cfg["in_patch_limit"]`` defines the
        # maximum number of patches (typically 16384), which translates to a
        # generous pixel budget (16384 * 14 * 14 = 3 211 264 px for
        # patch_size=14).  Common resolutions like 1080×1920 (2 073 600 px)
        # fall well within this limit, so no downscaling is needed.
        factor = patch_size * merge_size
        resized_height = ((image_height + factor - 1) // factor) * factor
        resized_width = ((image_width + factor - 1) // factor) * factor
        logger.info(
            "Kimi K2.5 image resize: %dx%d -> %dx%d (factor=%d, bypassed Qwen2-VL smart_resize)",
            image_height,
            image_width,
            resized_height,
            resized_width,
            factor,
        )
        return resized_height, resized_width

    _ig.resize_image = _kimi_resize_image
    _resize_image_patched = True


def _hf_config_patch_for_kimi_k25(config, model_id=None):
    """Pre-load entry point: apply HF config fixes, then model class patches.

    Called by :func:`AutoModelConfigLoader._apply_hf_config_patches` BEFORE
    the HuggingFace model is instantiated.

    The patching is split into two tiers:

    * **Per-config patches** (always run):  ``_attn_implementation``
      downgrade, vision-config attribute bridging, environment checks
      (e.g. ``is_torch_fx_available``).  These operate on the config
      *object* and MUST execute for every new config instance, even
      when class-level monkey-patches have already been applied.

    * **Class-level patches** (run once):  model-class monkey-patching,
      ``shard_model`` wrapping, ``resize_image`` patching.  These
      modify global state (module attributes / function references)
      and are guarded by ``_patched_kimi_k25`` to avoid redundant work.
    """
    import logging

    logger = logging.getLogger(__name__)

    model_type = getattr(config, "model_type", None)
    if model_type != "kimi_k25":
        return

    # ----------------------------------------------------------------
    # Phase 1 – config-level patches (always run for every new config)
    # ----------------------------------------------------------------
    config_patched = _patch_hf_config_for_kimi_k25(config)

    # ----------------------------------------------------------------
    # Phases 2-4 – class-level / global patches (run once per process)
    # ----------------------------------------------------------------
    # These modify module-level state (monkey-patches, function
    # references, etc.).  They are idempotent but expensive, so we
    # guard them with a global flag.
    global _patched_kimi_k25
    if _patched_kimi_k25:
        return

    # Phase 2 – model class monkey-patches (requires model_id).
    classes_patched = _patch_model_classes_for_kimi_k25(config, model_id)

    # Phase 3 – wrap shard_model to handle nested lm_head.
    _patch_shard_model_for_kimi_vl()

    # Phase 4 – patch resize_image for Kimi K2.5's image resize logic.
    _patch_resize_image_for_kimi_k25(model_id)

    if config_patched or classes_patched:
        _patched_kimi_k25 = True
        logger.info("Patched transformers environment for Kimi-K2.5")


register_model_profile(
    ModelProfile(
        model_type="kimi_k25",
        moe_module_name="DeepseekV3MoE",
        mla_module_name="DeepseekV3Attention",
        mtp_block_module_name="DeepseekV3DecoderLayer",
        moe_num_experts_key="n_routed_experts",
        language_layers_path_str="language_model.model.layers",
        visual_module_path="vision_tower",
        language_module_path="language_model",
        visual_layers_module_path="vision_tower.encoder.blocks",
        visual_layers_path_str="vision_tower.encoder.blocks",
        custom_expert_module_type=None,
        mla_field_names_override={
            "q_proj": "q_a_proj",
            "qk_head_dim": "q_head_dim",
        },
        hf_config_patch_method=_hf_config_patch_for_kimi_k25,
        # When DP≠EP, route is executed after DP slicing to avoid performance bloat caused by routing all tokens
        moe_route_after_dp_transform=True,
    )
)