from dataclasses import dataclass
from typing import Tuple, Union
import torch
from torch import Tensor
import torch.nn.functional as F
from einops import rearrange
from functools import wraps
import torch_npu
from megatron.core import parallel_state
from megatron.training import get_args
from megatron.legacy.model import RMSNorm
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer import TransformerConfig, ModuleSpec, build_module, MegatronModule
from megatron.core import mpu
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from mindspeed_llm.core.tensor_parallel.layers import LinearNoTP
from mindspeed_llm.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb_bshd_in_complex
from mindspeed_llm.core.context_parallel.kvallgather_context_parallel import gather_from_sp_cp, permute_cp_shard
from mindspeed_llm.tasks.models.transformer.deepseek4.compressor import get_compressor_spec
from mindspeed_llm.tasks.models.transformer.deepseek4.deepseek_utils import rotate_activation, apply_rotary_emb
from mindspeed.te.pytorch.attention.dot_product_attention.kvallgather_context_parallel import (
get_distributed_world_size,
get_seq_chunk_ids_for_reordering_before_attn,
)
try:
import mindspeed.ops.npu_lightning_indexer as mindspeed_li
except ImportError:
mindspeed_li = None
@dataclass
class DSAIndexerSubmodules:
wq_b: Union[ModuleSpec, type] = None
wk: Union[ModuleSpec, type] = None
weights_proj: Union[ModuleSpec, type] = None
k_norm: Union[ModuleSpec, type] = None
compressor: Union[ModuleSpec, type] = None
def get_dsa_indexer_spec(enable_dsa_indexer, compressor=None):
"""Helper function to get module spec for dsa_indexer"""
if enable_dsa_indexer:
return ModuleSpec(
module=DSAIndexer,
submodules=DSAIndexerSubmodules(
wq_b=LinearNoTP,
wk=LinearNoTP,
weights_proj=LinearNoTP,
compressor=get_compressor_spec() if compressor else IdentityOp,
),
)
else:
return IdentityOp
def norm2fp32_fp16module_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
for _, param in self.module.named_modules():
if isinstance(param, (RMSNorm, LayerNorm)):
param.weight.data = param.weight.data.to(torch.float32)
if hasattr(param, 'bias') and param.bias is not None:
param.bias.data = param.bias.data.to(torch.float32)
return wrapper
def bf16_index(q: torch.Tensor, weights: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
"""
Perform index score using BF16 precision.
Args:
q(torch.Tensor): query tensor of shape [S, B, N, D]
weights(torch.Tensor): weights tensor of shape [S, B, Di, 1]
k(torch.Tensor): key tensor of shape [S, B, N, D]
bf16 q bf16 k -> fp32 q fp32 k
q @ k -> fp32 logits
relu(fp32 logits) * weights -> fp32 logits
sum(fp32 logits) -> fp32 index_score
"""
query = rearrange(q, 's b h d -> b h s d').to(torch.float32)
key = rearrange(k, 's b h d -> b h d s').to(torch.float32)
p = torch.matmul(query, key)
relu_out = torch.nn.functional.relu(p)
weight_out = relu_out * weights.permute(1, 2, 0, 3)
reduce_out = torch.sum(weight_out, dim=1)
return reduce_out
class LayerNorm(torch.nn.Module):
"""
Layer Normalization in DSAIndexer.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = torch.nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
class DSAIndexer(MegatronModule):
"""
An indexing module that computes sparse attention scores using learned queries and keys,
with optional rotary positional embeddings and structured projection (e.g., via Hadamard rotation).
This module is designed for efficient long-sequence attention by selecting top-k relevant tokens
based on a learned similarity score, enabling sparse attention patterns.
"""
def __init__(self, config: TransformerConfig, submodules: DSAIndexerSubmodules, layer_number: int):
super().__init__(config=config)
args = get_args()
self.dim: int = args.hidden_size
self.n_heads: int = args.index_n_heads
self.n_local_heads = args.index_n_heads
self.head_dim: int = args.index_head_dim
self.rope_head_dim: int = args.qk_pos_emb_head_dim
self.index_topk: int = args.index_topk
self.q_lora_rank: int = args.q_lora_rank
self.kv_compress: bool = args.kv_compress
self.use_fused_lightning_indexer: bool = args.use_fused_lightning_indexer
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = args.scale_fmt
self.wq_b = build_module(
submodules.wq_b,
self.q_lora_rank,
self.n_heads * self.head_dim,
config=self.config,
init_method=self.config.init_method,
bias=False,
)
if not self.kv_compress:
self.wk = build_module(
submodules.wk,
self.dim,
self.head_dim,
config=self.config,
init_method=self.config.init_method,
bias=False,
)
self.k_norm = LayerNorm(self.head_dim)
else:
self.compress_ratio = args.compress_ratios[layer_number - 1]
self.kv_compressor = build_module(
submodules.compressor,
config=self.config,
compress_ratio=self.compress_ratio,
head_dim=self.head_dim,
rotate=True,
)
self.weights_proj = build_module(
submodules.weights_proj,
self.dim,
self.n_heads,
config=self.config,
init_method=self.config.init_method,
bias=False,
)
def forward(
self,
x: torch.Tensor,
qr: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask=None,
packed_seq_params=None,
offset=None,
):
"""
Forward pass of the dsa_indexer module.
Args:
x (torch.Tensor): Input activations of shape [seq_len, batch_size, hidden_size].
qr (torch.Tensor): Low-rank query input of shape [seq_len, batch_size, q_lora_rank].
start_pos (int): Starting position in the sequence.
freqs_cis (tuple): Rotary positional embedding frequencies for queries and keys,
shape:[seq_len, batch_size, 1, qk_pos_emb_head_dim].
mask (torch.Tensor, optional): Attention mask.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence processing.
"""
if not self.kv_compress:
q, k, weights, x = self.forward_with_index(x, qr, freqs_cis)
topk_indices, topk_score = self.forward_with_scores(
x, q, k, weights, mask, packed_seq_params, start_pos, self.index_topk
)
s1, b, _ = x.size()
s2 = s1
attention_mask = self.generate_sparse_mask(topk_indices, mask, (b, s1, s2), x.dtype, x.device)
else:
q, k, weights, x = self.forward_with_index_compress(x, qr, start_pos, freqs_cis)
q, k, weights = self.all_gather_qk_weight(q, k, weights)
topk_indices, topk_score = self.forward_with_scores_compress(
x, q, k, weights, mask, packed_seq_params, start_pos, self.index_topk, offset, self.compress_ratio
)
topk_indices, topk_score = self.post_process_index(topk_indices, topk_score)
b, s1, _ = topk_indices.size()
s2 = k.size(0)
attention_mask = self.generate_sparse_mask_compress(
topk_indices, mask, (b, s1, s2), x.dtype, x.device, offset
)
return topk_score, topk_indices, attention_mask
def forward_with_index(self, x: Tensor, qr: Tensor, freqs_cis: Tensor):
args = get_args()
rotary_q_pos_emb, rotary_k_pos_emb = freqs_cis
q = self.wq_b(qr)
q = rearrange(q, 's b (h d) -> s b h d', d=self.head_dim)
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
q_pe = apply_rotary_pos_emb_bshd_in_complex(q_pe, rotary_q_pos_emb, rotary_interleaved=True)
q = torch.cat([q_pe, q_nope], dim=-1)
k = self.wk(x)
k = self.k_norm(k)
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
k_pe = k_pe.unsqueeze(2)
s, b, n, d = k_pe.shape
k_pe = apply_rotary_pos_emb_bshd_in_complex(k_pe, rotary_k_pos_emb, rotary_interleaved=True).view(s, b, d)
k = torch.cat([k_pe, k_nope], dim=-1).unsqueeze(2)
if args.context_parallel_size > 1 and args.context_parallel_algo == 'ulysses_cp_algo':
k = gather_from_sequence_parallel_region(k, group=mpu.get_context_parallel_group())
q = gather_from_sequence_parallel_region(q, group=mpu.get_context_parallel_group())
x = gather_from_sequence_parallel_region(x, group=mpu.get_context_parallel_group())
q = rotate_activation(q)
k = rotate_activation(k)
weights = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
weights = weights * self.softmax_scale
return q, k, weights, x
@staticmethod
def forward_with_scores(x, q, k, weights, mask, packed_seq_params, start_pos, index_topk):
args = get_args()
if args.use_fused_lightning_indexer:
if args.context_parallel_size > 1 and args.context_parallel_algo == 'kvallgather_cp_algo':
topk_indices, topk_score = fused_lightning_indexer_kvallgather(
q,
k,
weights,
index_topk,
actual_seq_qlen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_q,
actual_seq_klen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_kv,
layout_query='BSND',
layout_key='BSND',
)
else:
topk_indices, topk_score = fused_lightning_indexer(
q,
k,
weights,
index_topk,
actual_seq_qlen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_q,
actual_seq_klen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_kv,
layout_query='BSND',
layout_key='BSND',
)
else:
s1, b, _ = x.size()
s2 = k.size(0)
end_pos = start_pos + s2
index_score = bf16_index(q.contiguous(), weights.unsqueeze(-1), k.contiguous())
if mask is None:
mask = torch.where(
torch.triu(torch.ones((b, s1, s2), dtype=x.dtype, device=x.device), diagonal=1) == 1,
float('-inf'),
0.0,
)
index_score += mask
topk_score, topk_indices = index_score.topk(min(index_topk, end_pos), dim=-1)
query_positions = torch.arange(s1, device=topk_indices.device).unsqueeze(0).unsqueeze(-1)
valid_positions = topk_indices <= query_positions
topk_indices = torch.where(valid_positions, topk_indices, torch.full_like(topk_indices, -1))
return topk_indices, topk_score
@staticmethod
def get_compress_idxs_on_this_rank(s_total, device, tp_shard=True):
cp_size = parallel_state.get_context_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_rank = parallel_state.get_tensor_model_parallel_rank()
compress_idxs = torch.arange(1, s_total + 1, device=device)
if cp_size > 1:
compress_idxs = permute_cp_shard(compress_idxs, reorder=False)
if tp_size > 1 and tp_shard:
s = s_total // cp_size // tp_size
compress_idxs = compress_idxs[s * tp_rank : s * (tp_rank + 1)]
return compress_idxs.unsqueeze(1)
@staticmethod
def forward_with_scores_compress(
x, q, k, weights, mask, packed_seq_params, start_pos, index_topk, offset, compress_ratio=4
):
args = get_args()
if args.use_fused_lightning_indexer:
if args.context_parallel_size > 1 and args.context_parallel_algo == 'kvallgather_cp_algo':
topk_idxs, topk_score = fused_lightning_indexer_with_compress_kvallgather(
q,
k,
weights,
index_topk,
actual_seq_qlen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_q,
actual_seq_klen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_kv,
layout_query='BSND',
layout_key='BSND',
compress_ratio=compress_ratio,
)
else:
topk_idxs, topk_score = fused_lightning_indexer_with_compress(
q,
k,
weights,
index_topk,
actual_seq_qlen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_q,
actual_seq_klen=None if packed_seq_params is None else packed_seq_params.cu_seqlens_kv,
layout_query='BSND',
layout_key='BSND',
compress_ratio=compress_ratio,
)
topk_idxs = torch.where(topk_idxs == -1, topk_idxs, topk_idxs + offset) if offset != 0 else topk_idxs
else:
s1, s2 = x.size(0), k.size(0)
s_total = s2 * compress_ratio
end_pos = start_pos + s_total
device = x.device
compress_idxs = DSAIndexer.get_compress_idxs_on_this_rank(s_total, device) // compress_ratio
index_score = bf16_index(q.contiguous(), weights.unsqueeze(-1), k.contiguous())
mask = torch.arange(s2, device=device).repeat(s1, 1) >= compress_idxs
index_score = index_score + torch.where(mask, torch.finfo(q.dtype).min, 0)
topk_score, topk_idxs = index_score.topk(min(index_topk, end_pos // compress_ratio), dim=-1)
topk_idxs = topk_idxs.int()
mask = topk_idxs >= compress_idxs
topk_idxs = torch.where(mask, -1, topk_idxs + offset) if offset != 0 else torch.where(mask, -1, topk_idxs)
return topk_idxs, topk_score
def forward_with_index_compress(self, x: Tensor, qr: Tensor, start_pos: int, freqs_cis: Tensor):
q = self.wq_b(qr)
q = rearrange(q, 's b (h d) -> s b h d', d=self.head_dim)
q = q.transpose(0, 1)
q[..., -self.rope_head_dim :] = apply_rotary_emb(q[..., -self.rope_head_dim :], freqs_cis)
q = q.transpose(0, 1)
q = rotate_activation(q)
k = self.kv_compressor(x, start_pos, freqs_cis).unsqueeze(2)
weights = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
weights = weights * self.softmax_scale
return q, k, weights, x
def all_gather_qk_weight(self, q, k, weights):
k = gather_from_sp_cp(k)
if self.use_fused_lightning_indexer:
q = gather_from_sp_cp(q)
weights = gather_from_sp_cp(weights)
return q, k, weights
def all_gather_qk_weight_kvallgather(self, q, k, weights):
k = gather_from_sp_cp(k)
group = parallel_state.get_tensor_model_parallel_group()
q = gather_from_sequence_parallel_region(q, group=group)
weights = gather_from_sequence_parallel_region(weights, group=group)
return q, k, weights
def post_process_index(self, topk_indices, topk_score):
if not self.use_fused_lightning_indexer:
topk_indices, topk_score = topk_indices.transpose(0, 1), topk_score.transpose(0, 1)
topk_indices, topk_score = self.all_gather_score(topk_indices, topk_score)
topk_indices, topk_score = topk_indices.transpose(0, 1), topk_score.transpose(0, 1)
return topk_indices, topk_score
def all_gather_score(self, topk_indices, topk_score):
topk_indices = gather_from_sequence_parallel_region(topk_indices, group=mpu.get_tensor_model_parallel_group())
topk_score = gather_from_sequence_parallel_region(topk_score, group=mpu.get_tensor_model_parallel_group())
return topk_indices, topk_score
@staticmethod
def generate_sparse_mask(topk_indices, mask, shape, dtype, device):
args = get_args()
if not args.use_sparse_flash_attn:
attention_mask = torch.full(shape, float('-inf'), dtype=dtype, device=device).scatter_(-1, topk_indices, 0)
if mask is None:
mask = torch.where(
torch.triu(torch.ones(shape, dtype=dtype, device=device), diagonal=1) == 1, float('-inf'), 0.0
)
attention_mask += mask
if getattr(args, 'use_flash_attn', False):
attention_mask = (torch.isinf(attention_mask) & (attention_mask < 0)).unsqueeze(1)
args.sparse_mode = 0
else:
attention_mask = None
return attention_mask
@staticmethod
def generate_sparse_mask_compress(topk_indices, mask, shape, dtype, device, offset, compress_ratio=4):
args = get_args()
b, s1, s2 = shape
s_total = s2 * compress_ratio
if offset != 0:
topk_indices = torch.where(topk_indices == -1, topk_indices, topk_indices - offset)
max_valid_idx = shape[-1] - 1
topk_indices_clean = torch.where(
topk_indices == -1,
torch.tensor(max_valid_idx, device=topk_indices.device, dtype=topk_indices.dtype),
topk_indices,
)
topk_indices_clean = torch.clamp(topk_indices_clean, 0, max_valid_idx)
attention_mask = torch.full(shape, float('-inf'), dtype=dtype, device=device).scatter_(
-1, topk_indices_clean, 0
)
compress_idxs = DSAIndexer.get_compress_idxs_on_this_rank(s_total, device, False) // compress_ratio
mask = torch.arange(s2, device=device).repeat(s1, 1) >= compress_idxs
mask = torch.where(mask, float('-inf'), 0)
attention_mask += mask
if getattr(args, 'use_flash_attn', False):
attention_mask = torch.isinf(attention_mask) & (attention_mask < 0).unsqueeze(1)
args.sparse_mode = 0
return attention_mask
class DSAIndexerLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that triggers the backward pass and scales the grad for DSA indexer loss."""
main_loss_backward_scale: torch.Tensor = None
@staticmethod
def forward(ctx, output: torch.Tensor, loss: torch.Tensor):
"""Preserve the indexer_loss by storing it in the context to avoid garbage collection.
Args:
ctx: Context object used to save tensors for backward pass.
output (torch.Tensor): The output tensor.
loss (torch.Tensor): The indexer loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for indexer loss.
Args:
ctx: Context object used to save tensors for backward pass.
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled indexer loss
gradient.
"""
(loss,) = ctx.saved_tensors
if DSAIndexerLossAutoScaler.main_loss_backward_scale is None:
DSAIndexerLossAutoScaler.main_loss_backward_scale = torch.tensor(1.0, device=loss.device)
dsa_indexer_loss_backward_scale = DSAIndexerLossAutoScaler.main_loss_backward_scale
scaled_dsa_indexer_loss_grad = torch.ones_like(loss) * dsa_indexer_loss_backward_scale
return grad_output, scaled_dsa_indexer_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the indexer loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
matches the scale of the main_loss.
"""
if DSAIndexerLossAutoScaler.main_loss_backward_scale is None:
DSAIndexerLossAutoScaler.main_loss_backward_scale = scale
else:
DSAIndexerLossAutoScaler.main_loss_backward_scale.copy_(scale)
def forward_step_dsa_wrapper(fn):
"""Forward step for passed-in model. Patch for DSA indexer loss."""
@wraps(fn)
def wrapper(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
encoder_decoder_xattn=False,
extra_block_kwargs=None,
):
global_args = get_args()
common_kwargs = {
'forward_step_func': forward_step_func,
'data_iterator': data_iterator,
'model': model,
'num_microbatches': num_microbatches,
'input_tensor': input_tensor,
'forward_data_store': forward_data_store,
'config': config,
'collect_non_loss_data': collect_non_loss_data,
'checkpoint_activations_microbatch': checkpoint_activations_microbatch,
'is_first_microbatch': is_first_microbatch,
'current_microbatch': current_microbatch,
'encoder_decoder_xattn': encoder_decoder_xattn,
}
if global_args.moe_fb_overlap:
common_kwargs['extra_block_kwargs'] = extra_block_kwargs
output_tensor, num_tokens = fn(**common_kwargs)
if not isinstance(output_tensor, (list, tuple)):
output_tensor_device = output_tensor.device
else:
output_tensor_device = output_tensor[0].device
if global_args.enable_dsa_indexer:
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor_device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor_device)
)
if config.calculate_per_token_loss:
DSAIndexerLossAutoScaler.set_loss_scale(loss_scale)
else:
DSAIndexerLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output_tensor, num_tokens
return wrapper
class DSAIndexerLossLoggingHelper:
"""Helper class for logging DSAIndexer losses."""
tracker = {}
@staticmethod
def save_loss_to_tracker(
loss: torch.Tensor,
layer_number: int,
num_layers: int,
reduce_group: torch.distributed.ProcessGroup = None,
avg_group: torch.distributed.ProcessGroup = None,
):
"""Save the DSA indexer loss for logging.
Args:
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
if layer_number is None:
return
tracker = DSAIndexerLossLoggingHelper.tracker
if "values" not in tracker:
tracker["values"] = torch.zeros(num_layers, device=loss.device)
tracker["values"][layer_number - 1] += loss.detach()
tracker["reduce_group"] = reduce_group
tracker["avg_group"] = avg_group
@staticmethod
def clean_loss_in_tracker():
"""Clear the DSA indexer losses."""
tracker = DSAIndexerLossLoggingHelper.tracker
tracker["values"].zero_()
tracker["reduce_group"] = None
tracker["avg_group"] = None
@staticmethod
def reduce_loss_in_tracker():
"""Collect and reduce the DSA indexer losses across ranks."""
tracker = DSAIndexerLossLoggingHelper.tracker
if "values" not in tracker:
return
values = tracker["values"]
torch.distributed.all_reduce(values, group=parallel_state.get_pipeline_model_parallel_group())
if tracker.get('reduce_group') is not None:
torch.distributed.all_reduce(values, group=tracker.get('reduce_group'))
if tracker.get('avg_group') is not None:
torch.distributed.all_reduce(values, group=tracker['avg_group'], op=torch.distributed.ReduceOp.AVG)
torch.distributed.all_reduce(
values,
group=parallel_state.get_data_parallel_group(with_context_parallel=False),
op=torch.distributed.ReduceOp.AVG,
)
@staticmethod
def track_das_indexer_metrics(loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None):
"""Track the DSA Indexer metrics for logging."""
DSAIndexerLossLoggingHelper.reduce_loss_in_tracker()
tracker = DSAIndexerLossLoggingHelper.tracker
if "values" not in tracker:
return
das_indexer_losses = tracker["values"] * loss_scale
das_indexer_num_layers = das_indexer_losses.shape[0]
loss = das_indexer_losses.sum() / das_indexer_num_layers
name = "dsa_indexer_loss"
if total_loss_dict is not None:
total_loss_dict[name] = loss
if writer is not None:
writer.add_scalar(name, loss, iteration)
if wandb_writer is not None:
wandb_writer.log({f"{name}": loss}, iteration)
DSAIndexerLossLoggingHelper.clean_loss_in_tracker()
def compute_dsa_indexer_loss(main_attn_dist, index_score, topk_indices, loss_scale, eps=1e-8, cmp_ratio=1):
"""Compute dsa indexer loss at sparse training stage
Reference: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf
Args:
main_attn_dist: Q dist
index_score: P dist
topk_indices: Selected top-K indices for sparse phase
loss_scale: Dsa indexer loss scale
"""
args = get_args()
if args.use_fused_lightning_indexer and cmp_ratio > 1:
index_score_up = index_score[:, : cmp_ratio - 1, :]
index_score_down = index_score[:, cmp_ratio - 1 :, :]
zeros_up = torch.zeros_like(index_score_up, dtype=torch.float32)
index_score = torch.cat([zeros_up, index_score_down], dim=1)
index_score = F.softmax(index_score, dim=-1, dtype=torch.float32)
max_valid_idx = main_attn_dist.size(-1) - 1
topk_indices_clean = torch.where(
topk_indices == -1,
torch.tensor(max_valid_idx, device=topk_indices.device, dtype=topk_indices.dtype),
topk_indices,
)
topk_indices_clean = torch.clamp(topk_indices_clean, 0, max_valid_idx)
selected_main_attn_dist = torch.gather(main_attn_dist, dim=-1, index=topk_indices_clean)
selected_main_attn_dist = F.normalize(selected_main_attn_dist, p=1, dim=-1)
loss = (
F.kl_div(
(index_score + eps).log(),
selected_main_attn_dist + eps,
reduction='none',
)
.sum(dim=-1)
.mean()
)
loss *= loss_scale
return loss
def get_attn_scores(
query,
key,
attention_mask,
num_attn_head_per_group,
attn_scale,
allgather_q=True,
):
"""aggregate the main attention scores"""
if num_attn_head_per_group > 1:
key = key.repeat_interleave(num_attn_head_per_group, dim=2)
output_size = (query.size(1), query.size(2), query.size(0), key.size(0))
query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu"
)
matmul_result = torch.baddbmm(
matmul_input_buffer,
query.transpose(0, 1),
key.transpose(0, 1).transpose(1, 2),
beta=0.0,
alpha=attn_scale,
)
attention_scores = matmul_result.view(*output_size)
if attention_mask is not None:
attention_scores.masked_fill_(attention_mask, torch.finfo(query.dtype).min)
attention_scores = F.softmax(attention_scores, dim=-1, dtype=torch.float32)
attention_scores = attention_scores.sum(dim=1)
if parallel_state.get_tensor_model_parallel_world_size() > 1 and not allgather_q:
torch.distributed.all_reduce(
attention_scores.contiguous(), group=parallel_state.get_tensor_model_parallel_group()
)
return attention_scores
def fused_lightning_indexer(
q: Tensor,
k: Tensor,
weights: Tensor,
index_topk,
actual_seq_qlen=None,
actual_seq_klen=None,
layout_query='BSND',
layout_key='BSND',
):
q = rearrange(q, 's b h d -> b s h d').to(torch.bfloat16)
k = rearrange(k, 's b h d -> b s h d').to(torch.bfloat16)
weights = rearrange(weights, 's b d -> b s d').to(torch.bfloat16)
topk_indices, topk_score = torch_npu.npu_lightning_indexer(
q,
k,
weights,
actual_seq_lengths_query=actual_seq_qlen,
actual_seq_lengths_key=actual_seq_klen,
layout_query=layout_query,
layout_key=layout_key,
sparse_count=index_topk,
sparse_mode=3,
return_value=True,
)
topk_indices = topk_indices.squeeze(2)
topk_score = topk_score.squeeze(2)
return topk_indices, topk_score
def fused_lightning_indexer_with_compress(
q: Tensor,
k: Tensor,
weights: Tensor,
index_topk,
actual_seq_qlen=None,
actual_seq_klen=None,
layout_query='BSND',
layout_key='BSND',
compress_ratio=4,
):
q = rearrange(q, 's b h d -> b s h d').to(torch.bfloat16)
k = rearrange(k, 's b h d -> b s h d').to(torch.bfloat16)
weights = rearrange(weights, 's b d -> b s d').to(torch.bfloat16)
topk_indices, topk_score = mindspeed_li.npu_lightning_indexer(
q, k, weights, sparse_count=index_topk, sparse_mode=3, cmp_ratio=compress_ratio
)
topk_indices = topk_indices.squeeze(2)
topk_score = topk_score.squeeze(2)
return topk_indices, topk_score
def fused_sparse_lightning_indexer_kl_loss(
query,
key,
query_index,
key_index,
weights,
topk_indices,
softmax_max,
softmax_sum,
scale_value=1,
*,
query_rope=None,
key_rope=None,
actual_seq_qlen=None,
actual_seq_klen=None,
layout='BSND',
sparse_mode=3,
pre_tokens=65536,
next_tokens=65536,
):
"""NPU Sparse Lightning Indexer KL Divergence Loss Function"""
query, key, query_index, key_index, weights = [
x.transpose(0, 1) for x in [query, key, query_index, key_index, weights]
]
topk_indices = topk_indices.unsqueeze(2)
if query_rope is not None:
query_rope, key_rope = [x.transpose(0, 1) for x in [query_rope, key_rope]]
bsz = query.shape[0]
sq = query.shape[1]
loss = LILossTrain.apply(
query,
key,
query_index,
key_index,
weights,
topk_indices,
softmax_max,
softmax_sum,
scale_value,
query_rope,
key_rope,
actual_seq_qlen,
actual_seq_klen,
layout,
sparse_mode,
pre_tokens,
next_tokens,
)
return loss / (sq * bsz)
class LILossTrain(torch.autograd.Function):
"""
A custom autograd function that computes kl loss in sparse lightning indexer.
This interface implements the backward functionality of npu_lightning_indexer and integrates the loss computation.
The npu_lightning_indexer selects the top-k pairs between queries and keys in attention that exhibit the strongest
intrinsic correlations, storing them in sparse_indices. This reduces the computational cost of attention in
long-sequence scenarios and improves training performance.
"""
@staticmethod
def forward(
ctx,
query,
key,
query_index,
key_index,
weights,
sparse_indices,
softmax_max,
softmax_sum,
scale_value=1,
query_rope=None,
key_rope=None,
actual_seq_qlen=None,
actual_seq_klen=None,
layout='BSND',
sparse_mode=3,
pre_tokens=65536,
next_tokens=65536,
):
"""
Forward pass: compute the total loss by processing hidden states in chunks.
Args:
ctx: Context object used to save tensors for backward pass.
query (Tensor): Required. Represents the Attention query. Shapes: (B, S1, N1, D), (T1, N1, D)
key (Tensor): Required. Represents the Attention key. Shapes: (B, S2, N2, D), (T2, N2, D)
query_index (Tensor): Required. Input query for the lightning_indexer forward pass.
key_index (Tensor): Required. Input key for the lightning_indexer forward pass.
weights (Tensor): Required. Weight coefficients of lightning_indexer.
sparse_indices (Tensor): Required. Token indices of sorted key and key_index.
softmax_max (Tensor): Required. Maximum values from Attention softmax results.
softmax_sum (Tensor): Required. Sum values from Attention softmax results.
scale_value (float): Required scaling coefficient.
query_rope (Tensor, optional): RoPE information for query in MLA architecture.
key_rope (Tensor, optional): RoPE information for key in MLA architecture.
actual_seq_qlen (list[int], optional): Required in TND layout. Cumulative sequence lengths for query.
actual_seq_klen (list[int], optional): Required in TND layout. Cumulative sequence lengths for key.
layout (str, optional): Input data layout format. Supported: "BSND", "TND". Default: "BSND".
sparse_mode (int, optional): Sparse computation mode. Default: 3.
pre_tokens (int, optional): Number of preceding tokens for sparse Attention. Default: 65536.
next_tokens (int, optional): Number of succeeding tokens for sparse Attention. Default: 65536.
Returns:
d_query_index (Tensor): Gradient of query_index.
d_key_index (Tensor): Gradient of key_index.
d_weights (Tensor): Gradient of weights.
loss (Tensor): Difference between network forward output and golden value.
"""
d_query_index, d_key_index, d_weights, loss = torch_npu.npu_sparse_lightning_indexer_grad_kl_loss(
query,
key,
query_index,
key_index,
weights,
sparse_indices,
softmax_max,
softmax_sum,
scale_value=scale_value,
query_rope=query_rope,
key_rope=key_rope,
actual_seq_qlen=actual_seq_qlen,
actual_seq_klen=actual_seq_klen,
layout=layout,
sparse_mode=sparse_mode,
pre_tokens=pre_tokens,
next_tokens=next_tokens,
)
ctx.save_for_backward(d_query_index, d_key_index, d_weights)
return loss[0]
@staticmethod
def backward(ctx, *grad_output) -> Tuple:
"""
Backward pass: propagate upstream gradients through the precomputed gradients.
Args:
ctx: Context object with saved tensors from forward pass.
grad_output: Gradient output.
Returns:
tuple: Gradients.
"""
d_query_index, d_key_index, d_weights = ctx.saved_tensors
grad_scale = grad_output[0]
if torch.ne(grad_scale, torch.tensor(1.0, device=grad_scale.device)):
d_query_index = d_query_index * grad_scale
d_key_index = d_key_index * grad_scale
d_weights = d_weights * grad_scale
res_list = [None] * 12
return None, None, d_query_index, d_key_index, d_weights, *res_list
def gather_and_permute_cp_shard(t: torch.Tensor, cp_group: torch.distributed.ProcessGroup) -> torch.Tensor:
cp_size = get_distributed_world_size(cp_group)
t_ag = gather_from_sequence_parallel_region(t, group=cp_group)
t_ag = t_ag.view(2 * cp_size, -1, *t.shape[1:])
chunk_ids = get_seq_chunk_ids_for_reordering_before_attn(cp_size, t.device)
t_ag = torch.index_select(t_ag, dim=0, index=chunk_ids).contiguous()
return t_ag.view(-1, *t.shape[1:])
def fused_lightning_indexer_kvallgather(
q: Tensor,
k: Tensor,
weights: Tensor,
index_topk,
actual_seq_qlen=None,
actual_seq_klen=None,
layout_query='BSND',
layout_key='BSND',
):
cp_group = parallel_state.get_context_parallel_group()
q, weights = [t.view(2, t.shape[0] // 2, *t.shape[1:]).transpose(1, 2) for t in [q, weights]]
k_ag = gather_and_permute_cp_shard(k, cp_group).transpose(0, 1)
topk_indices = torch.empty((*q.shape[:3], 1, index_topk), device=k.device, dtype=torch.int32)
topk_scores = torch.empty((*q.shape[:3], 1, index_topk), device=k.device, dtype=torch.bfloat16)
indices = [None, None]
scores = [None, None]
cp_size = parallel_state.get_context_parallel_world_size()
rank = parallel_state.get_context_parallel_rank()
local_seq_chunk_ids = [rank + 1, 2 * cp_size - rank]
chunk = k_ag.shape[1] // cp_size // 2
for i, chunk_id in enumerate(local_seq_chunk_ids):
indices[i], scores[i] = torch_npu.npu_lightning_indexer(
q[i],
k_ag[:, 0 : chunk_id * chunk, ...],
weights[i],
layout_query="BSND",
layout_key="BSND",
sparse_count=index_topk,
sparse_mode=3,
return_value=True,
)
topk_indices = torch.cat(indices, dim=1).squeeze(2)
topk_scores = torch.cat(scores, dim=1).squeeze(2)
return topk_indices, topk_scores
def fused_sparse_lightning_indexer_kl_loss_kvallgather(
query,
key,
query_index,
key_index,
weights,
topk_indices,
softmax_max,
softmax_sum,
scale_value=1,
*,
query_rope=None,
key_rope=None,
actual_seq_qlen=None,
actual_seq_klen=None,
layout='BSND',
sparse_mode=3,
pre_tokens=65536,
next_tokens=65536,
):
cp_group = parallel_state.get_context_parallel_group()
sq = query.shape[0]
topk_indices = topk_indices.unsqueeze(2).transpose(0, 1)
query, query_rope, topk_indices, query_index, weights = [
t.view(2, t.shape[0] // 2, *t.shape[1:]).transpose(1, 2)
for t in [query, query_rope, topk_indices, query_index, weights]
]
softmax_max, softmax_sum = [rearrange(t, 'b n2 (c s) n1 -> c b n2 s n1', c=2) for t in [softmax_max, softmax_sum]]
key_ag, key_index_ag, key_rope_ag = [
gather_and_permute_cp_shard(t, cp_group).transpose(0, 1) for t in [key, key_index, key_rope]
]
loss = [None, None]
cp_size = parallel_state.get_context_parallel_world_size()
rank = parallel_state.get_context_parallel_rank()
local_seq_chunk_ids = [rank + 1, 2 * cp_size - rank]
chunk = key_ag.shape[1] // cp_size // 2
for i, chunk_id in enumerate(local_seq_chunk_ids):
loss[i] = LILossTrain.apply(
query[i],
key_ag[:, 0 : chunk_id * chunk, ...],
query_index[i],
key_index_ag[:, 0 : chunk_id * chunk, ...],
weights[i],
topk_indices[i],
softmax_max[i],
softmax_sum[i],
scale_value,
query_rope[i],
key_rope_ag[:, 0 : chunk_id * chunk, ...],
None,
None,
layout,
sparse_mode,
pre_tokens,
next_tokens,
)
return (loss[0] + loss[1]) / sq
def fused_sparse_flash_attention_kvallgather(q, k, v, topk_indices, q_rope, k_rope, scale, cp_group):
"""
q: [s, b, n, d]
k: [s, b, n, d]
v: [s, b, n, d]
topk_indices: [b, s, sparse_size]
q_rope: [s, b, n, d]
k_rope: [s, b, n, d]
scale: float
cp_group: ProcessGroup
cp_stream: Stream
"""
if scale is None:
scale = q.shape[-1] ** (-0.5)
if not (q.shape[0] % 2 == 0 and k.shape[0] % 2 == 0):
raise AssertionError("Sequence length per GPU needs to be divisible by 2!")
q, q_rope = [t.view(2, t.shape[0] // 2, *t.shape[1:]).transpose(1, 2) for t in [q, q_rope]]
k_ag, v_ag, k_rope_ag = [gather_and_permute_cp_shard(t, cp_group).transpose(0, 1) for t in [k, v, k_rope]]
b, s, sparse_size = topk_indices.shape
topk_indices = topk_indices.view(b, 2, s // 2, sparse_size).transpose(0, 1).unsqueeze(3)
out_per_step = [None, None]
softmax_max = [None, None]
softmax_sum = [None, None]
out = torch.empty_like(q)
num_steps = 2
for i in range(num_steps):
attn_outs = torch_npu.npu_sparse_flash_attention(
q[i],
k_ag,
v_ag,
sparse_indices=topk_indices[i].to(torch.int32),
block_table=None,
actual_seq_lengths_query=None,
actual_seq_lengths_kv=None,
query_rope=q_rope[i],
key_rope=k_rope_ag,
scale_value=scale,
sparse_block_size=1,
layout_query='BSND',
layout_kv='BSND',
sparse_mode=3,
attention_mode=2,
return_softmax_lse=True,
)
out_per_step[i] = attn_outs[0]
softmax_max[i] = attn_outs[1]
softmax_sum[i] = attn_outs[2]
out[i].copy_(out_per_step[i])
softmax_max_out = torch.cat(softmax_max, dim=2)
softmax_sum_out = torch.cat(softmax_sum, dim=2)
out = out.transpose(0, 1).contiguous()
out = out.view(out.shape[0], -1, *out.shape[-2:])
out = rearrange(out, 'b s h d -> s b h d')
return out, softmax_max_out, softmax_sum_out
def fused_sparse_attn_shared_kv_kvallgather(
query,
ori_kv,
cmp_kv,
cmp_sparse_indices,
sinks,
softmax_scale,
cmp_ratio,
):
from mindspeed.ops.npu_sparse_attn_shared_kv import npu_sparse_attn_shared_kv
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
chunk_size = query.shape[0] // 2
q_chunks = query.chunk(2, dim=0)
if cmp_sparse_indices is not None:
indices_chunks = cmp_sparse_indices.chunk(2, dim=1)
else:
indices_chunks = [None, None]
end_idx_0 = (cp_rank + 1) * chunk_size
end_idx_1 = (2 * cp_size - cp_rank) * chunk_size
k_chunks = [ori_kv[:end_idx_0, ...], ori_kv[:end_idx_1, ...]]
if cmp_kv is not None:
chunk_size_cmp = chunk_size // cmp_ratio
end_idx_0_cmp = (cp_rank + 1) * chunk_size_cmp
end_idx_1_cmp = (2 * cp_size - cp_rank) * chunk_size_cmp
cmp_k_chunks = [cmp_kv[:end_idx_0_cmp, ...], cmp_kv[:end_idx_1_cmp, ...]]
else:
cmp_k_chunks = [None, None]
outputs = []
for i in range(2):
out = npu_sparse_attn_shared_kv(
q_chunks[i], k_chunks[i], cmp_k_chunks[i], indices_chunks[i], sinks.float(), softmax_scale, cmp_ratio
)
outputs.append(out)
output = torch.cat(outputs, dim=0)
return output
def fused_lightning_indexer_with_compress_kvallgather(
q,
k,
weights,
index_topk,
actual_seq_qlen=None,
actual_seq_klen=None,
layout_query='BSND',
layout_key='BSND',
compress_ratio=4,
):
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
chunk_size_q = q.shape[0] // 2
q_chunks = q.chunk(2, dim=0)
if weights is not None:
weights_chunks = weights.chunk(2, dim=0)
else:
weights_chunks = [None, None]
chunk_size_k = chunk_size_q // compress_ratio
end_idx_0_k = (cp_rank + 1) * chunk_size_k
end_idx_1_k = (2 * cp_size - cp_rank) * chunk_size_k
k_chunks = [k[:end_idx_0_k, ...], k[:end_idx_1_k, ...]]
topk_idxs_list = []
topk_score_list = []
for i in range(2):
q_i = q_chunks[i].clone()
k_i = k_chunks[i].clone()
w_i = weights_chunks[i].clone()
q_bshd = q_i.transpose(0, 1).contiguous().to(torch.bfloat16)
k_bshd = k_i.transpose(0, 1).contiguous().to(torch.bfloat16)
if w_i is not None:
w_bsd = w_i.transpose(0, 1).contiguous().to(torch.bfloat16)
else:
w_bsd = None
idx, score = mindspeed_li.npu_lightning_indexer(
q_bshd, k_bshd, w_bsd, sparse_count=index_topk, sparse_mode=3, cmp_ratio=compress_ratio
)
idx = idx.squeeze(2)
score = score.squeeze(2)
topk_idxs_list.append(idx)
topk_score_list.append(score)
topk_idxs_out = torch.cat(topk_idxs_list, dim=1)
topk_score_out = torch.cat(topk_score_list, dim=1)
return topk_idxs_out, topk_score_out
def fused_ms_sparse_lightning_indexer_kl_loss_kvallgather(
total_query,
kv_compress,
query_index,
key_index,
weights,
compress_topk_idxs,
padding_mask=None,
prefix=None,
scale_value=1.0,
query_rope=None,
key_rope=None,
actual_seq_qlen=None,
actual_seq_klen=None,
layout='BSND',
cmp_ratio=4,
):
import mindspeed.ops.npu_sparse_lightning_indexer_grad_kl_loss as ms_slig
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
chunk_size_q = total_query.shape[0] // 2
q_chunks = total_query.chunk(2, dim=0)
q_idx_chunks = query_index.chunk(2, dim=0)
if weights is not None:
w_chunks = weights.chunk(2, dim=0)
else:
w_chunks = [None, None]
if compress_topk_idxs is not None:
topk_idx_chunks = compress_topk_idxs.chunk(2, dim=1)
else:
topk_idx_chunks = [None, None]
chunk_size_k = chunk_size_q // cmp_ratio
end_idx_0_k = (cp_rank + 1) * chunk_size_k
end_idx_1_k = (2 * cp_size - cp_rank) * chunk_size_k
k_chunks = [kv_compress[:end_idx_0_k, ...], kv_compress[:end_idx_1_k, ...]]
k_idx_chunks = [key_index[:end_idx_0_k, ...], key_index[:end_idx_1_k, ...]]
total_loss = 0.0
for i in range(2):
loss_chunk = ms_slig.npu_sparse_lightning_indexer_grad_kl_loss(
q_chunks[i].contiguous(),
k_chunks[i].contiguous(),
q_idx_chunks[i].contiguous(),
k_idx_chunks[i].contiguous(),
w_chunks[i].contiguous() if w_chunks[i] is not None else None,
topk_idx_chunks[i].contiguous() if topk_idx_chunks[i] is not None else None,
padding_mask,
prefix,
scale_value=scale_value,
query_rope=query_rope,
key_rope=key_rope,
actual_seq_qlen=actual_seq_qlen,
actual_seq_klen=actual_seq_klen,
layout=layout,
cmp_ratio=cmp_ratio,
)
total_loss = total_loss + loss_chunk
final_loss = total_loss / 2.0
return final_loss