import os
import unittest
import torch
from mindiesd.compilation import MindieSDBackend
from tests.compilation.test_bench_utils import benchmark
class RopePatternModel(torch.nn.Module):
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
x_out = (x * cos + x_rotated * sin).to(x.dtype)
return x_out
class RopePatternModelDiffusersFlux(torch.nn.Module):
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return x_out
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU."
)
class TestRopeCompilationCase(unittest.TestCase):
def _run_test_and_measure_time(self, model, x, cos, sin):
compiled_model = torch.compile(model, backend=MindieSDBackend())
compiled_model(x, cos, sin)
torch.npu.synchronize()
compiled_args = (x, cos, sin)
compiled_time = benchmark(compiled_model, compiled_args)
original_time = benchmark(model, compiled_args)
output_compiled = compiled_model(x, cos, sin)
output_original = model(x, cos, sin)
output_compiled = output_compiled.reshape(1, -1).to(torch.float32)
output_original = output_original.reshape(1, -1).to(torch.float32)
self.assertGreater(
torch.cosine_similarity(output_compiled, output_original)[0], 2**-7, msg="模式替换后输出不一致!"
)
self.assertGreater(compiled_time, 0)
self.assertGreater(original_time, 0)
def test_rope_pattern_base(self):
model = RopePatternModel()
x = torch.randn(1, 4608, 24, 128, dtype=torch.bfloat16, device="npu")
cos = torch.randn(1, 4608, 1, 128, dtype=torch.bfloat16, device="npu")
sin = torch.randn(1, 4608, 1, 128, dtype=torch.bfloat16, device="npu")
self._run_test_and_measure_time(model, x, cos, sin)
def test_rope_pattern_diffusers_flux(self):
model = RopePatternModelDiffusersFlux()
x = torch.randn(1, 4608, 24, 128, dtype=torch.bfloat16, device="npu")
cos = torch.randn(1, 4608, 1, 128, dtype=torch.float32, device="npu")
sin = torch.randn(1, 4608, 1, 128, dtype=torch.float32, device="npu")
self._run_test_and_measure_time(model, x, cos, sin)
if __name__ == '__main__':
unittest.main()