import math
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn.functional as F
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer import TransformerConfig, ModuleSpec, build_module
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.training import get_args
try:
import bitsandbytes as bnb
except ImportError:
bnb = None
@dataclass
class CustomQwen3NextSelfAttentionSubmodules(SelfAttentionSubmodules):
"""Submodules for the Qwen3Next self-attention layer with NPU."""
q_proj: Union[ModuleSpec, type] = None
k_proj: Union[ModuleSpec, type] = None
v_proj: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None
class CustomQwen3NextSelfAttention(SelfAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: TransformerConfig,
submodules: CustomQwen3NextSelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
self.config = config
self.use_flash_attn = config.use_flash_attn
self.shape_order = config.shape_order
self.layer_number = layer_number
self.head_dim = getattr(config, "kv_channels", config.hidden_size // config.num_attention_heads)
query_projection_size = self.config.num_attention_heads * self.head_dim
self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = build_module(
submodules.q_proj,
self.config.hidden_size,
self.config.num_attention_heads * self.head_dim * 2,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="q_proj",
)
self.k_proj = build_module(
submodules.k_proj,
self.config.hidden_size,
config.num_query_groups * self.head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="k_proj",
)
self.v_proj = build_module(
submodules.v_proj,
self.config.hidden_size,
config.num_query_groups * self.head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="v_proj",
)
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.head_dim,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.head_dim,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
self.linear_proj = build_module(
submodules.linear_proj,
query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name="proj",
)
def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
*,
inference_params=None,
):
"""
Do patch for repeating KV so that GQA+Ulysses is better supported.
"""
args = get_args()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
query_states, gate = torch.chunk(
self.q_proj(hidden_states)[0].view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query = self.q_layernorm(query_states.view(hidden_shape))
key = self.k_layernorm(self.k_proj(hidden_states)[0].view(hidden_shape))
value = self.v_proj(hidden_states)[0].view(hidden_shape)
if rotary_pos_emb is not None:
rotary_q_pos_emb, rotary_k_pos_emb = rotary_pos_emb
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(query, rotary_q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)
key = apply_rotary_pos_emb(key, rotary_k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
should_kv_repeat_before_uly = (
args.context_parallel_size > 1
and args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"]
and args.kv_head_repeat_before_uly_alltoall
)
heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition
if should_kv_repeat_before_uly and heads_per_gqa_group > 1:
key = key.repeat_interleave(heads_per_gqa_group, dim=2)
value = value.repeat_interleave(heads_per_gqa_group, dim=2)
attn_mask_type = AttnMaskType.causal
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=None,
packed_seq_params=packed_seq_params,
)
core_attn_out = core_attn_out.reshape(*input_shape, -1).contiguous()
core_attn_out = core_attn_out * torch.sigmoid(gate)
core_attn_out = self.linear_proj(core_attn_out)
return core_attn_out