from typing import Optional, Union
import warnings
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import megatron
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.attention import SelfAttention, Attention, AttnMaskType, SelfAttentionSubmodules
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from mindspeed_mm.models.vision.vision_encoders.qwen2vl_vit_model import Qwen2vlSelfAttention, Qwen2vlVitSelfAttention
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
@dataclass
class SplitQKVSelfAttentionSubmodules(SelfAttentionSubmodules):
q_proj: Union[ModuleSpec, type] = None
k_proj: Union[ModuleSpec, type] = None
v_proj: Union[ModuleSpec, type] = None
@dataclass
class SplitUpGateMLPSubmodules:
gate_proj: Union[ModuleSpec, type] = None
up_proj: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class PatchSplitQKVSelfAttention(Qwen2vlSelfAttention):
"""Implementation of Splitting QKV Self-Attention Layer, which only rewrites the logic related to QKV projection."""
def __init__(
self,
config: TransformerConfig,
submodules: SplitQKVSelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
if hasattr(self, 'linear_qkv'):
del self.linear_qkv
self.q_proj = build_module(
submodules.q_proj,
self.config.hidden_size,
self.query_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='q',
)
self.k_proj = build_module(
submodules.k_proj,
self.config.hidden_size,
self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='k',
)
self.v_proj = build_module(
submodules.v_proj,
self.config.hidden_size,
self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='v',
)
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""Rewrite the QKV tensor generation logic."""
if key_value_states is not None:
raise ValueError("Self-Attention does not support key_value_states")
query, query_bias = self.q_proj(hidden_states)
key, key_bias = self.k_proj(hidden_states)
value, value_bias = self.v_proj(hidden_states)
new_query_shape = query.size()[:-1] + (
self.num_query_groups_per_partition * (
self.num_attention_heads_per_partition // self.num_query_groups_per_partition),
self.hidden_size_per_attention_head,
)
query = query.view(*new_query_shape)
new_kv_shape = key.size()[:-1] + (
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head
)
key = key.view(*new_kv_shape)
value = value.view(*new_kv_shape)
if self.q_layernorm is not None:
query = self.q_layernorm(query)
if self.k_layernorm is not None:
key = self.k_layernorm(key)
if self.config.test_mode:
self.run_realtime_tests()
return query, key, value
class PatchSplitGateUpMLP(MegatronModule):
"""Implementation of Splitting gate_proj and up_proj Layer"""
def __init__(
self,
config: TransformerConfig,
submodules: SplitUpGateMLPSubmodules,
is_expert: bool = False,
input_size: Optional[int] = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size else self.config.hidden_size
if is_expert and self.config.moe_ffn_hidden_size:
ffn_hidden_size = self.config.moe_ffn_hidden_size
else:
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
split_hidden_size = int(ffn_hidden_size // 2)
self.activation_func = self.config.activation_func
self.gate_proj = build_module(
submodules.gate_proj,
self.input_size,
split_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='gate',
)
self.up_proj = build_module(
submodules.up_proj,
self.input_size,
split_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='up',
)
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states, per_token_scale=None):
"""Rewrite the forward propagation using independent gate_proj and up_proj. """
gate_parallel, gate_bias = self.gate_proj(hidden_states)
up_parallel, up_bias = self.up_proj(hidden_states)
if self.config.add_bias_linear:
bias_parallel = torch.cat([gate_bias, up_bias], dim=-1) if gate_bias is not None else None
else:
bias_parallel = None
if self.config.bias_activation_fusion:
if per_token_scale is not None:
if self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_combined = torch.cat([gate_parallel, up_parallel], dim=-1)
intermediate_parallel = weighted_bias_swiglu_impl(
intermediate_combined,
bias_parallel,
per_token_scale.unsqueeze(-1),
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of swiglu with per_token_scale in MLP.")
else:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_combined = torch.cat([gate_parallel, up_parallel], dim=-1)
intermediate_parallel = bias_geglu_impl(intermediate_combined, bias_parallel)
elif self.config.add_bias_linear:
intermediate_parallel = bias_gelu_impl(up_parallel, up_bias)
else:
raise ValueError("Only support gated_linear_unit or add_bias_linear in gelu.")
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_combined = torch.cat([gate_parallel, up_parallel], dim=-1)
intermediate_parallel = bias_swiglu_impl(
intermediate_combined,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if self.config.add_bias_linear and gate_bias is not None:
gate_parallel = gate_parallel + gate_bias
up_parallel = up_parallel + up_bias
if self.config.gated_linear_unit:
intermediate_parallel = self.activation_func(gate_parallel) * up_parallel
else:
intermediate_parallel = self.activation_func(up_parallel)
if per_token_scale is not None:
original_dtype = intermediate_parallel.dtype
intermediate_parallel = intermediate_parallel * per_token_scale.unsqueeze(-1)
intermediate_parallel = intermediate_parallel.to(original_dtype)
output, output_bias = self.linear_fc2(intermediate_parallel)
if per_token_scale is not None:
if output_bias is not None:
raise ValueError("Bias is not supported with per_token_scale")
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""Rewrite the shard state dictionary to adapt to independent gate_proj and up_proj. """
sharded_state_dict = {}
for name, module in self._modules.items():
if name in ['gate_proj', 'up_proj']:
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
sharded_state_dict.update(sub_sd)
else:
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
class PatchViTSelfAttention(Qwen2vlVitSelfAttention):
"""Implementation of non-interleaved QKV Self-Attention Layer in ViT, which only rewrites the logic related to QKV projection."""
def __init__(
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type
)
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""Derives `query`, `key` and `value` tensors from `hidden_states` using non-interleaved weight"""
mixed_qkv, _ = self.linear_qkv(hidden_states)
sq, b, h = hidden_states.shape
query, key, value = (
mixed_qkv.reshape(sq, b, 3, self.num_attention_heads_per_partition, -1).permute(2, 0, 1, 3, 4).unbind(0)
)
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
if self.q_layernorm is not None:
query = self.q_layernorm(query)
if self.k_layernorm is not None:
key = self.k_layernorm(key)
if self.config.test_mode:
self.run_realtime_tests()
return query, key, value
def _patch_get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
):
warnings.warn(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)
return get_patch_mlp_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_patch_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Rewrite helper function to get patch solit gate/up module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
return ModuleSpec(
module=PatchSplitGateUpMLP,
submodules=SplitUpGateMLPSubmodules(
gate_proj=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
up_proj=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)