import os
import torch
import mindspeed.megatron_adaptor
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from tests_extend.unit_tests.common import DistributedTest
from tests_extend.commons import set_random_seed, initialize_model_parallel
from tests.ut.utils import judge_expression
from mindspeed_mm.models.common.attention import ParallelAttention
TP_SIZE = 2
CP_SIZE = 4
HIDDEN_SIZE = 2304
NUM_ATTENTION_HEADS = 24
SEQUENCE_LENGTH = 128
MICRO_BATCH_SIZE = 2
DTYPE = torch.bfloat16
OUTPUT_SUM = -161.0
ATTENTION_DATA_PATH = "/home/ci_resource/data/attention"
class TestParallelAttention(DistributedTest):
world_size = 8
def test_forward(self):
args = parse_args(None, True)
args.tensor_model_parallel_size = TP_SIZE
args.context_parallel_size = CP_SIZE
args.hidden_size = HIDDEN_SIZE
args.num_attention_heads = NUM_ATTENTION_HEADS
args.sequence_parallel = True
args.context_parallel_algo = "ulysses_cp_algo"
args.sparse_mode = 0
args.params_dtype = DTYPE
args.num_layers = 1
args.seq_length = SEQUENCE_LENGTH
set_args(args)
initialize_model_parallel(
tensor_model_parallel_size=args.tensor_model_parallel_size,
context_parallel_size=args.context_parallel_size
)
set_random_seed(1234)
model_parallel_cuda_manual_seed(1234)
parallelattention = ParallelAttention(
query_dim=args.hidden_size,
key_dim=args.hidden_size,
num_attention_heads=args.num_attention_heads,
hidden_size=args.hidden_size,
proj_q_bias=True,
proj_k_bias=True,
proj_v_bias=True,
proj_out_bias=True,
dropout=0.0,
use_qk_norm=True,
norm_type="layernorm",
fa_layout="sbh"
)
attention_state_dict = torch.load(f"{ATTENTION_DATA_PATH}/attention_state_dict_rank_{torch.distributed.get_rank()}.pt", map_location="cpu")
parallelattention.load_state_dict(attention_state_dict)
parallelattention = parallelattention.to("npu")
hidden_states = torch.load(f"{ATTENTION_DATA_PATH}/attention_input.pt").to(DTYPE).to("npu")
attention_mask = None
out = parallelattention(query=hidden_states, mask=attention_mask)
judge_expression(out.sum().item() == OUTPUT_SUM)