import unittest
import torch
import torch.nn as nn
try:
import torchair
from torchair.configs.compiler_config import CompilerConfig
HAS_TORCHAIR = True
except ImportError:
HAS_TORCHAIR = False
import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
window_size = 209715200
ffn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()
attn_workers = 2
micro_batch_number = 3
batch_size = 6
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
ffn_window = ffn_window_tensor.data_ptr()
def _set_all_flags():
num_int8 = attn_workers * micro_batch_number * (8 + batch_size * top_k * 4)
int32_view = ffn_window_tensor[:num_int8].view(torch.int32)
int32_view[:] = 1
class TestModelInplace(nn.Module):
def __init__(self):
super().__init__()
def forward(self, schedule_context):
torch_npu._afd.ffn_worker_scheduler_(schedule_context, sync_group_size=1, execute_mode=0)
class TestModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, schedule_context):
return torch_npu._afd.ffn_worker_scheduler(schedule_context, sync_group_size=1)
@unittest.skipUnless(HAS_TORCHAIR, "torchair is not available")
class TestFfnWorkerScheduler(TestCase):
def setUp(self):
self.context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode=0, session_num=attn_workers,
micro_batch_num=micro_batch_number,
micro_batch_size=batch_size,
selected_expert_num=top_k + 1,
expert_num=expert_num,
attn_to_ffn_token_size=attn_to_ffn_token_size,
ffn_to_attn_token_size=ffn_to_attn_token_size,
ffn_window=ffn_window,
ffn_window_size=window_size)
self.schedule_context = self.context_holder.get_schedule_context_tensor()
_set_all_flags()
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_ffn_worker_scheduler_(self):
_set_all_flags()
schedule_context1 = self.schedule_context.clone()
torch_npu._afd.ffn_worker_scheduler_(self.schedule_context, sync_group_size=2)
self.assertNotEqual(schedule_context1, self.schedule_context)
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_ffn_worker_scheduler(self):
_set_all_flags()
schedule_context1 = self.schedule_context.clone()
schedule_context2 = torch_npu._afd.ffn_worker_scheduler(self.schedule_context, sync_group_size=2)
self.assertEqual(schedule_context1, self.schedule_context)
self.assertNotEqual(schedule_context2, self.schedule_context)
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_ffn_worker_scheduler__graph(self):
_set_all_flags()
config = CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
model = TestModelInplace().npu()
model = torch.compile(model, backend=npu_backend)
schedule_context1 = self.schedule_context.clone()
model(self.schedule_context)
self.assertNotEqual(schedule_context1, self.schedule_context)
torch._dynamo.reset()
@unittest.skip("skip case until cann supported")
@SupportedDevices(['Ascend910B'])
def test_ffn_worker_scheduler_graph(self):
_set_all_flags()
config = CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
model = TestModel().npu()
model = torch.compile(model, backend=npu_backend)
schedule_context1 = self.schedule_context.clone()
schedule_context2 = model(self.schedule_context)
self.assertEqual(schedule_context1, self.schedule_context)
self.assertNotEqual(schedule_context2, self.schedule_context)
torch._dynamo.reset()
if __name__ == '__main__':
run_tests()