import json
import re
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import torch
import torch.distributed
import torch.nn as nn
from megatron.core import mpu
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from bridge.models.conversion.utils import get_module_and_param_from_name, remove_non_pickleables
WeightType = TypeVar("WeightType", torch.Tensor, Dict[str, torch.Tensor])
logger = logging.getLogger(__name__)
def get_pg_size(group=None):
"""Get world size for a distributed group.
Args:
group: Process group to get world size for. If None, uses default group.
Returns:
int: World size (1 if distributed not initialized or group is None, else group.size())
"""
if not torch.distributed.is_initialized() or group is None:
return 1
return group.size()
def get_pg_rank(group=None):
"""Get rank for a distributed group.
Args:
group: Process group to get rank for. If None, uses default group.
Returns:
int: Rank (0 if distributed not initialized or group is None, else group.rank())
"""
if not torch.distributed.is_initialized() or group is None:
return 0
return group.rank()
class MegatronParamMapping(ABC, Generic[WeightType]):
"""
Abstract base class for weight conversion between Megatron and external formats.
This class provides the foundation for all weight mappings, handling the complex
conversions between Megatron-Core's distributed tensor formats and standard
(typically HuggingFace) formats. Each concrete mapping implements specific
transformation logic while inheriting common parallel communication patterns.
Key responsibilities:
- Format transformation (e.g., QKV merging/splitting, gated MLP handling)
- Tensor parallel (TP) distribution and gathering across GPUs
- Pipeline parallel (PP) broadcasting between pipeline stages
- Wildcard pattern resolution for layer-wise mappings
The mapping abstraction ensures that higher-level code doesn't need to know
about the parallel topology or format differences - it just requests a
conversion and the mapping handles all the complexity.
Public helper methods for subclasses:
- broadcast_from_pp_rank: Broadcast tensors across pipeline stages
- broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks
- broadcast_tensor_to_tp_ranks: Broadcast within TP group
- scatter_to_tp_ranks: Distribute tensor shards to TP ranks
- gather_from_tp_ranks: Collect tensor shards from TP ranks
Example:
.. code-block:: python
class MyCustomMapping(MegatronParamMapping[torch.Tensor]):
def hf_to_megatron(self, hf_weights, megatron_module):
# Custom transformation logic
transformed = hf_weights.t() # Example: transpose
# Use helpers for distribution
return self.scatter_to_tp_ranks(...)
def megatron_to_hf(self, megatron_weights, megatron_module):
# Broadcast from owning PP rank
weight = self.broadcast_from_pp_rank(megatron_weights)
# Gather from TP ranks and transform
gathered = self.gather_from_tp_ranks(weight)
return {"custom_weight": gathered[0].t()}
"""
def __init__(self, megatron_param: str, hf_param: Union[str, Dict[str, str]]):
"""Initialize the weight mapping.
Args:
megatron_param (str): Megatron parameter name pattern (supports *
wildcards).
hf_param (Union[str, Dict[str, str]]): External format name pattern(s).
"""
self.megatron_param = megatron_param
self.hf_param = hf_param
self._validate_patterns()
self._broadcast_obj_cache = {}
self._tensor_spec_output_cache = {}
if mpu.is_initialized():
self.pp_group = mpu.get_pipeline_model_parallel_group()
self.ep_group = mpu.get_expert_model_parallel_group()
self._tp_group = mpu.get_tensor_model_parallel_group()
self._etp_group = mpu.get_expert_tensor_parallel_group()
else:
self.pp_group = None
self.ep_group = None
self._tp_group = None
self._etp_group = None
self.allow_hf_name_mismatch = False
@property
def tp_group(self):
"""Get the tensor model parallel group."""
if self.is_expert:
return self._etp_group
return self._tp_group
@property
def tp_rank(self) -> int:
"""Get the tensor model parallel rank."""
return get_pg_rank(self.tp_group)
@property
def tp_size(self) -> int:
"""Get the tensor model parallel size."""
return get_pg_size(self.tp_group)
@property
def pp_rank(self) -> int:
"""Get the pipeline model parallel rank."""
return get_pg_rank(self.pp_group)
@property
def pp_size(self) -> int:
"""Get the pipeline model parallel size."""
return get_pg_size(self.pp_group)
@property
def ep_rank(self) -> int:
"""Get the expert model parallel rank."""
return get_pg_rank(self.ep_group)
@property
def ep_size(self) -> int:
"""Get the expert model parallel size."""
return get_pg_size(self.ep_group)
@property
def etp_rank(self) -> int:
"""Get the expert tensor parallel rank."""
return get_pg_rank(self.etp_group)
@property
def etp_size(self) -> int:
"""Get the expert tensor parallel size."""
return get_pg_size(self.etp_group)
@property
def is_expert(self) -> bool:
"""Check if this mapping is for an expert parameter.
Matches both TEGroupedMLP (.mlp.experts.linear_fc) and
SequentialMLP (.mlp.experts.local_experts.*.linear_fc) patterns.
"""
return ".mlp.experts.linear_fc" in self.megatron_param or ".mlp.experts.local_experts." in self.megatron_param
def _resolve_names(self, captures: Tuple[str, ...]) -> Tuple[str, Union[str, Dict[str, str]]]:
"""Resolve wildcard patterns with captured values.
Handles both ** (any characters) and * (digits) wildcards in order.
** patterns are processed before * patterns to avoid conflicts.
"""
resolved_megatron_param = self.megatron_param
capture_index = 0
while "**" in resolved_megatron_param and capture_index < len(captures):
resolved_megatron_param = resolved_megatron_param.replace("**", captures[capture_index], 1)
capture_index += 1
while "*" in resolved_megatron_param and capture_index < len(captures):
resolved_megatron_param = resolved_megatron_param.replace("*", captures[capture_index], 1)
capture_index += 1
if isinstance(self.hf_param, str):
resolved_hf_param = self.hf_param
capture_index = 0
while "**" in resolved_hf_param and capture_index < len(captures):
resolved_hf_param = resolved_hf_param.replace("**", captures[capture_index], 1)
capture_index += 1
while "*" in resolved_hf_param and capture_index < len(captures):
resolved_hf_param = resolved_hf_param.replace("*", captures[capture_index], 1)
capture_index += 1
else:
resolved_hf_param = {}
for k, v in self.hf_param.items():
resolved_v = v
capture_index = 0
while "**" in resolved_v and capture_index < len(captures):
resolved_v = resolved_v.replace("**", captures[capture_index], 1)
capture_index += 1
while "*" in resolved_v and capture_index < len(captures):
resolved_v = resolved_v.replace("*", captures[capture_index], 1)
capture_index += 1
resolved_hf_param[k] = resolved_v
return resolved_megatron_param, resolved_hf_param
def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
"""Create a new mapping with resolved wildcards.
This default implementation works for mappings with a
(megatron_param, hf_param) constructor.
Args:
captures (Tuple[str, ...]): Captured wildcard values.
Returns:
MegatronParamMapping: A new mapping instance with resolved names.
"""
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(resolved_megatron_param, resolved_hf_param)
@abstractmethod
def hf_to_megatron(
self,
hf_weights: WeightType,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Convert hf_weights TO Megatron format.
This method handles:
1. Format transformation (if needed)
2. Tensor parallel distribution (if self.tp_size > 1)
Args:
hf_weights (WeightType): Source hf_weights in external format.
megatron_module (nn.Module): Target Megatron module (for config
access).
Returns:
torch.Tensor: Weight tensor ready for the current TP rank.
"""
...
def broadcast_from_pp_rank(
self, tensor: Optional[torch.Tensor], cache_key: Optional[str] = None
) -> Optional[torch.Tensor]:
"""Broadcast a tensor from the pipeline-parallel rank that owns it.
Broadcasts to **all** PP ranks. This mirrors the behaviour of
`broadcast_from_megatron_pp` in the original MMapping implementation and
additionally keeps the tensor-parallel metadata (`tensor_model_parallel`,
`partition_dim`) consistent on every rank.
Args:
tensor (Optional[torch.Tensor]): The local tensor if the current PP
rank owns it. ``None`` otherwise.
Returns:
Optional[torch.Tensor]: The broadcasted tensor on every PP rank, or
``None`` if *no* PP rank owned the tensor (which indicates a bug
in the calling code).
"""
if self.pp_size == 1:
return tensor
if cache_key is not None and cache_key in self._tensor_spec_output_cache:
tensor_spec_output = self._tensor_spec_output_cache[cache_key]
else:
if tensor is not None:
shape = tensor.shape
dtype = tensor.dtype
tensor_parallel = getattr(tensor, "tensor_model_parallel", None)
partition_dim = getattr(tensor, "partition_dim", None)
tensor_spec = (shape, dtype, tensor_parallel, partition_dim)
else:
tensor_spec = None
tensor_spec_output: list[Optional[tuple]] = [None] * self.pp_size
torch.distributed.all_gather_object(tensor_spec_output, tensor_spec, group=self.pp_group)
self._tensor_spec_output_cache[cache_key] = tensor_spec_output
target_tensor_spec = None
src_rank = None
for rank, spec in enumerate(tensor_spec_output):
if spec is not None:
if target_tensor_spec is not None:
raise ValueError(f"Tensor exists on more than one PP rank. Found on ranks {src_rank} and {rank}.")
target_tensor_spec = spec
src_rank = rank
if target_tensor_spec is None:
raise ValueError("Object must exist on at least one PP rank")
if tensor is None:
shape, dtype, tensor_parallel, partition_dim = target_tensor_spec
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = torch.empty(shape, dtype=dtype, device=device)
if tensor_parallel is not None:
tensor.tensor_model_parallel = tensor_parallel
if partition_dim is not None:
tensor.partition_dim = partition_dim
global_src = torch.distributed.get_global_rank(group=self.pp_group, group_rank=src_rank)
torch.distributed.broadcast(tensor, src=global_src, group=self.pp_group)
return tensor
def broadcast_obj_from_pp_rank(self, obj: Optional[Any], cache_key: Optional[str] = None) -> Any:
"""Broadcast any Python object from the PP rank that owns it.
This method is useful for broadcasting configuration objects or
other metadata across pipeline parallel ranks. Results are cached
after the first call to avoid redundant broadcasts.
Args:
obj (Optional[Any]): Object to broadcast (None on non-owning ranks).
cache_key (Optional[str]): Optional cache key. If not provided,
no caching will be performed.
Returns:
Any: Broadcasted object on all ranks.
Raises:
ValueError: If object exists on multiple ranks or no ranks.
"""
if self.pp_size == 1:
return obj
if cache_key is not None and cache_key in self._broadcast_obj_cache:
return self._broadcast_obj_cache[cache_key]
has_obj = obj is not None
obj_flags = [None] * self.pp_size
torch.distributed.all_gather_object(obj_flags, has_obj, group=self.pp_group)
src_rank = None
for rank, flag in enumerate(obj_flags):
if flag:
src_rank = rank
if src_rank is None:
raise ValueError("Object must exist on at least one PP rank")
if src_rank is None:
raise ValueError("Could not determine source rank")
obj_list = [obj]
pp_ranks = torch.distributed.get_process_group_ranks(self.pp_group)
global_src = pp_ranks[src_rank]
torch.distributed.broadcast_object_list(obj_list, src=global_src, group=self.pp_group)
result = obj_list[0]
if cache_key is not None:
self._broadcast_obj_cache[cache_key] = result
return result
def clear_broadcast_cache(self):
"""Clear the broadcast object cache.
This can be useful for testing or if the objects being broadcast
might change during the lifetime of the mapping.
"""
self._broadcast_obj_cache.clear()
def clear_tensor_spec_output_cache(self):
"""Clear the tensor spec output cache.
This can be useful for testing or if the tensor spec output
might change during the lifetime of the mapping.
"""
self._tensor_spec_output_cache.clear()
def broadcast_tensor_to_tp_ranks(self, tensor: torch.Tensor, src_rank: int = 0) -> torch.Tensor:
"""Broadcast a tensor to all TP ranks.
Args:
tensor (torch.Tensor): The tensor to broadcast.
src_rank (int, optional): The source rank within the TP group.
Defaults to 0.
Returns:
torch.Tensor: The broadcasted tensor.
"""
if self.tp_size == 1:
return tensor
global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank)
torch.distributed.broadcast(tensor, src=global_src, group=self.tp_group)
return tensor
def scatter_to_tp_ranks(
self,
splits: Optional[List[torch.Tensor]],
output_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
src_rank: int = 0,
) -> torch.Tensor:
"""Scatter tensor splits to TP ranks.
Args:
splits (Optional[List[torch.Tensor]]): A list of tensor shards to
scatter. Only rank `src_rank` needs this.
output_shape (torch.Size): The shape of the output tensor on each rank.
dtype (torch.dtype): The data type of the output tensor.
device (torch.device): The device for the output tensor.
src_rank (int, optional): The source rank for the scatter operation.
Defaults to 0.
Returns:
torch.Tensor: The scattered tensor shard on the current rank.
"""
if self.tp_size == 1:
return splits[0].to(device=device, dtype=dtype) if splits else None
output = torch.empty(output_shape, dtype=dtype, device=device)
global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank)
scatter_list = None
if self.tp_rank == src_rank and splits:
scatter_list = [s.to(device=device, dtype=dtype) for s in splits]
torch.distributed.scatter(
output,
scatter_list,
src=global_src,
group=self.tp_group,
)
return output
def gather_from_tp_ranks(self, tensor: torch.Tensor) -> List[torch.Tensor]:
"""Gather tensors from all TP ranks.
Args:
tensor (torch.Tensor): The tensor shard to be gathered from the
current rank.
Returns:
List[torch.Tensor]: A list of tensor shards from all TP ranks.
"""
if self.tp_size == 1:
return [tensor]
gathered = [torch.empty_like(tensor) for _ in range(self.tp_size)]
torch.distributed.all_gather(gathered, tensor, group=self.tp_group)
return gathered
def _count_wildcard_groups(self, pattern: str) -> int:
"""Count the number of wildcard capture groups in a pattern.
Args:
pattern: Pattern string with * and ** wildcards
Returns:
Number of capture groups that will be generated
Note:
** counts as 1 group, * counts as 1 group
** must be counted before * to avoid double-counting
"""
count = 0
remaining = pattern
while "**" in remaining:
count += 1
remaining = remaining.replace("**", "", 1)
count += remaining.count("*")
return count
def _validate_patterns(self):
"""Validate wildcard consistency between patterns."""
megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param)
if isinstance(self.hf_param, str):
hf_param_wildcards = self._count_wildcard_groups(self.hf_param)
if megatron_param_wildcards != hf_param_wildcards:
raise ValueError(
f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has "
f"{megatron_param_wildcards} wildcards, hf_param='{self.hf_param}' has {hf_param_wildcards}"
)
else:
for key, pattern in self.hf_param.items():
hf_param_wildcards = self._count_wildcard_groups(pattern)
if megatron_param_wildcards != hf_param_wildcards:
raise ValueError(
f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has "
f"{megatron_param_wildcards} wildcards, hf_param['{key}']='{pattern}' has {hf_param_wildcards}"
)
def _normalize_expert_param_name(self, param_name: str) -> str:
"""Normalize expert parameter name by replacing trailing numbers with 0.
e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0
Args:
param_name (str): Parameter name that may end with a number.
Returns:
str: Parameter name with trailing number replaced by 0.
"""
return re.sub(r"\d+$", "0", param_name)
def _get_config(self, module: nn.Module) -> Any:
"""Extract configuration from module hierarchy."""
current = module
while current is not None:
if hasattr(current, "config"):
return current.config
if hasattr(current, "_parent"):
current = current._parent
else:
for parent_module in module.modules():
for _, child_module in parent_module.named_children():
if child_module is current:
current = parent_module
break
else:
continue
break
else:
current = None
raise ValueError(
f"Could not find config in module hierarchy for {module.__class__.__name__}. "
f"Ensure the module or its parent has a 'config' attribute."
)
def gather_from_ep_ranks(
self,
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[MegatronModule],
hf_param_name: Optional[str],
) -> Dict[str, torch.Tensor]:
"""Handle expert parallel weight gathering for MoE models.
This method gathers expert weights across expert-parallel (EP) ranks and
returns a mapping from HF parameter names to the corresponding tensors
from each EP rank. Call this only for confirmed expert parameters
(self.is_expert is True), typically after TP gathering/concatenation in
the export path (Megatron → HF).
Behavior and notation:
- Let E be the total number of experts (e.g., config.num_moe_experts) and
S be the expert-parallel size (ep_size). We assume E % S == 0.
- Each EP rank owns E/S experts. For a given parameter name, we infer a
local expert index L (0 ≤ L < E/S) on the current EP rank from the
global expert id embedded in the name (works for both .weight and .bias).
- The set of global expert ids that correspond to this local index L
across all EP ranks is: {L + k * (E/S) | k ∈ [0, S-1]}.
Communication and outputs:
- We perform an all_gather over the EP group to collect the tensor from
every EP rank into a list ordered by EP rank id.
- For each EP rank k, we construct the HF parameter name by replacing the
expert id in `hf_param_name` with (L + k * (E/S)), preserving the rest
of the path, and map that name to the gathered tensor from rank k.
Example:
- E = 8, S = 2 → E/S = 4. Experts are distributed as:
Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7].
If the local index L = 0 (derived from the param name), this returns:
{"...experts.0.weight": tensor_from_rank0, "...experts.4.weight": tensor_from_rank1}
Args:
megatron_weights (Optional[torch.Tensor]): The local expert weight tensor
(after any TP handling) on this EP rank.
megatron_module (Optional[MegatronModule]): The Megatron module containing
configuration (used to determine E and E/S). Can be None on non-owning PP
ranks; values will be broadcast across PP.
hf_param_name (Optional[str]): HF parameter name template for the current
(local) expert on this rank. The expert id within this string is replaced
with the appropriate global expert ids for each EP rank.
Returns:
Dict[str, torch.Tensor]: Mapping from HF parameter names (one per EP rank)
to the corresponding expert tensors gathered from each EP rank.
"""
if megatron_module is None:
num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank")
else:
model_config = self._get_config(megatron_module)
num_experts = model_config.num_moe_experts
num_experts_per_rank = num_experts // self.ep_size
num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank")
local_expert_number = None
for key in (".weight", ".bias"):
if key in self.megatron_param:
global_expert_number = int(self.megatron_param.split(key)[-1])
local_expert_number = global_expert_number % num_experts_per_rank
gathered_expert_param_names = [
re.sub(
r"experts\.(\d+)", f"experts.{int(local_expert_number) + num_experts_per_rank * i}", str(hf_param_name)
)
for i in range(self.ep_size)
]
if str(hf_param_name) not in gathered_expert_param_names:
raise ValueError(
f"Parameter name '{hf_param_name}' not found in expert parameter names. "
f"Available names: {gathered_expert_param_names}"
)
gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)]
torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group)
weights_dict = {}
for i, param_name in enumerate(gathered_expert_param_names):
if param_name in weights_dict:
weights_dict[param_name] = torch.cat(
[weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0
)
else:
weights_dict[param_name] = gathered_weights[i].unsqueeze(0)
for param_name in weights_dict:
weights_dict[param_name] = weights_dict[param_name].squeeze()
return weights_dict
def maybe_dequantize(self, tensor: torch.Tensor) -> torch.Tensor:
"""Dequantize FP8 tensor if needed."""
return tensor
class DirectMapping(MegatronParamMapping[torch.Tensor]):
"""Direct 1:1 weight mapping with no transformation or tensor parallelism."""
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Direct copy - no transformation or distribution."""
return hf_weights
class ColumnParallelMapping(MegatronParamMapping[torch.Tensor]):
"""
Mapping for column-parallel linear and embedding weights.
Column-parallel layers in Megatron split the output dimension across tensor
parallel ranks. This is used for layers where each rank computes a portion
of the output features independently, such as:
- Embedding layers (split vocabulary)
- Linear layers producing hidden states (e.g., QKV projections, MLP up projections)
The weight matrix is partitioned along dimension 0 (rows), so each TP rank
holds a subset of output features while maintaining all input features.
**Sharding pattern**
- Original weight: `[output_features, input_features]`
- Rank 0: `[output_features/tp_size, input_features]`
- Rank 1: `[output_features/tp_size, input_features]`
- ...
**Forward path (HuggingFace → Megatron)**
1. Validate divisibility: output dimension must be divisible by tp_size
2. Split: Chunk tensor along dim 0 into tp_size equal parts
3. Scatter: Distribute chunks to respective TP ranks
**Reverse path (Megatron → HuggingFace)**
1. Broadcast: Ensure all PP ranks have the tensor
2. Gather: Collect chunks from all TP ranks
3. Concatenate: Reassemble along dim 0 on rank 0
Example:
.. code-block:: python
# For a weight of shape [4096, 1024] with tp_size=4:
# Each rank gets [1024, 1024] after column-parallel split
mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight")
megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module)
# megatron_weights.shape = [1024, 1024] on each rank
Note:
This mapping also handles bias terms, which are 1D tensors split
along their only dimension following the same pattern.
"""
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Split weight along dim 0 and distribute to TP ranks."""
if self.tp_size == 1:
return hf_weights
normalized_param = self._normalize_expert_param_name(self.megatron_param)
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)
if self.tp_rank == 0:
if hf_weights is None:
raise ValueError("hf_weights should not be None on rank 0")
if hf_weights.dtype != target_param.dtype:
logger.warning(
f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. "
f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. "
f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. "
)
hf_weights = hf_weights.to(target_param.dtype)
full_size = hf_weights.shape[0]
if full_size % self.tp_size != 0:
raise ValueError(f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks")
splits = torch.chunk(hf_weights, self.tp_size, dim=0)
else:
splits = None
return self.scatter_to_tp_ranks(
splits,
target_param.shape,
target_param.dtype,
target_param.device,
)
class RowParallelMapping(MegatronParamMapping[torch.Tensor]):
"""Mapping for **row-parallel** linear weights.
Megatron shards row-parallel tensors along **dimension 1** (the *input*
dimension of a linear layer).
**Forward path (external → Megatron)**
1. Rank 0 validates that the *second* dimension is divisible by `tp_size`.
2. Rank 0 splits the tensor with `torch.chunk(..., dim=1)` producing
`tp_size` equally-sized shards.
3. The shards are **scattered** so that every TP rank receives exactly one
shard matching the shape of its local Megatron parameter.
**Reverse path (Megatron → external)**
1. The local Megatron parameter (which may live on any PP rank) is
broadcast to all PP ranks so that the gather step can be collective.
2. All TP ranks **gather** their shard.
3. Rank 0 concatenates the gathered list along dim 1 to reconstruct the
original unsharded weight and emits it under the external (HF) name.
"""
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Split weight along dim 1 and distribute to TP ranks."""
if self.tp_size == 1:
return hf_weights
normalized_param = self._normalize_expert_param_name(self.megatron_param)
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)
if self.tp_rank == 0:
if hf_weights is None:
raise ValueError("hf_weights should not be None on rank 0")
if hf_weights.ndim == 1:
splits = [hf_weights] * self.tp_size
else:
if hf_weights.ndim != 2:
raise ValueError(
f"Expected 2-dimensional weight tensor, but got {hf_weights.ndim}-dimensional tensor with shape {hf_weights.shape}")
full_size = hf_weights.shape[1]
if full_size % self.tp_size != 0:
raise ValueError(
f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks"
)
splits = torch.chunk(hf_weights, self.tp_size, dim=1)
else:
splits = None
return self.scatter_to_tp_ranks(
splits,
target_param.shape,
target_param.dtype,
target_param.device,
)
class WeightReshapeMapping(MegatronParamMapping[torch.Tensor]):
"""
Modification:Dimensionality Transformations During Qwen3-VL Weight Conversion
"""
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
if hasattr(megatron_module, "weight"):
target_device = megatron_module.weight.device
else:
target_device = next(megatron_module.parameters()).device
hf_weights = hf_weights.to(device=target_device)
hf_weights = hf_weights.view(-1, hf_weights.shape[-1])
return hf_weights
class ReplicatedMapping(MegatronParamMapping[torch.Tensor]):
"""Mapping for weights that are **fully replicated** across TP ranks.
Examples: layer-norm scales, biases, router weights in MoE, etc.
These tensors exist in exactly the same form on *every* TP rank, so the
mapping logic is trivial – but we still need to broadcast across TP ranks
during *load* (HF → Megatron) and ensure we do **not** emit duplicates
during *export* (Megatron → HF).
"""
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Replicate weight to all TP ranks."""
if hasattr(megatron_module, "weight"):
target_device = megatron_module.weight.device
else:
target_device = next(megatron_module.parameters()).device
hf_weights = hf_weights.to(device=target_device)
if self.tp_size == 1:
return hf_weights
if target_device.index != torch.cuda.current_device():
hf_weights = hf_weights.to(torch.cuda.current_device())
if self.tp_rank > 0:
hf_weights = torch.empty_like(hf_weights)
return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0)
class AutoMapping(MegatronParamMapping[torch.Tensor]):
"""
Smart mapping that automatically detects and applies the correct parallelism strategy.
This mapping eliminates the need to manually specify whether a layer is
column-parallel, row-parallel, or replicated. It examines the Megatron
module at runtime and delegates to the appropriate specialized mapping.
**Detection strategy**
1. Check module class name against a registry of known types
2. If unknown, examine module attributes (tensor_model_parallel, partition_dim)
3. Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated
This abstraction is particularly useful for model-agnostic code where you
don't know the parallelism type ahead of time, or when working with models
that mix different parallelism strategies.
**Built-in module recognition**
- Column-parallel: `ColumnParallelLinear`, `VocabParallelEmbedding`, etc.
- Row-parallel: `RowParallelLinear`, `TERowParallelLinear`
- Replicated: `LayerNorm`, `RMSNorm`, and other normalization layers
**Dimension permutation**
Supports optional tensor permutation via `permute_dims` parameter. This is useful
for weights that need to be transposed or have their dimensions reordered during
conversion. The same permutation is applied in both directions (HF→Megatron and
Megatron→HF).
Example:
.. code-block:: python
# Automatically handles any weight type
mapping = AutoMapping(
megatron_param="decoder.layers.*.mlp.linear_fc1.weight",
hf_param="model.layers.*.mlp.gate_proj.weight"
)
# Works with column-parallel layers
megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module)
# Also works with normalization layers
norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module)
# With dimension permutation (e.g., transpose)
transpose_mapping = AutoMapping(
megatron_param="vision_projection.weight",
hf_param="multi_modal_projector.weight",
permute_dims=(1, 0) # Transpose dimensions
)
# Register custom module types
AutoMapping.register_module_type("MyCustomLinear", "column")
Note:
If the parallelism type cannot be determined, the mapping will raise
a descriptive error suggesting how to fix the issue.
"""
_MODULE_TYPE_REGISTRY: Dict[str, set] = {
"column": {
"ColumnParallelLinear",
"TEColumnParallelLinear",
"TELayerNormColumnParallelLinear",
"TEColumnParallelGroupedLinear",
"VocabParallelEmbedding",
"DotProductAttention",
"TEDotProductAttention",
},
"row": {
"RowParallelLinear",
"TERowParallelLinear",
"TERowParallelGroupedLinear",
},
"replicated": {
"FSDPCheckpointWrapper",
"FSDPHead",
"TENorm",
"FusedLayerNorm",
"WrappedTorchNorm",
"LayerNorm",
"RMSNorm",
"L2Norm",
"IdentityOp",
"TopKRouter",
"Conv3d",
"Linear",
"WanDiTBlock",
"Head",
},
}
@classmethod
def register_module_type(cls, module_name: str, parallelism_type: str):
"""Register a new module type for automatic parallelism detection.
Args:
module_name (str): The name of the module class (e.g.,
'MyColumnLinear').
parallelism_type (str): One of 'column', 'row', or 'replicated'.
"""
if parallelism_type not in cls._MODULE_TYPE_REGISTRY:
raise ValueError(
f"Invalid parallelism_type '{parallelism_type}'. "
f"Must be one of {list(cls._MODULE_TYPE_REGISTRY.keys())}"
)
cls._MODULE_TYPE_REGISTRY[parallelism_type].add(module_name)
def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None):
"""Initialize TP-aware mapping.
Args:
megatron_param (str): Megatron parameter name pattern.
hf_param (str): HuggingFace parameter name pattern.
permute_dims (Optional[Tuple[int, ...]]): Dimension permutation to apply.
If provided, the tensor will be permuted and made contiguous during conversion.
"""
super().__init__(megatron_param, hf_param)
self._detected_type: Optional[str] = None
self._mapping: Optional[MegatronParamMapping[torch.Tensor]] = None
self.permute_dims = permute_dims
def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[torch.Tensor]:
"""Get or create the appropriate mapping for the given type."""
if parallelism_type == "column":
return ColumnParallelMapping(self.megatron_param, self.hf_param)
elif parallelism_type == "row":
return RowParallelMapping(self.megatron_param, self.hf_param)
elif parallelism_type == "replicated":
return ReplicatedMapping(self.megatron_param, self.hf_param)
else:
raise ValueError(f"Unknown parallelism type: {parallelism_type}")
def _detect_parallelism_type(self, module: nn.Module) -> str:
"""Detect parallelism type from module."""
module_type = type(module).__name__
if module_type == "TELayerNormColumnParallelLinear":
if self.megatron_param and (
self.megatron_param.endswith("layer_norm_weight") or self.megatron_param.endswith("layer_norm_bias")
):
return "replicated"
return "column"
for parallelism, types in self._MODULE_TYPE_REGISTRY.items():
if module_type in types:
return parallelism
if hasattr(module, "tensor_model_parallel"):
if not module.tensor_model_parallel:
return "replicated"
partition_dim = getattr(module, "partition_dim", None)
if partition_dim == 0:
return "column"
elif partition_dim == 1:
return "row"
if any(norm in module_type for norm in ["Norm", "Normalization"]):
return "replicated"
if module_type == "TELinear":
if module.parallel_mode == "column":
return "column"
elif module.parallel_mode == "row":
return "row"
else:
return "replicated"
known_types = {p: sorted(list(t)) for p, t in self._MODULE_TYPE_REGISTRY.items()}
raise ValueError(
f"Cannot determine parallelism type for module '{module_type}' "
f"at weight '{self.megatron_param}'.\n"
f"Please use an explicit mapping type (e.g., ColumnParallelMapping) "
f"or register the module type using:\n"
f" AutoMapping.register_module_type('{module_type}', 'column|row|replicated')\n\n"
f"Currently known module types:\n{json.dumps(known_types, indent=2)}"
)
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Delegate to appropriate mapping based on module type."""
if self.permute_dims is not None and self.tp_rank == 0:
hf_weights = torch.permute(hf_weights, self.permute_dims).contiguous()
if self._mapping is None:
self._detected_type = self._detect_parallelism_type(megatron_module)
self._mapping = self._get_or_create_mapping(self._detected_type)
return self._mapping.hf_to_megatron(hf_weights, megatron_module)
def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
"""Create a new mapping with resolved wildcards, preserving permute_dims."""
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(resolved_megatron_param, resolved_hf_param, self.permute_dims)
class QKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
"""
Mapping for interleaved Query/Key/Value attention projection weights.
"""
def __init__(self, megatron_param: str, q: str, k: str, v: str):
"""Initialize QKV mapping.
Args:
megatron_param (str): Megatron QKV parameter name pattern.
q (str): Query weight name pattern.
k (str): Key weight name pattern.
v (str): Value weight name pattern.
"""
super().__init__(megatron_param, {"q": q, "k": k, "v": v})
self._tp_mapping = AutoMapping(megatron_param, megatron_param)
def hf_to_megatron(
self,
hf_weights: Dict[str, torch.Tensor],
megatron_module: nn.Module,
) -> torch.Tensor:
"""Merge Q, K, V into interleaved format and distribute."""
if self.tp_rank == 0:
config = self._get_config(megatron_module)
if hf_weights["q"].ndim == 1:
merged = merge_qkv_biases(config, hf_weights["q"], hf_weights["k"], hf_weights["v"])
else:
merged = merge_qkv_weights(config, hf_weights["q"], hf_weights["k"], hf_weights["v"])
else:
merged = None
return self._tp_mapping.hf_to_megatron(merged, megatron_module)
def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
"""Return a new *resolved* QKVMapping instance."""
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(
resolved_megatron_param,
resolved_hf_param["q"],
resolved_hf_param["k"],
resolved_hf_param["v"],
)
class KVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
"""
Mapping for interleaved Key/Value projection weights.
This mapping converts between separate K and V tensors used in external
checkpoints and Megatron's interleaved KV format following grouped-query
attention semantics.
External format (HF)
- Separate tensors: k_proj, v_proj
- Shapes mirror QKV mappings but without Q
Megatron format
- Single interleaved tensor with order: [k1, v1, k2, v2, ...]
where index corresponds to query-group id
Tensor-parallel distribution is delegated to AutoMapping.
"""
def __init__(self, megatron_param: str, k: str, v: str):
super().__init__(megatron_param, {"k": k, "v": v})
self._tp_mapping = AutoMapping(megatron_param, megatron_param)
def hf_to_megatron(
self,
hf_weights: Dict[str, torch.Tensor],
megatron_module: nn.Module,
) -> torch.Tensor:
"""Merge K and V into interleaved format and distribute across TP."""
if self.tp_rank == 0:
config = self._get_config(megatron_module)
if hf_weights["k"].ndim == 1:
merged = merge_kv_biases(config, hf_weights["k"], hf_weights["v"])
else:
merged = merge_kv_weights(config, hf_weights["k"], hf_weights["v"])
else:
merged = None
return self._tp_mapping.hf_to_megatron(merged, megatron_module)
class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
"""
Mapping for interleaved Query/Key/Value attention projection weights.
This mapping handles the conversion between Concatenated Q, K, V matrices used in
some transformers models and Megatron's optimized interleaved format. The
interleaving pattern groups queries with their corresponding key-value pairs
to maximize GEMM efficiency during attention computation.
**External format (HuggingFace)**
- One tensor with concatenated query, key, value: `qkv`, with shape
`[hidden_size, head_dim * num_heads + 2 * head_dim * num_query_groups]`
**Megatron format**
- Single interleaved tensor following grouped query attention (GQA) pattern
- Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]`
- Where `n = num_attention_heads / num_query_groups`
**Key features**
1. Format conversion: Handles merging/splitting with proper interleaving
2. Grouped Query Attention: Supports different numbers of Q and KV heads
3. Tensor parallelism: Delegates to AutoMapping for distribution
Example:
.. code-block:: python
# Create mapping for attention weights
mapping = QKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
qkv="model.layers.*.self_attn.qkv.weight",
)
# Convert from HuggingFace to Megatron
megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module)
# Convert from Megatron to HuggingFace
hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module)
Note:
This mapping automatically handles both regular multi-head attention
(same number of Q, K, V heads) and grouped query attention (fewer
KV heads than Q heads) based on the model configuration.
"""
def __init__(self, megatron_param: str, hf_param: str):
"""Initialize QKV mapping.
Args:
megatron_param (str): Megatron interleaved QKV parameter name pattern.
hf_param (str): HF concatenated QKV parameter name pattern.
"""
super().__init__(megatron_param, hf_param)
self._tp_mapping = AutoMapping(megatron_param, megatron_param)
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Merge Q, K, V into interleaved format and distribute."""
if self.tp_rank == 0:
config = self._get_config(megatron_module)
head_num = config.num_attention_heads
head_size = config.kv_channels
num_query_groups = config.num_query_groups
q, k, v = hf_weights.split(
[head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0
)
if q.ndim == 1:
merged = merge_qkv_biases(config, q, k, v)
else:
merged = merge_qkv_weights(config, q, k, v)
else:
merged = None
return self._tp_mapping.hf_to_megatron(merged, megatron_module)
def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
"""Return a new *resolved* QKVMapping instance."""
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(resolved_megatron_param, resolved_hf_param)
class VisionEncoderQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
"""
Mapping for vision encoder (ViT/ViT-like) attention projection weights.
This mapping handles the conversion between HuggingFace's vision encoder format
and Megatron's format for vision transformer models. The vision encoder typically
stores QKV weights in a grouped format that requires specific reshaping.
**HuggingFace format (vision encoder)**
- Weights: 5D tensor with shape [3, num_heads, -1, head_dim, hidden_size]
- Biases: 3D tensor with shape [3, num_heads, -1]
**Megatron format**
- Single concatenated tensor: [num_heads * 3 * head_dim, hidden_size] for weights
- Single concatenated tensor: [num_heads * 3 * head_dim] for biases
Note:
This is specifically for vision encoders (like ViT) and uses different
reshaping logic compared to text transformer QKV mappings.
"""
def __init__(self, megatron_param: str, hf_param: str):
"""Initialize vision encoder QKV mapping.
Args:
megatron_param (str): Megatron vision encoder QKV parameter name pattern.
hf_param (str): HF vision encoder QKV parameter name pattern.
vision_config (Optional[dict]): Vision encoder configuration containing
num_heads, hidden_size, etc. If None, will be extracted from module.
"""
super().__init__(megatron_param, hf_param)
self._tp_mapping = AutoMapping(megatron_param, megatron_param)
def hf_to_megatron(
self,
hf_weights: torch.Tensor,
megatron_module: nn.Module,
) -> torch.Tensor:
"""Convert HF vision encoder QKV weights to Megatron format."""
if self.tp_rank == 0:
config = self._get_config(megatron_module)
vision_num_heads = config.num_attention_heads
vision_hidden_size = config.hidden_size
vision_head_dim = vision_hidden_size // vision_num_heads
is_bias = hf_weights.ndim == 1
if is_bias:
in_shape = (3, vision_num_heads, -1)
else:
in_shape = (3, vision_num_heads, -1, vision_head_dim, vision_hidden_size)
q, k, v = hf_weights.view(*in_shape)
q = q.view(vision_num_heads, vision_head_dim, -1)
k = k.view(vision_num_heads, vision_head_dim, -1)
v = v.view(vision_num_heads, vision_head_dim, -1)
if is_bias:
qkv = torch.cat([q, k, v], dim=1).flatten()
else:
qkv = torch.cat([q, k, v], dim=1).reshape(-1, vision_hidden_size)
else:
qkv = None
return self._tp_mapping.hf_to_megatron(qkv, megatron_module)
def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
"""Return a new resolved VisionEncoderQKVMapping instance."""
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(resolved_megatron_param, resolved_hf_param)
class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU).
Checkpoint formats expose two independent matrices:
- **G** – gate projection
- **U** – up projection
Megatron concatenates them row-wise (`[G; U]`) so that a single GEMM can
produce both activations.
**Responsibilities handled by this mapping**
1. **Concatenate / split** – convert between `[G; U]` (Megatron) and the
separate `{G, U}` matrices (external).
2. **Tensor-parallel distribution** – correctly splits gate and up
projections separately before concatenating corresponding shards,
ensuring each TP rank gets the proper [gate_shard; up_shard] format.
**TP Distribution Strategy**
For tensor parallelism, this mapping:
- Splits gate and up matrices separately along output dimension (dim 0)
- Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i
- This ensures each rank's concatenated tensor matches the expected shape
"""
def __init__(self, megatron_param: str, gate: str, up: str):
"""Initialize gated MLP mapping.
Args:
megatron_param (str): Megatron MLP parameter name pattern.
gate (str): Gate projection weight name pattern.
up (str): Up projection weight name pattern.
"""
super().__init__(megatron_param, {"gate": gate, "up": up})
def hf_to_megatron(
self,
hf_weights: Dict[str, torch.Tensor],
megatron_module: nn.Module,
) -> torch.Tensor:
"""Split gate and up separately, then concatenate corresponding shards."""
if self.tp_size == 1:
return torch.cat([hf_weights["gate"], hf_weights["up"]], dim=0)
normalized_param = self._normalize_expert_param_name(self.megatron_param)
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)
if self.tp_rank == 0:
gate = hf_weights["gate"]
up = hf_weights["up"]
if gate.shape != up.shape:
raise ValueError(
f"Gate and up weights must have the same shape. Gate shape: {gate.shape}, Up shape: {up.shape}")
gate_output_size = gate.shape[0]
if gate_output_size % self.tp_size != 0:
raise ValueError(
f"Cannot evenly split gate dimension 0 size {gate_output_size} across {self.tp_size} TP ranks"
)
gate_splits = torch.chunk(gate, self.tp_size, dim=0)
up_splits = torch.chunk(up, self.tp_size, dim=0)
splits = [torch.cat([gate_splits[i], up_splits[i]], dim=0) for i in range(self.tp_size)]
else:
splits = None
return self.scatter_to_tp_ranks(
splits,
target_param.shape,
target_param.dtype,
target_param.device,
)
def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
"""Return a new *resolved* GatedMLPMapping instance."""
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(
resolved_megatron_param,
resolved_hf_param["gate"],
resolved_hf_param["up"],
)
class RMSNorm2ZeroCenteredRMSNormMapping(AutoMapping):
"""
Mapping for zero-centered RMSNorm to standard RMSNorm.
"""
def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor:
hf_weights.data = hf_weights.data - 1
return super().hf_to_megatron(hf_weights, megatron_module)
def merge_qkv_biases(config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format.
Args:
config (TransformerConfig): Transformer configuration.
q (torch.Tensor): Query projection biases [hidden_size].
k (torch.Tensor): Key projection biases [kv_hidden_size].
v (torch.Tensor): Value projection biases [kv_hidden_size].
Returns:
torch.Tensor: Interleaved QKV biases in Megatron format as 1D tensor.
"""
head_num = config.num_attention_heads
num_query_groups = config.num_query_groups
heads_per_group = head_num // num_query_groups
head_size = config.kv_channels or (config.hidden_size // head_num)
if getattr(config, "attention_output_gate", False):
q, z = torch.chunk(q.view(head_num, head_size * 2), 2, dim=-1)
else:
q = q.view(head_num, head_size)
k = k.view(num_query_groups, head_size)
v = v.view(num_query_groups, head_size)
qkv_biases = []
for i in range(num_query_groups):
qkv_biases.append(q[i * heads_per_group: (i + 1) * heads_per_group, :])
if getattr(config, "attention_output_gate", False):
qkv_biases.append(z[i * heads_per_group: (i + 1) * heads_per_group, :])
qkv_biases.append(k[i: i + 1, :])
qkv_biases.append(v[i: i + 1, :])
qkv = torch.cat(qkv_biases)
return qkv.flatten()
def merge_qkv_weights(provider: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Merge separate Q, K, V weight matrices into Megatron's interleaved QKV format.
Args:
provider (TransformerConfig): Model configuration provider.
q (torch.Tensor): Query projection weights [hidden_size, hidden_size] or
bias [hidden_size].
k (torch.Tensor): Key projection weights [kv_hidden_size, hidden_size]
or bias [kv_hidden_size].
v (torch.Tensor): Value projection weights [kv_hidden_size,
hidden_size] or bias [kv_hidden_size].
Returns:
torch.Tensor: Interleaved QKV weights in Megatron format.
"""
head_num = provider.num_attention_heads
num_query_groups = provider.num_query_groups
heads_per_group = head_num // num_query_groups
head_size = provider.kv_channels or (provider.hidden_size // head_num)
hidden_size = provider.hidden_size
is_bias = q.ndim == 1
q_head_size = head_size * 2 if getattr(provider, "attention_output_gate", False) else head_size
if is_bias:
q_reshaped = q.view(head_num, q_head_size)
k_reshaped = k.view(num_query_groups, head_size)
v_reshaped = v.view(num_query_groups, head_size)
else:
q_reshaped = q.view(head_num, q_head_size, hidden_size)
k_reshaped = k.view(num_query_groups, head_size, hidden_size)
v_reshaped = v.view(num_query_groups, head_size, hidden_size)
if getattr(provider, "attention_output_gate", False):
q_reshaped, z_reshaped = torch.chunk(q_reshaped, 2, dim=1)
qkv_weights = []
for i in range(num_query_groups):
q_group = q_reshaped[i * heads_per_group: (i + 1) * heads_per_group]
k_group = k_reshaped[i: i + 1]
v_group = v_reshaped[i: i + 1]
if getattr(provider, "attention_output_gate", False):
z_group = z_reshaped[i * heads_per_group: (i + 1) * heads_per_group]
qkv_weights.extend([q_group, z_group, k_group, v_group])
else:
qkv_weights.extend([q_group, k_group, v_group])
qkv = torch.cat(qkv_weights, dim=0)
if q.numel() + k.numel() + v.numel() != qkv.numel():
raise AssertionError(
f"QKV weights are not correctly merged, {q.shape=}, {k.shape=}, {v.shape=}, {qkv.shape=}"
)
if is_bias:
return qkv.reshape(-1)
else:
return qkv.reshape([-1, hidden_size])
def merge_kv_biases(config: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Merge separate K, V bias vectors into Megatron's interleaved KV format (1D)."""
num_query_groups = config.num_query_groups
head_size = config.kv_channels or (config.hidden_size // config.num_attention_heads)
k = k.view(num_query_groups, head_size)
v = v.view(num_query_groups, head_size)
pieces: List[torch.Tensor] = []
for i in range(num_query_groups):
pieces.append(k[i: i + 1, :])
pieces.append(v[i: i + 1, :])
kv = torch.cat(pieces, dim=0)
return kv.reshape(-1)
def merge_kv_weights(provider: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Merge separate K, V weights into Megatron's interleaved KV format (2D)."""
num_query_groups = provider.num_query_groups
head_size = provider.kv_channels or (provider.hidden_size // provider.num_attention_heads)
hidden_size = provider.hidden_size
k_reshaped = k.view(num_query_groups, head_size, hidden_size)
v_reshaped = v.view(num_query_groups, head_size, hidden_size)
pieces: List[torch.Tensor] = []
for i in range(num_query_groups):
pieces.append(k_reshaped[i: i + 1])
pieces.append(v_reshaped[i: i + 1])
kv = torch.cat(pieces, dim=0)
return kv.view(-1, hidden_size)