import math
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn.functional as F
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
from mindspeed.utils import set_actual_seq_len, set_position_ids, get_actual_seq_len, get_position_ids
try:
from mindspeed.core.pipeline_parallel.fb_overlap.modules.attention import launch_async_all2all_hook, launch_async_all2all
from .mla_up_proj_overlap_tp_comm import mla_up_projection_overlap_tp_comm
except ImportError:
pass
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import TransformerConfig, ModuleSpec, build_module
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core import mpu, parallel_state, tensor_parallel
from megatron.training import get_args
try:
import bitsandbytes as bnb
except ImportError:
bnb = None
@dataclass
class MLASelfAttentionSubmodules(SelfAttentionSubmodules):
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None
linear_qb: Union[ModuleSpec, type] = None
linear_kvb: Union[ModuleSpec, type] = None
@dataclass
class MLASelfAttentionWithMMSplitSubmodules(SelfAttentionSubmodules):
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None
linear_qk_nope: Union[ModuleSpec, type] = None
linear_kv_nope: Union[ModuleSpec, type] = None
linear_qk_rope: Union[ModuleSpec, type] = None
linear_v: Union[ModuleSpec, type] = None
class MultiHeadLatentAttention(SelfAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: TransformerConfig,
submodules: MLASelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
args = get_args()
self.use_flash_attn = args.use_flash_attn
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_nope_head_dim = args.qk_nope_head_dim
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.v_head_dim = args.v_head_dim
self.mla_mm_split = self.config.mla_mm_split
self.mla_fa_without_pad = self.config.mla_fa_without_pad
self.padded_base_length = self.config.padded_base_length
self.shape_order = self.config.shape_order
query_projection_size = self.config.num_attention_heads * self.v_head_dim
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
max_dim = max(self.v_head_dim, self.q_head_dim)
self.fa_padding_length = math.ceil(max_dim / self.padded_base_length) * self.padded_base_length
if self.q_lora_rank is None:
self.q_rank = self.config.num_attention_heads * self.q_head_dim
self.q_layernorm = None
else:
self.q_rank = self.q_lora_rank
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.q_lora_rank,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if not self.mla_mm_split:
self.linear_qb = build_module(
submodules.linear_qb,
self.q_lora_rank,
self.config.num_attention_heads * self.q_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="qb",
)
else:
self.linear_qk_nope = build_module(
submodules.linear_qk_nope,
self.q_lora_rank,
self.config.num_attention_heads * self.qk_nope_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="qk_nope",
)
self.linear_qk_rope = build_module(
submodules.linear_qk_rope,
self.q_lora_rank,
self.config.num_attention_heads * self.qk_rope_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="qk_rope",
)
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.q_rank + self.kv_lora_rank + self.qk_rope_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="qkv",
)
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.kv_lora_rank,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
if not self.mla_mm_split:
self.linear_kvb = build_module(
submodules.linear_kvb,
self.kv_lora_rank,
self.config.num_attention_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="kvb",
)
else:
self.linear_kv_nope = build_module(
submodules.linear_kv_nope,
self.kv_lora_rank,
self.config.num_attention_heads * self.qk_nope_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="kv_nope",
)
self.linear_v = build_module(
submodules.linear_v,
self.kv_lora_rank,
self.config.num_attention_heads * self.v_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="v",
)
self.linear_proj = build_module(
submodules.linear_proj,
query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name="proj",
)
if args.moe_fb_overlap and parallel_state.get_tensor_model_parallel_world_size() > 1:
self.a2a_hooked_on_attention = True
else:
self.a2a_hooked_on_attention = False
self.mla_up_proj_tp_overlap = self.config.mla_up_proj_tp_overlap
self.recompute_mla_up_proj = self.config.recompute_mla_up_proj
self.mla_zero_memory = self.config.mla_zero_memory
def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
inference_params=None,
):
"""
Do patch for repeating KV so that GQA+Ulysses is better supported.
"""
args = get_args()
def mla_attention(hidden_states):
args = get_args()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
nonlocal rotary_pos_emb
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
q_len, bsz, _ = hidden_states.shape
q_len = q_len * tp_size if self.config.sequence_parallel else q_len
qkv_combo = self.linear_qkv(hidden_states)
q_a, compressed_kv, k_pe = torch.split(
qkv_combo,
[
self.q_rank,
self.kv_lora_rank,
self.qk_rope_head_dim,
],
dim=-1,
)
if self.mla_up_proj_tp_overlap:
query, key, value = mla_up_projection_overlap_tp_comm(q_a, compressed_kv, k_pe, rotary_pos_emb,
packed_seq_params, self)
else:
if self.q_layernorm is not None:
q_a = self.q_layernorm(q_a)
if not self.mla_mm_split:
q, _ = self.linear_qb(q_a)
q = q.view(q_len, bsz, self.num_attention_heads_per_partition, -1)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
q_nope, _ = self.linear_qk_nope(q_a)
q_pe, _ = self.linear_qk_rope(q_a)
q_nope = q_nope.view(
q_len, bsz, self.num_attention_heads_per_partition, -1
)
q_pe = q_pe.view(q_len, bsz, self.num_attention_heads_per_partition, -1)
else:
if self.config.sequence_parallel:
q_a = gather_from_sequence_parallel_region(q_a)
q_a = tensor_parallel.scatter_to_tensor_model_parallel_region(q_a)
q = q_a.view(q_len, bsz, self.num_attention_heads_per_partition, -1)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
if self.config.sequence_parallel:
k_pe = gather_from_sequence_parallel_region(k_pe)
else:
k_pe = tensor_parallel.copy_to_tensor_model_parallel_region(k_pe)
k_pe = k_pe.view(q_len, bsz, 1, self.qk_rope_head_dim)
compressed_kv_norm = self.k_layernorm(compressed_kv)
if not self.mla_mm_split:
kv, _ = self.linear_kvb(compressed_kv_norm)
kv = kv.view(
q_len,
bsz,
self.num_attention_heads_per_partition,
self.qk_nope_head_dim + self.v_head_dim,
)
k_nope, value = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
else:
k_nope, _ = self.linear_kv_nope(compressed_kv_norm)
value, _ = self.linear_v(compressed_kv_norm)
k_nope = k_nope.view(q_len, bsz, self.num_attention_heads_per_partition, -1)
value = value.view(q_len, bsz, self.num_attention_heads_per_partition, -1)
if self.a2a_hooked_on_attention:
launch_async_all2all()
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
b, h, s, d = q_pe.shape
q_pe = q_pe.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k_pe.shape
k_pe = k_pe.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
tmp_multi_latent_attention = self.config.multi_latent_attention
self.config.multi_latent_attention = False
q_pe = apply_rotary_pos_emb(q_pe, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)
k_pe = apply_rotary_pos_emb(k_pe, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
self.config.multi_latent_attention = tmp_multi_latent_attention
del tmp_multi_latent_attention
query = torch.cat([q_nope, q_pe], dim=-1)
k_pe = k_pe.expand(k_pe.shape[0], k_pe.shape[1], query.shape[2], k_pe.shape[3])
key = torch.cat([k_nope, k_pe], dim=-1)
if (
self.use_flash_attn
and self.q_head_dim != self.v_head_dim
and not self.mla_fa_without_pad
):
if self.shape_order == "BNSD":
value = F.pad(value, [0, self.q_head_dim - self.v_head_dim])
else:
query = F.pad(query, [0, self.fa_padding_length - self.q_head_dim])
key = F.pad(key, [0, self.fa_padding_length - self.q_head_dim])
value = F.pad(value, [0, self.fa_padding_length - self.v_head_dim])
args = get_args()
should_kv_repeat_before_uly = (
args.context_parallel_size > 1
and args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"]
and args.kv_head_repeat_before_uly_alltoall
)
heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition
if should_kv_repeat_before_uly and heads_per_gqa_group > 1:
key = key.repeat_interleave(heads_per_gqa_group, dim=2)
value = value.repeat_interleave(heads_per_gqa_group, dim=2)
attn_mask_type = AttnMaskType.causal
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
if self.recompute_mla_up_proj and core_attn_out.requires_grad:
self.recompute_mla_up_proj_ckpt.discard_output()
core_attn_out.register_hook(self.recompute_mla_up_proj_ckpt.recompute)
if packed_seq_params is not None:
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
if self.use_flash_attn and not self.mla_fa_without_pad:
core_attn_out = core_attn_out.view(q_len, bsz, self.num_attention_heads_per_partition, -1)
core_attn_out = core_attn_out[:, :, :, : self.v_head_dim]
core_attn_out = core_attn_out.reshape(q_len, bsz, self.num_attention_heads_per_partition * self.v_head_dim)
return core_attn_out
if self.mla_zero_memory:
self.mla_checkpoint_manager = CheckpointWithoutOutput()
core_attn_out = self.mla_checkpoint_manager.checkpoint(mla_attention,
False,
hidden_states)
if args.reset_position_ids:
self.mla_checkpoint_manager.ctx.actual_len = get_actual_seq_len()
self.mla_checkpoint_manager.ctx.position_id = get_position_ids()
else:
core_attn_out = mla_attention(hidden_states)
if self.a2a_hooked_on_attention and core_attn_out.requires_grad:
core_attn_out.register_hook(launch_async_all2all_hook)
output, bias = self.linear_proj(core_attn_out)
if self.mla_zero_memory:
self.mla_checkpoint_manager.discard_output()
if output.requires_grad:
if args.reset_position_ids:
output.register_hook(recompute_mla(self.mla_checkpoint_manager))
else:
output.register_hook(self.mla_checkpoint_manager.recompute)
return output, bias
def recompute_mla(mla_checkpoint_manager):
"""
recompute_mla when reset_position_ids is enabled.
"""
def hook_fn(grad):
actual_seq_len = getattr(mla_checkpoint_manager.ctx, "actual_len", None)
position_ids = getattr(mla_checkpoint_manager.ctx, "position_id", None)
change_pos_id = False
if position_ids is not None:
change_pos_id = True
old_position_id = get_position_ids()
set_position_ids(position_ids)
change_seq_len = False
if actual_seq_len is not None:
change_seq_len = True
old_actual_seq_len = get_actual_seq_len()
set_actual_seq_len(actual_seq_len)
mla_checkpoint_manager.recompute(grad)
if change_pos_id:
set_position_ids(old_position_id)
if change_seq_len:
set_actual_seq_len(old_actual_seq_len)
return hook_fn
class LinearNoTP(torch.nn.Linear):
def __init__(
self,
input_size,
output_size,
config,
**kwargs,
):
super().__init__(
input_size,
output_size,
bias=kwargs.get('bias', True),
dtype=config.params_dtype,
)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
self.config = config
current_seed = torch.random.initial_seed()
torch.manual_seed(123)
torch.nn.init.xavier_uniform_(self.weight)
torch.random.manual_seed(current_seed)
def forward(self, input_):
if hasattr(self.weight, "quant_state"):
output = bnb.matmul_4bit(input_, self.weight.t(), self.weight.quant_state, bias=self.bias)
else:
output = torch.matmul(input_, self.weight.t())
return output
def _save_to_state_dict(self, destination, prefix, keep_vars):
args = get_args()
if hasattr(args, "qlora_save_dequantize") and args.qlora_save_dequantize and getattr(self.weight, "quant_state", None) is not None:
self.weight = torch.nn.Parameter(bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state))
super()._save_to_state_dict(destination, prefix, keep_vars)
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if any(['bitsandbytes' in i for i in state_dict.keys()]):
qs_dict = {key: v for k, v in state_dict.items() if (key := k.replace(prefix, "")) != '_extra_state'}
self.weight = bnb.nn.Params4bit.from_prequantized(
data=qs_dict.get('weight'),
quantized_stats={key.replace('weight.', ''): qs_dict[key] for key in qs_dict if key != 'weight'},
requires_grad=False,
device='npu')
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)