import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
class TestScheduleContext(TestCase):
def setUp(self):
npu_device = torch._C._get_privateuse1_backend_name()
self.window_tensor = torch.ones([1 * 1024 * 1024 * 1024], dtype=torch.int8).to(
npu_device
)
self.default_params = {
"schedule_mode": 0,
"session_num": 288,
"micro_batch_num": 3,
"micro_batch_size": 30,
"selected_expert_num": 9,
"expert_num": 288,
"attn_to_ffn_token_size": 7168 + 512,
"ffn_to_attn_token_size": 7168 * 2,
"attention_window": self.window_tensor.data_ptr(),
"attention_window_size": 1 * 1024 * 1024 * 1024,
"ffn_window": self.window_tensor.data_ptr(),
"ffn_window_size": 1 * 1024 * 1024 * 1024,
}
def tearDown(self):
return super().tearDown()
def test_init_with_invalid_params(self):
"""测试参数校验"""
invalid_params = [
({"session_num": 0}, "session_num=0 should fail"),
({"micro_batch_num": 0}, "micro_batch_num=0 should fail"),
(
{"session_num": 1 << 31, "micro_batch_num": 1 << 31},
"micro_batch_num mul overflow",
),
({"micro_batch_size": 0}, "micro_batch_size=0 should fail"),
(
{"micro_batch_num": 1 << 31, "micro_batch_size": 1 << 31},
"micro_batch_size mul overflow",
),
(
{
"schedule_mode": 1,
"micro_batch_num": 1 << 31,
"micro_batch_size": 1 << 31,
},
"attention micro_batch_size mul overflow",
),
({"selected_expert_num": 0}, "selected_expert_num=0 should fail"),
(
{"micro_batch_size": 1 << 31, "selected_expert_num": 1 << 31},
"selected_expert_num mul overflow",
),
(
{
"schedule_mode": 1,
"micro_batch_size": 1 << 31,
"selected_expert_num": 1 << 31,
},
"attention selected_expert_num mul overflow",
),
({"ffn_window": 0}, "ffn_window=0 should fail"),
({"ffn_window_size": 0}, "ffn_window_size can not be 0"),
({"ffn_window_size": 511}, "ffn_window_size is not enough should fail"),
({"schedule_mode": 1, "attention_window": 0}, "ffn_window is null"),
(
{"schedule_mode": 1, "attention_window_size": 0},
"attention_window_size can not be 0",
),
(
{"schedule_mode": 1, "attention_window_size": 511},
"attention_window_size is not enough should fail",
),
({"schedule_mode": 2}, "schedule_mode 2 is not supported"),
(
{"attn_to_ffn_token_size": 1023},
"attn_to_ffn_token_size must be aligned by 512",
),
(
{"ffn_to_attn_token_size": 400},
"ffn_to_attn_token_size must be aligned by 512",
),
]
for params, msg in invalid_params:
with self.subTest(msg=msg):
test_params = self.default_params.copy()
test_params.update(params)
with self.assertRaises(RuntimeError):
torch_npu._afd.create_schedule_context_holder(**test_params)
def test_schedule_ffn(self):
"""测试用有效参数初始化"""
holder = torch_npu._afd.create_schedule_context_holder(
**self.default_params
)
self.assertIsInstance(holder, torch_npu._afd.ScheduleContextHolder)
tensor = holder.get_schedule_context_tensor()
self.assertIsInstance(tensor, torch.Tensor)
context_info = holder.get_schedule_context_info()
self.assertIn("ffn info:", context_info)
holder.stop_schedule()
def test_schedule_attn(self):
"""测试用有效参数初始化"""
test_params = self.default_params.copy()
test_params["schedule_mode"] = 1
holder = torch_npu._afd.create_schedule_context_holder(**test_params)
self.assertIsInstance(holder, torch_npu._afd.ScheduleContextHolder)
tensor = holder.get_schedule_context_tensor()
self.assertIsInstance(tensor, torch.Tensor)
holder.stop_schedule()
if __name__ == "__main__":
run_tests()