# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import torch
import torch.distributed
import torch.nn as nn
from megatron.core import mpu
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig

from bridge.models.conversion.utils import get_module_and_param_from_name, remove_non_pickleables


WeightType = TypeVar("WeightType", torch.Tensor, Dict[str, torch.Tensor])

logger = logging.getLogger(__name__)


def get_pg_size(group=None):
    """Get world size for a distributed group.

    Args:
        group: Process group to get world size for. If None, uses default group.

    Returns:
        int: World size (1 if distributed not initialized or group is None, else group.size())
    """
    if not torch.distributed.is_initialized() or group is None:
        return 1
    return group.size()


def get_pg_rank(group=None):
    """Get rank for a distributed group.

    Args:
        group: Process group to get rank for. If None, uses default group.

    Returns:
        int: Rank (0 if distributed not initialized or group is None, else group.rank())
    """
    if not torch.distributed.is_initialized() or group is None:
        return 0
    return group.rank()


class MegatronParamMapping(ABC, Generic[WeightType]):
    """
    Abstract base class for weight conversion between Megatron and external formats.

    This class provides the foundation for all weight mappings, handling the complex
    conversions between Megatron-Core's distributed tensor formats and standard
    (typically HuggingFace) formats. Each concrete mapping implements specific
    transformation logic while inheriting common parallel communication patterns.

    Key responsibilities:
    - Format transformation (e.g., QKV merging/splitting, gated MLP handling)
    - Tensor parallel (TP) distribution and gathering across GPUs
    - Pipeline parallel (PP) broadcasting between pipeline stages
    - Wildcard pattern resolution for layer-wise mappings

    The mapping abstraction ensures that higher-level code doesn't need to know
    about the parallel topology or format differences - it just requests a
    conversion and the mapping handles all the complexity.

    Public helper methods for subclasses:
    - broadcast_from_pp_rank: Broadcast tensors across pipeline stages
    - broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks
    - broadcast_tensor_to_tp_ranks: Broadcast within TP group
    - scatter_to_tp_ranks: Distribute tensor shards to TP ranks
    - gather_from_tp_ranks: Collect tensor shards from TP ranks

    Example:
        .. code-block:: python

            class MyCustomMapping(MegatronParamMapping[torch.Tensor]):
                def hf_to_megatron(self, hf_weights, megatron_module):
                    # Custom transformation logic
                    transformed = hf_weights.t()  # Example: transpose
                    # Use helpers for distribution
                    return self.scatter_to_tp_ranks(...)

                def megatron_to_hf(self, megatron_weights, megatron_module):
                    # Broadcast from owning PP rank
                    weight = self.broadcast_from_pp_rank(megatron_weights)
                    # Gather from TP ranks and transform
                    gathered = self.gather_from_tp_ranks(weight)
                    return {"custom_weight": gathered[0].t()}
    """

    def __init__(self, megatron_param: str, hf_param: Union[str, Dict[str, str]]):
        """Initialize the weight mapping.

        Args:
            megatron_param (str): Megatron parameter name pattern (supports *
                wildcards).
            hf_param (Union[str, Dict[str, str]]): External format name pattern(s).
        """
        self.megatron_param = megatron_param
        self.hf_param = hf_param
        self._validate_patterns()

        # Cache for metadata and tensor_spec_output
        self._broadcast_obj_cache = {}
        self._tensor_spec_output_cache = {}

        if mpu.is_initialized():
            self.pp_group = mpu.get_pipeline_model_parallel_group()
            self.ep_group = mpu.get_expert_model_parallel_group()
            self._tp_group = mpu.get_tensor_model_parallel_group()
            self._etp_group = mpu.get_expert_tensor_parallel_group()
        else:
            self.pp_group = None
            self.ep_group = None
            self._tp_group = None
            self._etp_group = None

        # if a param mapping class takes in modified HF weight name from maybe_modify_loaded_hf_weight,
        # allow_hf_name_mismatch should be set to True to bypass a check in `build_conversion_tasks`
        self.allow_hf_name_mismatch = False

    @property
    def tp_group(self):
        """Get the tensor model parallel group."""
        if self.is_expert:
            return self._etp_group
        return self._tp_group

    @property
    def tp_rank(self) -> int:
        """Get the tensor model parallel rank."""
        return get_pg_rank(self.tp_group)

    @property
    def tp_size(self) -> int:
        """Get the tensor model parallel size."""
        return get_pg_size(self.tp_group)

    @property
    def pp_rank(self) -> int:
        """Get the pipeline model parallel rank."""
        return get_pg_rank(self.pp_group)

    @property
    def pp_size(self) -> int:
        """Get the pipeline model parallel size."""
        return get_pg_size(self.pp_group)

    @property
    def ep_rank(self) -> int:
        """Get the expert model parallel rank."""
        return get_pg_rank(self.ep_group)

    @property
    def ep_size(self) -> int:
        """Get the expert model parallel size."""
        return get_pg_size(self.ep_group)

    @property
    def etp_rank(self) -> int:
        """Get the expert tensor parallel rank."""
        return get_pg_rank(self.etp_group)

    @property
    def etp_size(self) -> int:
        """Get the expert tensor parallel size."""
        return get_pg_size(self.etp_group)

    @property
    def is_expert(self) -> bool:
        """Check if this mapping is for an expert parameter.

        Matches both TEGroupedMLP (.mlp.experts.linear_fc) and
        SequentialMLP (.mlp.experts.local_experts.*.linear_fc) patterns.
        """
        return ".mlp.experts.linear_fc" in self.megatron_param or ".mlp.experts.local_experts." in self.megatron_param

    def _resolve_names(self, captures: Tuple[str, ...]) -> Tuple[str, Union[str, Dict[str, str]]]:
        """Resolve wildcard patterns with captured values.

        Handles both ** (any characters) and * (digits) wildcards in order.
        ** patterns are processed before * patterns to avoid conflicts.
        """
        resolved_megatron_param = self.megatron_param
        capture_index = 0

        # First pass: resolve ** wildcards
        while "**" in resolved_megatron_param and capture_index < len(captures):
            resolved_megatron_param = resolved_megatron_param.replace("**", captures[capture_index], 1)
            capture_index += 1

        # Second pass: resolve * wildcards
        while "*" in resolved_megatron_param and capture_index < len(captures):
            resolved_megatron_param = resolved_megatron_param.replace("*", captures[capture_index], 1)
            capture_index += 1

        if isinstance(self.hf_param, str):
            resolved_hf_param = self.hf_param
            capture_index = 0

            # First pass: resolve ** wildcards
            while "**" in resolved_hf_param and capture_index < len(captures):
                resolved_hf_param = resolved_hf_param.replace("**", captures[capture_index], 1)
                capture_index += 1

            # Second pass: resolve * wildcards
            while "*" in resolved_hf_param and capture_index < len(captures):
                resolved_hf_param = resolved_hf_param.replace("*", captures[capture_index], 1)
                capture_index += 1
        else:
            resolved_hf_param = {}
            for k, v in self.hf_param.items():
                resolved_v = v
                capture_index = 0

                # First pass: resolve ** wildcards
                while "**" in resolved_v and capture_index < len(captures):
                    resolved_v = resolved_v.replace("**", captures[capture_index], 1)
                    capture_index += 1

                # Second pass: resolve * wildcards
                while "*" in resolved_v and capture_index < len(captures):
                    resolved_v = resolved_v.replace("*", captures[capture_index], 1)
                    capture_index += 1

                resolved_hf_param[k] = resolved_v

        return resolved_megatron_param, resolved_hf_param

    def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
        """Create a new mapping with resolved wildcards.

        This default implementation works for mappings with a
        (megatron_param, hf_param) constructor.

        Args:
            captures (Tuple[str, ...]): Captured wildcard values.

        Returns:
            MegatronParamMapping: A new mapping instance with resolved names.
        """
        resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
        return type(self)(resolved_megatron_param, resolved_hf_param)

    @abstractmethod
    def hf_to_megatron(
        self,
        hf_weights: WeightType,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Convert hf_weights TO Megatron format.

        This method handles:
        1. Format transformation (if needed)
        2. Tensor parallel distribution (if self.tp_size > 1)

        Args:
            hf_weights (WeightType): Source hf_weights in external format.
            megatron_module (nn.Module): Target Megatron module (for config
                access).

        Returns:
            torch.Tensor: Weight tensor ready for the current TP rank.
        """
        ...

    def broadcast_from_pp_rank(
        self, tensor: Optional[torch.Tensor], cache_key: Optional[str] = None
    ) -> Optional[torch.Tensor]:
        """Broadcast a tensor from the pipeline-parallel rank that owns it.

        Broadcasts to **all** PP ranks. This mirrors the behaviour of
        `broadcast_from_megatron_pp` in the original MMapping implementation and
        additionally keeps the tensor-parallel metadata (`tensor_model_parallel`,
        `partition_dim`) consistent on every rank.

        Args:
            tensor (Optional[torch.Tensor]): The local tensor if the current PP
                rank owns it. ``None`` otherwise.

        Returns:
            Optional[torch.Tensor]: The broadcasted tensor on every PP rank, or
                ``None`` if *no* PP rank owned the tensor (which indicates a bug
                in the calling code).
        """

        # Fast-path when we are not using pipeline parallelism.
        if self.pp_size == 1:
            return tensor

        # ------------------------------------------------------------------
        # 1.  Gather (shape, dtype, tensor_parallel flag, partition_dim) from
        #     every PP rank so that we can find the source rank.
        # ------------------------------------------------------------------
        if cache_key is not None and cache_key in self._tensor_spec_output_cache:
            tensor_spec_output = self._tensor_spec_output_cache[cache_key]
        else:
            if tensor is not None:
                shape = tensor.shape
                dtype = tensor.dtype
                tensor_parallel = getattr(tensor, "tensor_model_parallel", None)
                partition_dim = getattr(tensor, "partition_dim", None)
                tensor_spec = (shape, dtype, tensor_parallel, partition_dim)
            else:
                tensor_spec = None

            tensor_spec_output: list[Optional[tuple]] = [None] * self.pp_size
            torch.distributed.all_gather_object(tensor_spec_output, tensor_spec, group=self.pp_group)
            self._tensor_spec_output_cache[cache_key] = tensor_spec_output

        # ------------------------------------------------------------------
        # 2.  Identify the owning rank (the only rank with a non-None spec).
        # ------------------------------------------------------------------
        target_tensor_spec = None
        src_rank = None  # Rank *inside* the PP group.
        for rank, spec in enumerate(tensor_spec_output):
            if spec is not None:
                if target_tensor_spec is not None:
                    raise ValueError(f"Tensor exists on more than one PP rank. Found on ranks {src_rank} and {rank}.")
                target_tensor_spec = spec
                src_rank = rank

        if target_tensor_spec is None:
            # No rank had the tensor – this is an error in the caller.
            raise ValueError("Object must exist on at least one PP rank")

        # ------------------------------------------------------------------
        # 3.  Ensure every rank has an allocated tensor with the right shape
        #     and dtype before the broadcast.
        # ------------------------------------------------------------------
        if tensor is None:
            shape, dtype, tensor_parallel, partition_dim = target_tensor_spec
            # Use CPU by default, unless CUDA is available
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            tensor = torch.empty(shape, dtype=dtype, device=device)
            if tensor_parallel is not None:
                tensor.tensor_model_parallel = tensor_parallel
            if partition_dim is not None:
                tensor.partition_dim = partition_dim

        # ------------------------------------------------------------------
        # 4.  Broadcast from the source PP rank to all other PP ranks.
        # ------------------------------------------------------------------
        global_src = torch.distributed.get_global_rank(group=self.pp_group, group_rank=src_rank)
        torch.distributed.broadcast(tensor, src=global_src, group=self.pp_group)

        return tensor

    def broadcast_obj_from_pp_rank(self, obj: Optional[Any], cache_key: Optional[str] = None) -> Any:
        """Broadcast any Python object from the PP rank that owns it.

        This method is useful for broadcasting configuration objects or
        other metadata across pipeline parallel ranks. Results are cached
        after the first call to avoid redundant broadcasts.

        Args:
            obj (Optional[Any]): Object to broadcast (None on non-owning ranks).
            cache_key (Optional[str]): Optional cache key. If not provided,
                no caching will be performed.

        Returns:
            Any: Broadcasted object on all ranks.

        Raises:
            ValueError: If object exists on multiple ranks or no ranks.
        """
        if self.pp_size == 1:
            return obj

        # Check if we already have a cached result (only if cache_key is provided)
        if cache_key is not None and cache_key in self._broadcast_obj_cache:
            return self._broadcast_obj_cache[cache_key]

        # ------------------------------------------------------------------
        # 1. Gather presence flags from all PP ranks to find the source rank
        # ------------------------------------------------------------------
        has_obj = obj is not None
        obj_flags = [None] * self.pp_size
        torch.distributed.all_gather_object(obj_flags, has_obj, group=self.pp_group)

        # ------------------------------------------------------------------
        # 2. Identify the owning rank (the only rank with True flag)
        # ------------------------------------------------------------------
        src_rank = None  # Rank *inside* the PP group
        for rank, flag in enumerate(obj_flags):
            if flag:
                src_rank = rank

        if src_rank is None:
            raise ValueError("Object must exist on at least one PP rank")

        # ------------------------------------------------------------------
        # 3. Broadcast the object from the source rank to all ranks
        # ------------------------------------------------------------------
        if src_rank is None:
            raise ValueError("Could not determine source rank")

        # Use broadcast_object_list which is more robust than all_gather_object
        obj_list = [obj]
        pp_ranks = torch.distributed.get_process_group_ranks(self.pp_group)
        global_src = pp_ranks[src_rank]
        torch.distributed.broadcast_object_list(obj_list, src=global_src, group=self.pp_group)

        result = obj_list[0]

        # Cache the result for future calls (only if cache_key is provided)
        if cache_key is not None:
            self._broadcast_obj_cache[cache_key] = result

        return result

    def clear_broadcast_cache(self):
        """Clear the broadcast object cache.

        This can be useful for testing or if the objects being broadcast
        might change during the lifetime of the mapping.
        """
        self._broadcast_obj_cache.clear()

    def clear_tensor_spec_output_cache(self):
        """Clear the tensor spec output cache.

        This can be useful for testing or if the tensor spec output
        might change during the lifetime of the mapping.
        """
        self._tensor_spec_output_cache.clear()

    def broadcast_tensor_to_tp_ranks(self, tensor: torch.Tensor, src_rank: int = 0) -> torch.Tensor:
        """Broadcast a tensor to all TP ranks.

        Args:
            tensor (torch.Tensor): The tensor to broadcast.
            src_rank (int, optional): The source rank within the TP group.
                Defaults to 0.

        Returns:
            torch.Tensor: The broadcasted tensor.
        """
        if self.tp_size == 1:
            return tensor

        global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank)
        torch.distributed.broadcast(tensor, src=global_src, group=self.tp_group)
        return tensor

    def scatter_to_tp_ranks(
        self,
        splits: Optional[List[torch.Tensor]],
        output_shape: torch.Size,
        dtype: torch.dtype,
        device: torch.device,
        src_rank: int = 0,
    ) -> torch.Tensor:
        """Scatter tensor splits to TP ranks.

        Args:
            splits (Optional[List[torch.Tensor]]): A list of tensor shards to
                scatter. Only rank `src_rank` needs this.
            output_shape (torch.Size): The shape of the output tensor on each rank.
            dtype (torch.dtype): The data type of the output tensor.
            device (torch.device): The device for the output tensor.
            src_rank (int, optional): The source rank for the scatter operation.
                Defaults to 0.

        Returns:
            torch.Tensor: The scattered tensor shard on the current rank.
        """
        if self.tp_size == 1:
            return splits[0].to(device=device, dtype=dtype) if splits else None

        output = torch.empty(output_shape, dtype=dtype, device=device)
        global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank)

        scatter_list = None
        if self.tp_rank == src_rank and splits:
            scatter_list = [s.to(device=device, dtype=dtype) for s in splits]

        torch.distributed.scatter(
            output,
            scatter_list,
            src=global_src,
            group=self.tp_group,
        )
        return output

    def gather_from_tp_ranks(self, tensor: torch.Tensor) -> List[torch.Tensor]:
        """Gather tensors from all TP ranks.

        Args:
            tensor (torch.Tensor): The tensor shard to be gathered from the
                current rank.

        Returns:
            List[torch.Tensor]: A list of tensor shards from all TP ranks.
        """
        if self.tp_size == 1:
            return [tensor]

        gathered = [torch.empty_like(tensor) for _ in range(self.tp_size)]
        torch.distributed.all_gather(gathered, tensor, group=self.tp_group)
        return gathered

    def _count_wildcard_groups(self, pattern: str) -> int:
        """Count the number of wildcard capture groups in a pattern.

        Args:
            pattern: Pattern string with * and ** wildcards

        Returns:
            Number of capture groups that will be generated

        Note:
            ** counts as 1 group, * counts as 1 group
            ** must be counted before * to avoid double-counting
        """
        count = 0
        remaining = pattern

        # Count ** patterns first
        while "**" in remaining:
            count += 1
            remaining = remaining.replace("**", "", 1)

        # Count remaining * patterns
        count += remaining.count("*")

        return count

    def _validate_patterns(self):
        """Validate wildcard consistency between patterns."""
        megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param)
        if isinstance(self.hf_param, str):
            hf_param_wildcards = self._count_wildcard_groups(self.hf_param)
            if megatron_param_wildcards != hf_param_wildcards:
                raise ValueError(
                    f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has "
                    f"{megatron_param_wildcards} wildcards, hf_param='{self.hf_param}' has {hf_param_wildcards}"
                )
        else:
            for key, pattern in self.hf_param.items():
                hf_param_wildcards = self._count_wildcard_groups(pattern)
                if megatron_param_wildcards != hf_param_wildcards:
                    raise ValueError(
                        f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has "
                        f"{megatron_param_wildcards} wildcards, hf_param['{key}']='{pattern}' has {hf_param_wildcards}"
                    )

    def _normalize_expert_param_name(self, param_name: str) -> str:
        """Normalize expert parameter name by replacing trailing numbers with 0.
        e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0

        Args:
            param_name (str): Parameter name that may end with a number.

        Returns:
            str: Parameter name with trailing number replaced by 0.
        """
        # Use regex to replace any trailing number with 0
        return re.sub(r"\d+$", "0", param_name)

    def _get_config(self, module: nn.Module) -> Any:
        """Extract configuration from module hierarchy."""
        current = module
        while current is not None:
            if hasattr(current, "config"):
                return current.config
            # Try parent module
            if hasattr(current, "_parent"):
                current = current._parent
            else:
                # Walk up the module tree
                for parent_module in module.modules():
                    for _, child_module in parent_module.named_children():
                        if child_module is current:
                            current = parent_module
                            break
                    else:
                        continue
                    break
                else:
                    current = None

        raise ValueError(
            f"Could not find config in module hierarchy for {module.__class__.__name__}. "
            f"Ensure the module or its parent has a 'config' attribute."
        )

    def gather_from_ep_ranks(
        self,
        megatron_weights: Optional[torch.Tensor],
        megatron_module: Optional[MegatronModule],
        hf_param_name: Optional[str],
    ) -> Dict[str, torch.Tensor]:
        """Handle expert parallel weight gathering for MoE models.

        This method gathers expert weights across expert-parallel (EP) ranks and
        returns a mapping from HF parameter names to the corresponding tensors
        from each EP rank. Call this only for confirmed expert parameters
        (self.is_expert is True), typically after TP gathering/concatenation in
        the export path (Megatron → HF).

        Behavior and notation:
        - Let E be the total number of experts (e.g., config.num_moe_experts) and
          S be the expert-parallel size (ep_size). We assume E % S == 0.
        - Each EP rank owns E/S experts. For a given parameter name, we infer a
          local expert index L (0 ≤ L < E/S) on the current EP rank from the
          global expert id embedded in the name (works for both .weight and .bias).
        - The set of global expert ids that correspond to this local index L
          across all EP ranks is: {L + k * (E/S) | k ∈ [0, S-1]}.

        Communication and outputs:
        - We perform an all_gather over the EP group to collect the tensor from
          every EP rank into a list ordered by EP rank id.
        - For each EP rank k, we construct the HF parameter name by replacing the
          expert id in `hf_param_name` with (L + k * (E/S)), preserving the rest
          of the path, and map that name to the gathered tensor from rank k.

        Example:
        - E = 8, S = 2 → E/S = 4. Experts are distributed as:
          Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7].
          If the local index L = 0 (derived from the param name), this returns:
          {"...experts.0.weight": tensor_from_rank0, "...experts.4.weight": tensor_from_rank1}

        Args:
            megatron_weights (Optional[torch.Tensor]): The local expert weight tensor
                (after any TP handling) on this EP rank.
            megatron_module (Optional[MegatronModule]): The Megatron module containing
                configuration (used to determine E and E/S). Can be None on non-owning PP
                ranks; values will be broadcast across PP.
            hf_param_name (Optional[str]): HF parameter name template for the current
                (local) expert on this rank. The expert id within this string is replaced
                with the appropriate global expert ids for each EP rank.

        Returns:
            Dict[str, torch.Tensor]: Mapping from HF parameter names (one per EP rank)
            to the corresponding expert tensors gathered from each EP rank.
        """
        if megatron_module is None:
            num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank")
        else:
            model_config = self._get_config(megatron_module)
            num_experts = model_config.num_moe_experts
            num_experts_per_rank = num_experts // self.ep_size
            num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank")

        # Extract local expert number from parameter name
        # Handle both .weight and .bias suffixes
        local_expert_number = None
        for key in (".weight", ".bias"):
            if key in self.megatron_param:
                global_expert_number = int(self.megatron_param.split(key)[-1])
                local_expert_number = global_expert_number % num_experts_per_rank

        # Compute global expert numbers for all EP ranks
        # use regex to replace the local expert number with the global expert number
        gathered_expert_param_names = [
            re.sub(
                r"experts\.(\d+)", f"experts.{int(local_expert_number) + num_experts_per_rank * i}", str(hf_param_name)
            )
            for i in range(self.ep_size)
        ]
        if str(hf_param_name) not in gathered_expert_param_names:
            raise ValueError(
                f"Parameter name '{hf_param_name}' not found in expert parameter names. "
                f"Available names: {gathered_expert_param_names}"
            )

        # Gather weights from all EP ranks
        gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)]
        torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group)

        # this should be in the right order because of the all-gather
        weights_dict = {}
        for i, param_name in enumerate(gathered_expert_param_names):
            if param_name in weights_dict:
                weights_dict[param_name] = torch.cat(
                    [weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0
                )
            else:
                weights_dict[param_name] = gathered_weights[i].unsqueeze(0)
        for param_name in weights_dict:
            weights_dict[param_name] = weights_dict[param_name].squeeze()
        return weights_dict

    def maybe_dequantize(self, tensor: torch.Tensor) -> torch.Tensor:
        """Dequantize FP8 tensor if needed."""
        return tensor


class DirectMapping(MegatronParamMapping[torch.Tensor]):
    """Direct 1:1 weight mapping with no transformation or tensor parallelism."""

    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Direct copy - no transformation or distribution."""
        return hf_weights


class ColumnParallelMapping(MegatronParamMapping[torch.Tensor]):
    """
    Mapping for column-parallel linear and embedding weights.

    Column-parallel layers in Megatron split the output dimension across tensor
    parallel ranks. This is used for layers where each rank computes a portion
    of the output features independently, such as:
    - Embedding layers (split vocabulary)
    - Linear layers producing hidden states (e.g., QKV projections, MLP up projections)

    The weight matrix is partitioned along dimension 0 (rows), so each TP rank
    holds a subset of output features while maintaining all input features.

    **Sharding pattern**
    -   Original weight: `[output_features, input_features]`
    -   Rank 0: `[output_features/tp_size, input_features]`
    -   Rank 1: `[output_features/tp_size, input_features]`
    -   ...

    **Forward path (HuggingFace → Megatron)**
    1.  Validate divisibility: output dimension must be divisible by tp_size
    2.  Split: Chunk tensor along dim 0 into tp_size equal parts
    3.  Scatter: Distribute chunks to respective TP ranks

    **Reverse path (Megatron → HuggingFace)**
    1.  Broadcast: Ensure all PP ranks have the tensor
    2.  Gather: Collect chunks from all TP ranks
    3.  Concatenate: Reassemble along dim 0 on rank 0

    Example:
        .. code-block:: python

            # For a weight of shape [4096, 1024] with tp_size=4:
            # Each rank gets [1024, 1024] after column-parallel split
            mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight")
            megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module)
            # megatron_weights.shape = [1024, 1024] on each rank

    Note:
        This mapping also handles bias terms, which are 1D tensors split
        along their only dimension following the same pattern.
    """

    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Split weight along dim 0 and distribute to TP ranks."""
        if self.tp_size == 1:
            return hf_weights

        # Some parameters are named with global expert number, e.g. experts.weight15,
        # normalize it to experts.weight0, note we are only use the shape, dtype, device info,
        # not the actual value, so it is safe to do this.
        normalized_param = self._normalize_expert_param_name(self.megatron_param)
        _, target_param = get_module_and_param_from_name(megatron_module, normalized_param)

        # On rank 0, check for divisibility and split
        if self.tp_rank == 0:
            if hf_weights is None:
                raise ValueError("hf_weights should not be None on rank 0")

            # For MCore MambaMixer, A_log is initialized in FP32 but cast to BF16 when
            # saving ckpts, including the ckpt uploaded to HF. Without this cast,
            # self.scatter_to_tp_ranks will try to scatter the HF A_log weights in BF16 to
            # the Megatron tensor which is in FP32. This will error. So we cast before the scatter.
            if hf_weights.dtype != target_param.dtype:
                logger.warning(
                    f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. "
                    f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. "
                    f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. "
                )
                hf_weights = hf_weights.to(target_param.dtype)

            # For bias (1D), we still split along dim 0
            # For weight (2D), we split along dim 0 (output dimension)
            full_size = hf_weights.shape[0]
            if full_size % self.tp_size != 0:
                raise ValueError(f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks")
            splits = torch.chunk(hf_weights, self.tp_size, dim=0)

        else:
            splits = None

        # Scatter to all ranks. Each rank gets its sharded shape from its module.
        return self.scatter_to_tp_ranks(
            splits,
            target_param.shape,
            target_param.dtype,
            target_param.device,
        )


class RowParallelMapping(MegatronParamMapping[torch.Tensor]):
    """Mapping for **row-parallel** linear weights.

    Megatron shards row-parallel tensors along **dimension 1** (the *input*
    dimension of a linear layer).

    **Forward path (external → Megatron)**
    1.  Rank 0 validates that the *second* dimension is divisible by `tp_size`.
    2.  Rank 0 splits the tensor with `torch.chunk(..., dim=1)` producing
        `tp_size` equally-sized shards.
    3.  The shards are **scattered** so that every TP rank receives exactly one
        shard matching the shape of its local Megatron parameter.

    **Reverse path (Megatron → external)**
    1.  The local Megatron parameter (which may live on any PP rank) is
        broadcast to all PP ranks so that the gather step can be collective.
    2.  All TP ranks **gather** their shard.
    3.  Rank 0 concatenates the gathered list along dim 1 to reconstruct the
        original unsharded weight and emits it under the external (HF) name.
    """

    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Split weight along dim 1 and distribute to TP ranks."""
        if self.tp_size == 1:
            return hf_weights

        # Some parameters are named with global expert number, e.g. experts.weight15,
        # normalize it to experts.weight0, note we are only use the shape, dtype, device info,
        # not the actual value, so it is safe to do this.
        normalized_param = self._normalize_expert_param_name(self.megatron_param)
        _, target_param = get_module_and_param_from_name(megatron_module, normalized_param)

        # On rank 0, check for divisibility and split
        if self.tp_rank == 0:
            if hf_weights is None:
                raise ValueError("hf_weights should not be None on rank 0")

            # bias (1D) is replicated across tp ranks
            # For weight (2D), we split along dim 1
            if hf_weights.ndim == 1:
                splits = [hf_weights] * self.tp_size
            else:
                if hf_weights.ndim != 2:
                    raise ValueError(
                        f"Expected 2-dimensional weight tensor, but got {hf_weights.ndim}-dimensional tensor with shape {hf_weights.shape}")

                full_size = hf_weights.shape[1]
                if full_size % self.tp_size != 0:
                    raise ValueError(
                        f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks"
                    )
                splits = torch.chunk(hf_weights, self.tp_size, dim=1)

        else:
            splits = None
        # Scatter to all ranks. Each rank gets its sharded shape from its module.
        return self.scatter_to_tp_ranks(
            splits,
            target_param.shape,
            target_param.dtype,
            target_param.device,
        )


class WeightReshapeMapping(MegatronParamMapping[torch.Tensor]):
    """
    Modification:Dimensionality Transformations During Qwen3-VL Weight Conversion
    """
    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        if hasattr(megatron_module, "weight"):
            target_device = megatron_module.weight.device
        else:
            # the parameter may not be called "weight"
            target_device = next(megatron_module.parameters()).device
        hf_weights = hf_weights.to(device=target_device)
        hf_weights = hf_weights.view(-1, hf_weights.shape[-1])
        return hf_weights


class ReplicatedMapping(MegatronParamMapping[torch.Tensor]):
    """Mapping for weights that are **fully replicated** across TP ranks.

    Examples: layer-norm scales, biases, router weights in MoE, etc.

    These tensors exist in exactly the same form on *every* TP rank, so the
    mapping logic is trivial – but we still need to broadcast across TP ranks
    during *load* (HF → Megatron) and ensure we do **not** emit duplicates
    during *export* (Megatron → HF).
    """

    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Replicate weight to all TP ranks."""
        if hasattr(megatron_module, "weight"):
            target_device = megatron_module.weight.device
        else:
            # the parameter may not be called "weight"
            target_device = next(megatron_module.parameters()).device
        hf_weights = hf_weights.to(device=target_device)
        if self.tp_size == 1:
            return hf_weights

        # TODO(yuya): router.weight is on device cpu, need to check.
        if target_device.index != torch.cuda.current_device():
            hf_weights = hf_weights.to(torch.cuda.current_device())

        # All ranks need the full weight
        if self.tp_rank > 0:
            # Create empty tensor of correct shape
            hf_weights = torch.empty_like(hf_weights)

        # Broadcast from rank 0 to all TP ranks
        return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0)


class AutoMapping(MegatronParamMapping[torch.Tensor]):
    """
    Smart mapping that automatically detects and applies the correct parallelism strategy.

    This mapping eliminates the need to manually specify whether a layer is
    column-parallel, row-parallel, or replicated. It examines the Megatron
    module at runtime and delegates to the appropriate specialized mapping.

    **Detection strategy**
    1. Check module class name against a registry of known types
    2. If unknown, examine module attributes (tensor_model_parallel, partition_dim)
    3. Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated

    This abstraction is particularly useful for model-agnostic code where you
    don't know the parallelism type ahead of time, or when working with models
    that mix different parallelism strategies.

    **Built-in module recognition**
    -   Column-parallel: `ColumnParallelLinear`, `VocabParallelEmbedding`, etc.
    -   Row-parallel: `RowParallelLinear`, `TERowParallelLinear`
    -   Replicated: `LayerNorm`, `RMSNorm`, and other normalization layers

    **Dimension permutation**
    Supports optional tensor permutation via `permute_dims` parameter. This is useful
    for weights that need to be transposed or have their dimensions reordered during
    conversion. The same permutation is applied in both directions (HF→Megatron and
    Megatron→HF).

    Example:
        .. code-block:: python

            # Automatically handles any weight type
            mapping = AutoMapping(
                megatron_param="decoder.layers.*.mlp.linear_fc1.weight",
                hf_param="model.layers.*.mlp.gate_proj.weight"
            )

            # Works with column-parallel layers
            megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module)

            # Also works with normalization layers
            norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module)

            # With dimension permutation (e.g., transpose)
            transpose_mapping = AutoMapping(
                megatron_param="vision_projection.weight",
                hf_param="multi_modal_projector.weight",
                permute_dims=(1, 0)  # Transpose dimensions
            )

            # Register custom module types
            AutoMapping.register_module_type("MyCustomLinear", "column")

    Note:
        If the parallelism type cannot be determined, the mapping will raise
        a descriptive error suggesting how to fix the issue.
    """

    # Module type registry
    _MODULE_TYPE_REGISTRY: Dict[str, set] = {
        "column": {
            "ColumnParallelLinear",
            "TEColumnParallelLinear",
            "TELayerNormColumnParallelLinear",
            "TEColumnParallelGroupedLinear",
            "VocabParallelEmbedding",
            "DotProductAttention",  # for attention sink only
            "TEDotProductAttention",  # for attention sink only
        },
        "row": {
            "RowParallelLinear",
            "TERowParallelLinear",
            "TERowParallelGroupedLinear",
        },
        "replicated": {
            # Normalization layers
            "FSDPCheckpointWrapper",
            "FSDPHead",
            "TENorm",
            "FusedLayerNorm",
            "WrappedTorchNorm",
            "LayerNorm",
            "RMSNorm",
            "L2Norm",
            # Other non-parallel modules
            "IdentityOp",
            "TopKRouter",
            "Conv3d",
            "Linear",
            "WanDiTBlock",
            "Head",
        },
    }

    @classmethod
    def register_module_type(cls, module_name: str, parallelism_type: str):
        """Register a new module type for automatic parallelism detection.

        Args:
            module_name (str): The name of the module class (e.g.,
                'MyColumnLinear').
            parallelism_type (str): One of 'column', 'row', or 'replicated'.
        """
        if parallelism_type not in cls._MODULE_TYPE_REGISTRY:
            raise ValueError(
                f"Invalid parallelism_type '{parallelism_type}'. "
                f"Must be one of {list(cls._MODULE_TYPE_REGISTRY.keys())}"
            )
        cls._MODULE_TYPE_REGISTRY[parallelism_type].add(module_name)

    def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None):
        """Initialize TP-aware mapping.

        Args:
            megatron_param (str): Megatron parameter name pattern.
            hf_param (str): HuggingFace parameter name pattern.
            permute_dims (Optional[Tuple[int, ...]]): Dimension permutation to apply.
                If provided, the tensor will be permuted and made contiguous during conversion.
        """
        super().__init__(megatron_param, hf_param)

        # Cache for detected parallelism type and delegate mapping
        self._detected_type: Optional[str] = None
        self._mapping: Optional[MegatronParamMapping[torch.Tensor]] = None

        # Permutation settings
        self.permute_dims = permute_dims

    def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[torch.Tensor]:
        """Get or create the appropriate mapping for the given type."""
        if parallelism_type == "column":
            return ColumnParallelMapping(self.megatron_param, self.hf_param)
        elif parallelism_type == "row":
            return RowParallelMapping(self.megatron_param, self.hf_param)
        elif parallelism_type == "replicated":
            return ReplicatedMapping(self.megatron_param, self.hf_param)
        else:
            raise ValueError(f"Unknown parallelism type: {parallelism_type}")

    def _detect_parallelism_type(self, module: nn.Module) -> str:
        """Detect parallelism type from module."""
        module_type = type(module).__name__

        # Handle fused modules like TELayerNormColumnParallelLinear
        # These modules have both column-parallel weights (weight, bias)
        # and replicated layer norm weights (layer_norm_weight, layer_norm_bias)
        if module_type == "TELayerNormColumnParallelLinear":
            # Check the actual parameter name to determine the correct parallelism type
            if self.megatron_param and (
                self.megatron_param.endswith("layer_norm_weight") or self.megatron_param.endswith("layer_norm_bias")
            ):
                return "replicated"
            # All other parameters (weight, bias) are column-parallel
            return "column"

        # Check registry first
        for parallelism, types in self._MODULE_TYPE_REGISTRY.items():
            if module_type in types:
                return parallelism

        # Fallback to inspecting module attributes
        if hasattr(module, "tensor_model_parallel"):
            if not module.tensor_model_parallel:
                return "replicated"

            # Check partition dimension
            partition_dim = getattr(module, "partition_dim", None)
            if partition_dim == 0:
                return "column"
            elif partition_dim == 1:
                return "row"

        # Fallback for normalization layers
        if any(norm in module_type for norm in ["Norm", "Normalization"]):
            return "replicated"

        # Check parallel_mode for TELinear
        if module_type == "TELinear":
            if module.parallel_mode == "column":
                return "column"
            elif module.parallel_mode == "row":
                return "row"
            else:
                return "replicated"

        # Cannot determine - raise informative error
        known_types = {p: sorted(list(t)) for p, t in self._MODULE_TYPE_REGISTRY.items()}

        raise ValueError(
            f"Cannot determine parallelism type for module '{module_type}' "
            f"at weight '{self.megatron_param}'.\n"
            f"Please use an explicit mapping type (e.g., ColumnParallelMapping) "
            f"or register the module type using:\n"
            f"  AutoMapping.register_module_type('{module_type}', 'column|row|replicated')\n\n"
            f"Currently known module types:\n{json.dumps(known_types, indent=2)}"
        )

    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Delegate to appropriate mapping based on module type."""
        # Apply permutation if specified (before distribution)
        if self.permute_dims is not None and self.tp_rank == 0:
            hf_weights = torch.permute(hf_weights, self.permute_dims).contiguous()

        # Detect type and create delegate on first use
        if self._mapping is None:
            self._detected_type = self._detect_parallelism_type(megatron_module)
            self._mapping = self._get_or_create_mapping(self._detected_type)

        return self._mapping.hf_to_megatron(hf_weights, megatron_module)

    def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
        """Create a new mapping with resolved wildcards, preserving permute_dims."""
        resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
        return type(self)(resolved_megatron_param, resolved_hf_param, self.permute_dims)


class QKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
    """
    Mapping for interleaved Query/Key/Value attention projection weights.
    """

    def __init__(self, megatron_param: str, q: str, k: str, v: str):
        """Initialize QKV mapping.

        Args:
            megatron_param (str): Megatron QKV parameter name pattern.
            q (str): Query weight name pattern.
            k (str): Key weight name pattern.
            v (str): Value weight name pattern.
        """
        super().__init__(megatron_param, {"q": q, "k": k, "v": v})
        # Delegate all tensor-parallel logic to the smart TP-aware mapping so we
        # do not hard-code the assumption that QKV projections are column-parallel.
        # This keeps the format-handling (merge/split) concerns separate from
        # TP/PP distribution mechanics.
        self._tp_mapping = AutoMapping(megatron_param, megatron_param)

    def hf_to_megatron(
        self,
        hf_weights: Dict[str, torch.Tensor],
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Merge Q, K, V into interleaved format and distribute."""
        if self.tp_rank == 0:
            config = self._get_config(megatron_module)

            # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors)
            if hf_weights["q"].ndim == 1:
                # For biases, use the bias-specific merge function
                merged = merge_qkv_biases(config, hf_weights["q"], hf_weights["k"], hf_weights["v"])
            else:
                # For hf_weights, use the standard merge function
                merged = merge_qkv_weights(config, hf_weights["q"], hf_weights["k"], hf_weights["v"])
        else:
            merged = None

        # Delegate the actual sharding/broadcasting to the TP-aware mapping.
        return self._tp_mapping.hf_to_megatron(merged, megatron_module)

    def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
        """Return a new *resolved* QKVMapping instance."""
        resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)

        return type(self)(
            resolved_megatron_param,
            resolved_hf_param["q"],
            resolved_hf_param["k"],
            resolved_hf_param["v"],
        )


class KVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
    """
    Mapping for interleaved Key/Value projection weights.

    This mapping converts between separate K and V tensors used in external
    checkpoints and Megatron's interleaved KV format following grouped-query
    attention semantics.

    External format (HF)
    - Separate tensors: k_proj, v_proj
    - Shapes mirror QKV mappings but without Q

    Megatron format
    - Single interleaved tensor with order: [k1, v1, k2, v2, ...]
      where index corresponds to query-group id

    Tensor-parallel distribution is delegated to AutoMapping.
    """

    def __init__(self, megatron_param: str, k: str, v: str):
        super().__init__(megatron_param, {"k": k, "v": v})
        # Delegate TP sharding/broadcasting
        self._tp_mapping = AutoMapping(megatron_param, megatron_param)

    def hf_to_megatron(
        self,
        hf_weights: Dict[str, torch.Tensor],
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Merge K and V into interleaved format and distribute across TP."""
        if self.tp_rank == 0:
            config = self._get_config(megatron_module)

            if hf_weights["k"].ndim == 1:
                merged = merge_kv_biases(config, hf_weights["k"], hf_weights["v"])
            else:
                merged = merge_kv_weights(config, hf_weights["k"], hf_weights["v"])
        else:
            merged = None

        return self._tp_mapping.hf_to_megatron(merged, megatron_module)


class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
    """
    Mapping for interleaved Query/Key/Value attention projection weights.

    This mapping handles the conversion between Concatenated Q, K, V matrices used in
    some transformers models and Megatron's optimized interleaved format. The
    interleaving pattern groups queries with their corresponding key-value pairs
    to maximize GEMM efficiency during attention computation.

    **External format (HuggingFace)**
    -   One tensor with concatenated query, key, value: `qkv`, with shape
        `[hidden_size, head_dim * num_heads + 2 * head_dim * num_query_groups]`

    **Megatron format**
    -   Single interleaved tensor following grouped query attention (GQA) pattern
    -   Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]`
    -   Where `n = num_attention_heads / num_query_groups`

    **Key features**
    1.  Format conversion: Handles merging/splitting with proper interleaving
    2.  Grouped Query Attention: Supports different numbers of Q and KV heads
    3.  Tensor parallelism: Delegates to AutoMapping for distribution

    Example:
        .. code-block:: python

            # Create mapping for attention weights
            mapping = QKVMapping(
                megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
                qkv="model.layers.*.self_attn.qkv.weight",
            )

            # Convert from HuggingFace to Megatron
            megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module)

            # Convert from Megatron to HuggingFace
            hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module)

    Note:
        This mapping automatically handles both regular multi-head attention
        (same number of Q, K, V heads) and grouped query attention (fewer
        KV heads than Q heads) based on the model configuration.
    """

    def __init__(self, megatron_param: str, hf_param: str):
        """Initialize QKV mapping.

        Args:
            megatron_param (str): Megatron interleaved QKV parameter name pattern.
            hf_param (str): HF concatenated QKV parameter name pattern.
        """
        super().__init__(megatron_param, hf_param)
        # Delegate all tensor-parallel logic to the smart TP-aware mapping so we
        # do not hard-code the assumption that QKV projections are column-parallel.
        # This keeps the format-handling (merge/split) concerns separate from
        # TP/PP distribution mechanics.
        self._tp_mapping = AutoMapping(megatron_param, megatron_param)

    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Merge Q, K, V into interleaved format and distribute."""
        if self.tp_rank == 0:
            config = self._get_config(megatron_module)
            head_num = config.num_attention_heads
            head_size = config.kv_channels
            num_query_groups = config.num_query_groups
            q, k, v = hf_weights.split(
                [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0
            )
            # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors)
            if q.ndim == 1:
                # For biases, use the bias-specific merge function
                merged = merge_qkv_biases(config, q, k, v)
            else:
                # For hf_weights, use the standard merge function
                merged = merge_qkv_weights(config, q, k, v)
        else:
            merged = None

        # Delegate the actual sharding/broadcasting to the TP-aware mapping.
        return self._tp_mapping.hf_to_megatron(merged, megatron_module)

    def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
        """Return a new *resolved* QKVMapping instance."""
        resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)

        return type(self)(resolved_megatron_param, resolved_hf_param)


class VisionEncoderQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
    """
    Mapping for vision encoder (ViT/ViT-like) attention projection weights.

    This mapping handles the conversion between HuggingFace's vision encoder format
    and Megatron's format for vision transformer models. The vision encoder typically
    stores QKV weights in a grouped format that requires specific reshaping.

    **HuggingFace format (vision encoder)**
    - Weights: 5D tensor with shape [3, num_heads, -1, head_dim, hidden_size]
    - Biases: 3D tensor with shape [3, num_heads, -1]

    **Megatron format**
    - Single concatenated tensor: [num_heads * 3 * head_dim, hidden_size] for weights
    - Single concatenated tensor: [num_heads * 3 * head_dim] for biases

    Note:
        This is specifically for vision encoders (like ViT) and uses different
        reshaping logic compared to text transformer QKV mappings.
    """

    def __init__(self, megatron_param: str, hf_param: str):
        """Initialize vision encoder QKV mapping.

        Args:
            megatron_param (str): Megatron vision encoder QKV parameter name pattern.
            hf_param (str): HF vision encoder QKV parameter name pattern.
            vision_config (Optional[dict]): Vision encoder configuration containing
                num_heads, hidden_size, etc. If None, will be extracted from module.
        """
        super().__init__(megatron_param, hf_param)

        # Delegate tensor-parallel logic to AutoMapping
        self._tp_mapping = AutoMapping(megatron_param, megatron_param)


    def hf_to_megatron(
        self,
        hf_weights: torch.Tensor,
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Convert HF vision encoder QKV weights to Megatron format."""
        if self.tp_rank == 0:
            # Get vision configuration
            config = self._get_config(megatron_module)

            vision_num_heads = config.num_attention_heads
            vision_hidden_size = config.hidden_size
            vision_head_dim = vision_hidden_size // vision_num_heads

            # Determine if this is bias or weight
            is_bias = hf_weights.ndim == 1

            # Reshape based on bias/weight
            if is_bias:
                in_shape = (3, vision_num_heads, -1)
            else:
                in_shape = (3, vision_num_heads, -1, vision_head_dim, vision_hidden_size)

            # Reshape and separate Q, K, V
            # Note: hf_weights[0] is used because original code had hf_weights as a list
            q, k, v = hf_weights.view(*in_shape)
            # Further reshape
            q = q.view(vision_num_heads, vision_head_dim, -1)
            k = k.view(vision_num_heads, vision_head_dim, -1)
            v = v.view(vision_num_heads, vision_head_dim, -1)

            # Merge Q, K, V
            if is_bias:
                qkv = torch.cat([q, k, v], dim=1).flatten()
            else:
                qkv = torch.cat([q, k, v], dim=1).reshape(-1, vision_hidden_size)
        else:
            qkv = None
        # Delegate TP distribution to AutoMapping
        return self._tp_mapping.hf_to_megatron(qkv, megatron_module)

    def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
        """Return a new resolved VisionEncoderQKVMapping instance."""
        resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)

        # Pass vision_config to the new instance
        return type(self)(resolved_megatron_param, resolved_hf_param)


class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
    r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU).

    Checkpoint formats expose two independent matrices:

    -   **G** – gate projection
    -   **U** – up projection

    Megatron concatenates them row-wise (`[G; U]`) so that a single GEMM can
    produce both activations.

    **Responsibilities handled by this mapping**
    1.  **Concatenate / split** – convert between `[G; U]` (Megatron) and the
        separate `{G, U}` matrices (external).
    2.  **Tensor-parallel distribution** – correctly splits gate and up
        projections separately before concatenating corresponding shards,
        ensuring each TP rank gets the proper [gate_shard; up_shard] format.

    **TP Distribution Strategy**
    For tensor parallelism, this mapping:
    - Splits gate and up matrices separately along output dimension (dim 0)
    - Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i
    - This ensures each rank's concatenated tensor matches the expected shape
    """

    def __init__(self, megatron_param: str, gate: str, up: str):
        """Initialize gated MLP mapping.

        Args:
            megatron_param (str): Megatron MLP parameter name pattern.
            gate (str): Gate projection weight name pattern.
            up (str): Up projection weight name pattern.
        """
        super().__init__(megatron_param, {"gate": gate, "up": up})

    def hf_to_megatron(
        self,
        hf_weights: Dict[str, torch.Tensor],
        megatron_module: nn.Module,
    ) -> torch.Tensor:
        """Split gate and up separately, then concatenate corresponding shards."""
        # For single TP, just concatenate and return
        if self.tp_size == 1:
            return torch.cat([hf_weights["gate"], hf_weights["up"]], dim=0)

        # Get target parameter info from megatron module
        # Some parameters are named with global expert number, e.g. experts.weight15,
        # normalize it to experts.weight0, note we are only use the shape, dtype, device info,
        # not the actual value, so it is safe to do this.
        normalized_param = self._normalize_expert_param_name(self.megatron_param)
        _, target_param = get_module_and_param_from_name(megatron_module, normalized_param)

        # On rank 0, split gate and up separately, then concatenate corresponding pieces
        if self.tp_rank == 0:
            gate = hf_weights["gate"]
            up = hf_weights["up"]

            # Verify shapes match
            if gate.shape != up.shape:
                raise ValueError(
                    f"Gate and up weights must have the same shape. Gate shape: {gate.shape}, Up shape: {up.shape}")

            # Check divisibility for TP splitting
            gate_output_size = gate.shape[0]
            if gate_output_size % self.tp_size != 0:
                raise ValueError(
                    f"Cannot evenly split gate dimension 0 size {gate_output_size} across {self.tp_size} TP ranks"
                )

            # Split gate and up separately along output dimension (dim 0)
            # This works for both bias (1D) and weight (2D) tensors
            gate_splits = torch.chunk(gate, self.tp_size, dim=0)
            up_splits = torch.chunk(up, self.tp_size, dim=0)

            # Concatenate corresponding pieces: [gate_shard_i; up_shard_i] for each rank i
            splits = [torch.cat([gate_splits[i], up_splits[i]], dim=0) for i in range(self.tp_size)]
        else:
            splits = None

        # Scatter the concatenated shards to each rank
        return self.scatter_to_tp_ranks(
            splits,
            target_param.shape,
            target_param.dtype,
            target_param.device,
        )

    def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
        """Return a new *resolved* GatedMLPMapping instance."""
        resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)

        return type(self)(
            resolved_megatron_param,
            resolved_hf_param["gate"],
            resolved_hf_param["up"],
        )


class RMSNorm2ZeroCenteredRMSNormMapping(AutoMapping):
    """
    Mapping for zero-centered RMSNorm to standard RMSNorm.
    """

    def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor:
        hf_weights.data = hf_weights.data - 1
        return super().hf_to_megatron(hf_weights, megatron_module)


def merge_qkv_biases(config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format.

    Args:
        config (TransformerConfig): Transformer configuration.
        q (torch.Tensor): Query projection biases [hidden_size].
        k (torch.Tensor): Key projection biases [kv_hidden_size].
        v (torch.Tensor): Value projection biases [kv_hidden_size].

    Returns:
        torch.Tensor: Interleaved QKV biases in Megatron format as 1D tensor.
    """
    head_num = config.num_attention_heads
    num_query_groups = config.num_query_groups
    heads_per_group = head_num // num_query_groups
    head_size = config.kv_channels or (config.hidden_size // head_num)

    # Reshape biases to expose head dimension
    if getattr(config, "attention_output_gate", False):
        q, z = torch.chunk(q.view(head_num, head_size * 2), 2, dim=-1)
    else:
        q = q.view(head_num, head_size)
    k = k.view(num_query_groups, head_size)
    v = v.view(num_query_groups, head_size)

    # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...]
    qkv_biases = []
    for i in range(num_query_groups):
        qkv_biases.append(q[i * heads_per_group: (i + 1) * heads_per_group, :])
        if getattr(config, "attention_output_gate", False):
            qkv_biases.append(z[i * heads_per_group: (i + 1) * heads_per_group, :])
        qkv_biases.append(k[i: i + 1, :])
        qkv_biases.append(v[i: i + 1, :])

    # Concatenate and flatten back to 1D
    qkv = torch.cat(qkv_biases)
    return qkv.flatten()


def merge_qkv_weights(provider: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Merge separate Q, K, V weight matrices into Megatron's interleaved QKV format.

    Args:
        provider (TransformerConfig): Model configuration provider.
        q (torch.Tensor): Query projection weights [hidden_size, hidden_size] or
            bias [hidden_size].
        k (torch.Tensor): Key projection weights [kv_hidden_size, hidden_size]
            or bias [kv_hidden_size].
        v (torch.Tensor): Value projection weights [kv_hidden_size,
            hidden_size] or bias [kv_hidden_size].

    Returns:
        torch.Tensor: Interleaved QKV weights in Megatron format.
    """
    head_num = provider.num_attention_heads
    num_query_groups = provider.num_query_groups
    heads_per_group = head_num // num_query_groups
    head_size = provider.kv_channels or (provider.hidden_size // head_num)
    hidden_size = provider.hidden_size
    is_bias = q.ndim == 1
    q_head_size = head_size * 2 if getattr(provider, "attention_output_gate", False) else head_size

    # Reshape to expose head dimension
    if is_bias:
        q_reshaped = q.view(head_num, q_head_size)
        k_reshaped = k.view(num_query_groups, head_size)
        v_reshaped = v.view(num_query_groups, head_size)
    else:
        q_reshaped = q.view(head_num, q_head_size, hidden_size)
        k_reshaped = k.view(num_query_groups, head_size, hidden_size)
        v_reshaped = v.view(num_query_groups, head_size, hidden_size)
    if getattr(provider, "attention_output_gate", False):
        q_reshaped, z_reshaped = torch.chunk(q_reshaped, 2, dim=1)

    # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...]
    qkv_weights = []
    for i in range(num_query_groups):
        q_group = q_reshaped[i * heads_per_group: (i + 1) * heads_per_group]
        k_group = k_reshaped[i: i + 1]
        v_group = v_reshaped[i: i + 1]
        if getattr(provider, "attention_output_gate", False):
            z_group = z_reshaped[i * heads_per_group: (i + 1) * heads_per_group]
            qkv_weights.extend([q_group, z_group, k_group, v_group])
        else:
            qkv_weights.extend([q_group, k_group, v_group])

    qkv = torch.cat(qkv_weights, dim=0)

    if q.numel() + k.numel() + v.numel() != qkv.numel():
        raise AssertionError(
            f"QKV weights are not correctly merged, {q.shape=}, {k.shape=}, {v.shape=}, {qkv.shape=}"
        )

    # Final reshape
    if is_bias:
        return qkv.reshape(-1)
    else:
        return qkv.reshape([-1, hidden_size])


def merge_kv_biases(config: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Merge separate K, V bias vectors into Megatron's interleaved KV format (1D)."""
    num_query_groups = config.num_query_groups
    head_size = config.kv_channels or (config.hidden_size // config.num_attention_heads)

    k = k.view(num_query_groups, head_size)
    v = v.view(num_query_groups, head_size)

    pieces: List[torch.Tensor] = []
    for i in range(num_query_groups):
        pieces.append(k[i: i + 1, :])
        pieces.append(v[i: i + 1, :])

    kv = torch.cat(pieces, dim=0)
    return kv.reshape(-1)


def merge_kv_weights(provider: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Merge separate K, V weights into Megatron's interleaved KV format (2D)."""
    num_query_groups = provider.num_query_groups
    head_size = provider.kv_channels or (provider.hidden_size // provider.num_attention_heads)
    hidden_size = provider.hidden_size

    k_reshaped = k.view(num_query_groups, head_size, hidden_size)
    v_reshaped = v.view(num_query_groups, head_size, hidden_size)

    pieces: List[torch.Tensor] = []
    for i in range(num_query_groups):
        pieces.append(k_reshaped[i: i + 1])
        pieces.append(v_reshaped[i: i + 1])

    kv = torch.cat(pieces, dim=0)
    return kv.view(-1, hidden_size)