import sys
import math
import pytest
import torch
import torch_npu
import torch.distributed as dist
sys.argv = [
sys.argv[0],
'--use-flash-attn',
]
from mindspeed_llm import megatron_adaptor
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from megatron.legacy.model.transformer import FlashSelfAttention
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.dot_product_attention import DotProductAttention
import megatron.core.parallel_state as mpu
from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import UlyssesContextAttention
from mindspeed.core.parallel_state import get_context_parallel_group_for_hybrid_ulysses
from mindspeed.model.transformer import get_attention_mask
from mindspeed.model.transformer import set_attention_mask
from tests.test_tools.dist_test import DistributedTest
from tests.test_tools.utils import initialize_model_parallel, initialize_model_parallel_decorator
from mindspeed_llm.tasks.models.common.alibi import Alibi
from mindspeed_llm.training.utils import seed_all
def get_data_on_this_cp_rank(data, r_size, u_size, cp_rank, dim=0):
""" Slice data along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
Dispatch data in a striped way for load-balance.
"""
cp_size = r_size * u_size
if r_size == 1:
data = data.chunk(cp_size, dim=dim)[cp_rank]
elif u_size == 1:
data = data.view(*data.shape[0:dim], 2 * cp_size, data.shape[dim] // (2 * cp_size), *data.shape[dim + 1:])
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=data.device)
data = data.index_select(dim, index)
data = data.view(*data.shape[0:dim], -1, *data.shape[dim + 2:])
else:
r_rank = cp_rank // u_size
u_rank = cp_rank % u_size
data = data.view(*data.shape[0:dim], 2 * r_size, data.shape[dim] // (2 * r_size), *data.shape[dim + 1:])
index = torch.tensor([r_rank, (2 * r_size - r_rank - 1)], device=data.device)
data = data.index_select(dim, index)
data = data.view(*data.shape[0:dim], -1, *data.shape[dim + 2:])
data = data.chunk(u_size, dim=dim)[u_rank]
return data
def run_attention_module(test_args, use_mcore, use_cp, cp_size, u_size, use_alibi=False):
bs, seq_len, dtype = test_args
r_size = cp_size // u_size
args = parse_args(None, True)
args.use_cp_send_recv_overlap = True
args.attention_mask_type = 'causal'
args.tp_2d = None
args.tp_x = 1
args.tp_y = 1
args.use_nd_matmul = False
args.ampipe_degree = 0
args.hccl_group_buffer_adaptive = False
args.context_parallel_kv_cache_policy = None
args.context_parallel_cache_interval = 0
args.use_ulysses_allgather_kv = False
args.enable_high_availability = False
if use_alibi:
args.position_embedding_type = 'alibi'
args.square_alibi_mask = True
args.fill_neg_inf = True
args.num_attention_heads = 32
args.params_dtype = dtype
args.use_flash_attn = True
if u_size == 1:
args.context_parallel_algo = 'megatron_cp_algo'
elif u_size == 8:
args.context_parallel_algo = 'ulysses_cp_algo'
else:
args.context_parallel_algo = 'hybrid_cp_algo'
args.context_parallel_size = cp_size
args.ulysses_degree_in_cp = u_size
args.seq_length = seq_len
set_args(args)
set_attention_mask(None)
initialize_model_parallel_nest = initialize_model_parallel_decorator(initialize_model_parallel)
initialize_model_parallel_nest(context_parallel_size=cp_size)
seed_all(1234)
rank = dist.get_rank()
b, n, s, d = bs, 32, seq_len, 128
scale = 1.0 / math.sqrt(d)
q = torch.randn(s, b, n * d, dtype=dtype, device='npu', requires_grad=True)
k = torch.randn(s, b, n * d, dtype=dtype, device='npu', requires_grad=True)
v = torch.randn(s, b, n * d, dtype=dtype, device='npu', requires_grad=True)
dout = torch.randn(s, b, n * d, dtype=dtype, device='npu', requires_grad=True)
if use_alibi:
_alibi = Alibi()
_alibi.alibi = _alibi._build_alibi_tensor(seq_len, n, True, True).to(torch.cuda.current_device(), dtype=dtype)
attn_mask = torch.triu(torch.ones(seq_len, seq_len), 1).bool().npu()
_alibi.get_alibi_pse(attn_mask, b, q.shape[0], k.shape[0])
pse = _alibi.alibi_pse.reshape(b, n, _alibi.alibi_pse.size(1), -1) * 1.0 / scale
sparse_mode = 0
else:
attn_mask = get_attention_mask()
pse = None
sparse_mode = 4 if attn_mask is not None else 0
out = torch_npu.npu_fusion_attention( \
q, k, v, n, 'SBH', \
pse=pse, \
padding_mask=None, \
atten_mask=attn_mask, \
scale=scale, \
pre_tockens=seq_len, \
next_tockens=0, \
keep_prob=1., \
inner_precise=0, \
sparse_mode=sparse_mode
)[0]
out.backward(dout)
if use_cp:
out_ref = get_data_on_this_cp_rank(out.clone().detach(), r_size, u_size, rank)
k_grad_ref = get_data_on_this_cp_rank(k.grad.clone().detach(), r_size, u_size, rank)
v_grad_ref = get_data_on_this_cp_rank(v.grad.clone().detach(), r_size, u_size, rank)
q_ = get_data_on_this_cp_rank(q.clone().detach(), r_size, u_size, rank)
k_ = get_data_on_this_cp_rank(k.clone().detach(), r_size, u_size, rank)
v_ = get_data_on_this_cp_rank(v.clone().detach(), r_size, u_size, rank)
dout_ = get_data_on_this_cp_rank(dout.clone().detach(), r_size, u_size, rank)
else:
out_ref = out.clone().detach()
k_grad_ref = k.grad.clone().detach()
v_grad_ref = v.grad.clone().detach()
q_ = q.clone().detach()
k_ = k.clone().detach()
v_ = v.clone().detach()
dout_ = dout.clone().detach()
for x in [q_, k_, v_]:
x.requires_grad = True
if use_mcore:
config = TransformerConfig(num_layers=2, hidden_size=n * d, num_attention_heads=n, use_cpu_initialization=True, context_parallel_size=cp_size)
local_attn = DotProductAttention(config=config, layer_number=1,
attn_mask_type=args.attention_mask_type, attention_type='self',
attention_dropout=0.)
else:
local_attn = FlashSelfAttention(causal=True, softmax_scale=scale, attention_dropout=0.)
attn = local_attn
if args.context_parallel_algo != "megatron_cp_algo":
ulysses_group = get_context_parallel_group_for_hybrid_ulysses() \
if args.context_parallel_algo == 'hybrid_cp_algo' else mpu.get_context_parallel_group()
attn = UlyssesContextAttention(local_attn, ulysses_group)
if use_mcore:
out_ = attn(q_.reshape(-1, b, n, d), k_.reshape(-1, b, n, d), v_.reshape(-1, b, n, d), None, None, None)
else:
out_ = attn(q_.reshape(-1, b, n, d), k_.reshape(-1, b, n, d), v_.reshape(-1, b, n, d), None)
out_.backward(dout_)
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
assert torch.allclose(out_ref, out_, **tols)
assert torch.allclose(k_grad_ref, k_.grad, **tols)
assert torch.allclose(v_grad_ref, v_.grad, **tols)
class TestAttention(DistributedTest):
"""
Test attention module, including DotProductAttention with non-CP (no context-parallel) in megatron-core and FlashSelfAttention in legacy.
"""
world_size = 8
@pytest.mark.parametrize("use_mcore", [True])
def test_no_context_parallel_seq8192_bs2_bf16(self, use_mcore):
run_attention_module((2, 8192, torch.bfloat16), use_mcore, False, 1, 1)
@pytest.mark.parametrize("use_mcore, use_alibi", [(True, True), (True, False)])
def test_alibi_seq8192_bs2_bf16(self, use_mcore, use_alibi):
run_attention_module((2, 8192, torch.bfloat16), use_mcore, False, 1, 1, use_alibi=use_alibi)