"""Regression tests for attn_mask_type='padding_causal' on NPU.
Megatron-core passes attn_mask_type='padding_causal' for packed (qkv_format='thd')
sequences. It must apply causal masking within each packed segment, identical to
'causal'. These tests guard against the regression where get_fa_config() let
'padding_causal' fall through to sparse_mode=0 and DotProductAttention skipped the
compressed causal mask, making packed/thd attention non-causal (full bidirectional).
"""
import pytest
import torch
from transformer_engine.pytorch.attention.dot_product_attention.backends import get_fa_config
from transformer_engine.pytorch.attention.dot_product_attention.dot_product_attention import (
DotProductAttention,
)
pytestmark = pytest.mark.skipif(not torch.npu.is_available(), reason="NPU required")
def test_get_fa_config_padding_causal_maps_to_causal():
assert get_fa_config("causal")["sparse_mode"] == 2
assert get_fa_config("padding_causal")["sparse_mode"] == 2
assert get_fa_config("padding,causal")["sparse_mode"] == 2
assert get_fa_config("causal,padding")["sparse_mode"] == 2
assert get_fa_config("no_mask")["sparse_mode"] == 0
assert get_fa_config("general")["sparse_mode"] == 1
def _ref_per_segment_causal(q, k, v, hq, hkv, d, scale, segs):
"""Pure-fp32 per-segment causal attention ground truth."""
rep = hq // hkv
outs = []
for a, b in segs:
s = b - a
qf, kf, vf = q[a:b].float(), k[a:b].float(), v[a:b].float()
o = torch.empty(s, hq, d, device=q.device, dtype=torch.float32)
cm = torch.triu(torch.ones(s, s, device=q.device, dtype=torch.bool), diagonal=1)
for h in range(hq):
sc = (qf[:, h] @ kf[:, h // rep].transpose(0, 1)) * scale
sc = sc.masked_fill(cm, float("-inf"))
o[:, h] = torch.softmax(sc, dim=-1) @ vf[:, h // rep]
outs.append(o.reshape(s, hq * d))
return torch.cat(outs, dim=0)
def test_thd_padding_causal_matches_per_segment_causal():
"""Packed (thd) padding_causal forward == per-segment causal ground truth.
Before the fix, attention was bidirectional and the first token of each segment
attended to future tokens, giving an O(1) error vs. this reference.
"""
torch.manual_seed(0)
dev = "npu"
hq, hkv, d = 4, 2, 64
scale = 1.0 / (d**0.5)
L1, L2 = 3, 5
T = L1 + L2
segs = [(0, L1), (L1, T)]
cu = torch.tensor([0, L1, T], dtype=torch.int32, device=dev)
q = torch.randn(T, hq, d, device=dev, dtype=torch.bfloat16)
k = torch.randn(T, hkv, d, device=dev, dtype=torch.bfloat16)
v = torch.randn(T, hkv, d, device=dev, dtype=torch.bfloat16)
dpa = DotProductAttention(
num_attention_heads=hq,
kv_channels=d,
num_gqa_groups=hkv,
qkv_format="thd",
attn_mask_type="padding_causal",
softmax_scale=scale,
)
out = dpa(
q,
k,
v,
qkv_format="thd",
cu_seqlens_q=cu,
cu_seqlens_kv=cu,
max_seqlen_q=max(L1, L2),
max_seqlen_kv=max(L1, L2),
attn_mask_type="padding_causal",
).reshape(T, hq * d)
ref = _ref_per_segment_causal(q, k, v, hq, hkv, d, scale, segs)
torch.testing.assert_close(out.float(), ref, atol=2e-2, rtol=2e-2)