import unittest
import torch
import torch.nn as nn
try:
import torchair
from torchair.configs.compiler_config import CompilerConfig
HAS_TORCHAIR = True
except ImportError:
HAS_TORCHAIR = False
import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
window_size = 209715200
attn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()
attn_workers = 144
micro_batch_number = 3
batch_size = 30
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
attn_window = attn_window_tensor.data_ptr()
def _set_all_flags():
num_int8 = batch_size * (top_k + 1) * 4 * micro_batch_number
int32_view = attn_window_tensor[:num_int8].view(torch.int32)
int32_view[:] = 1
class TestModelInplace(nn.Module):
def __init__(self):
super().__init__()
def forward(self, schedule_context):
torch_npu._afd.attention_worker_scheduler_(schedule_context)
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, schedule_context):
return torch_npu._afd.attention_worker_scheduler(schedule_context)
@unittest.skipUnless(HAS_TORCHAIR, "torchair is not available")
class TestAttentionWorkerScheduler(TestCase):
def setUp(self):
self.context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode=1, session_num=attn_workers,
micro_batch_num=micro_batch_number,
micro_batch_size=batch_size,
selected_expert_num=top_k + 1,
expert_num=expert_num,
attn_to_ffn_token_size=attn_to_ffn_token_size,
ffn_to_attn_token_size=ffn_to_attn_token_size,
attention_window=attn_window,
attention_window_size=window_size)
self.schedule_context = self.context_holder.get_schedule_context_tensor()
_set_all_flags()
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_attention_worker_scheduler_(self):
schedule_context1 = self.schedule_context.clone()
torch_npu._afd.attention_worker_scheduler_(self.schedule_context)
self.assertNotEqual(schedule_context1, self.schedule_context)
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_attention_worker_scheduler(self):
_set_all_flags()
schedule_context1 = self.schedule_context.clone()
schedule_context2 = torch_npu._afd.attention_worker_scheduler(self.schedule_context)
self.assertEqual(schedule_context1, self.schedule_context)
self.assertNotEqual(schedule_context2, self.schedule_context)
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_attention_worker_scheduler__graph(self):
_set_all_flags()
config = CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
model = TestModelInplace().npu()
model = torch.compile(model, backend=npu_backend)
schedule_context1 = self.schedule_context.clone()
model(self.schedule_context)
self.assertNotEqual(schedule_context1, self.schedule_context)
torch._dynamo.reset()
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_attention_worker_scheduler_graph(self):
_set_all_flags()
config = CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
model = TestModel().npu()
model = torch.compile(model, backend=npu_backend)
schedule_context1 = self.schedule_context.clone()
schedule_context2 = model(self.schedule_context)
self.assertEqual(schedule_context1, self.schedule_context)
self.assertNotEqual(schedule_context2, self.schedule_context)
torch._dynamo.reset()
if __name__ == '__main__':
run_tests()