from typing import Optional, Union
import warnings
from dataclasses import dataclass
import torch
import megatron
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.attention import SelfAttention, Attention, AttnMaskType, SelfAttentionSubmodules
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from mindspeed_mm.models.vision.vision_encoders.qwen2vl_vit_model import Qwen2vlSelfAttention, Qwen2vlVitSelfAttention
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
@dataclass
class SplitQKVSelfAttentionSubmodules(SelfAttentionSubmodules):
q_proj: Union[ModuleSpec, type] = None
k_proj: Union[ModuleSpec, type] = None
v_proj: Union[ModuleSpec, type] = None
@dataclass
class SplitUpGateMLPSubmodules:
gate_proj: Union[ModuleSpec, type] = None
up_proj: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class PatchSplitQKVSelfAttention(Qwen2vlSelfAttention):
"""Implementation of Splitting QKV Self-Attention Layer, which only rewrites the logic related to QKV projection."""
def __init__(
self,
config: TransformerConfig,
submodules: SplitQKVSelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
if hasattr(self, 'linear_qkv'):
del self.linear_qkv
self.q_proj = build_module(
submodules.q_proj,
self.config.hidden_size,
self.query_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='q',
)
self.k_proj = build_module(
submodules.k_proj,
self.config.hidden_size,
self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='k',
)
self.v_proj = build_module(
submodules.v_proj,
self.config.hidden_size,
self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='v',
)
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""Rewrite the QKV tensor generation logic."""
if key_value_states is not None:
raise ValueError("Self-Attention does not support key_value_states")
query, query_bias = self.q_proj(hidden_states)
key, key_bias = self.k_proj(hidden_states)
value, value_bias = self.v_proj(hidden_states)
new_query_shape = query.size()[:-1] + (
self.num_query_groups_per_partition * (
self.num_attention_heads_per_partition // self.num_query_groups_per_partition),
self.hidden_size_per_attention_head,
)
query = query.view(*new_query_shape)
new_kv_shape = key.size()[:-1] + (
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head
)
key = key.view(*new_kv_shape)
value = value.view(*new_kv_shape)
if self.q_layernorm is not None:
query = self.q_layernorm(query)
if self.k_layernorm is not None:
key = self.k_layernorm(key)
if self.config.test_mode:
self.run_realtime_tests()
return query, key, value
class PatchSplitGateUpMLP(MegatronModule):
"""Implementation of Splitting gate_proj and up_proj Layer"""
def __init__(
self,
config: TransformerConfig,
submodules: SplitUpGateMLPSubmodules,
is_expert: bool = False,
input_size: Optional[int] = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size else self.config.hidden_size
if is_expert and self.config.moe_ffn_hidden_size:
ffn_hidden_size = self.config.moe_ffn_hidden_size
else:
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
split_hidden_size = int(ffn_hidden_size // 2)
self.activation_func = self.config.activation_func
self.gate_proj = build_module(
submodules.gate_proj,
self.input_size,
split_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='gate',
)
self.up_proj = build_module(
submodules.up_proj,
self.input_size,
split_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='up',
)
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states, per_token_scale=None):
"""Rewrite the forward propagation using independent gate_proj and up_proj. """
gate_parallel, gate_bias = self.gate_proj(hidden_states)
up_parallel, up_bias = self.up_proj(hidden_states)
if self.config.add_bias_linear:
bias_parallel = torch.cat([gate_bias, up_bias], dim=-1) if gate_bias is not None else None
else:
bias_parallel = None
if self.config.bias_activation_fusion:
if per_token_scale is not None:
if self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_combined = torch.cat([gate_parallel, up_parallel], dim=-1)
intermediate_parallel = weighted_bias_swiglu_impl(
intermediate_combined,
bias_parallel,
per_token_scale.unsqueeze(-1),
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of swiglu with per_token_scale in MLP.")
else:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_combined = torch.cat([gate_parallel, up_parallel], dim=-1)
intermediate_parallel = bias_geglu_impl(intermediate_combined, bias_parallel)
elif self.config.add_bias_linear:
intermediate_parallel = bias_gelu_impl(up_parallel, up_bias)
else:
raise ValueError("Only support gated_linear_unit or add_bias_linear in gelu.")
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_combined = torch.cat([gate_parallel, up_parallel], dim=-1)
intermediate_parallel = bias_swiglu_impl(
intermediate_combined,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if self.config.add_bias_linear and gate_bias is not None:
gate_parallel = gate_parallel + gate_bias
up_parallel = up_parallel + up_bias
if self.config.gated_linear_unit:
intermediate_parallel = self.activation_func(gate_parallel) * up_parallel
else:
intermediate_parallel = self.activation_func(up_parallel)
if per_token_scale is not None:
original_dtype = intermediate_parallel.dtype
intermediate_parallel = intermediate_parallel * per_token_scale.unsqueeze(-1)
intermediate_parallel = intermediate_parallel.to(original_dtype)
output, output_bias = self.linear_fc2(intermediate_parallel)
if per_token_scale is not None:
if output_bias is not None:
raise ValueError("Bias is not supported with per_token_scale")
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""Rewrite the shard state dictionary to adapt to independent gate_proj and up_proj. """
sharded_state_dict = {}
for name, module in self._modules.items():
if name in ['gate_proj', 'up_proj']:
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
sharded_state_dict.update(sub_sd)
else:
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
class PatchViTSelfAttention(Qwen2vlVitSelfAttention):
"""Implementation of non-interleaved QKV Self-Attention Layer in ViT, which only rewrites the logic related to QKV projection."""
def __init__(
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type
)
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""Derives `query`, `key` and `value` tensors from `hidden_states` using non-interleaved weight"""
mixed_qkv, _ = self.linear_qkv(hidden_states)
sq, b, h = hidden_states.shape
query, key, value = (
mixed_qkv.reshape(sq, b, 3, self.num_attention_heads_per_partition, -1).permute(2, 0, 1, 3, 4).unbind(0)
)
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
if self.q_layernorm is not None:
query = self.q_layernorm(query)
if self.k_layernorm is not None:
key = self.k_layernorm(key)
if self.config.test_mode:
self.run_realtime_tests()
return query, key, value
def _patch_get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
):
warnings.warn(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)
return get_patch_mlp_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_patch_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Rewrite helper function to get patch solit gate/up module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
return ModuleSpec(
module=PatchSplitGateUpMLP,
submodules=SplitUpGateMLPSubmodules(
gate_proj=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
up_proj=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)