import unittest
from pathlib import Path
import yaml
CANN85_OP_MAPPING = (
Path(__file__).resolve().parents[4]
/ "tensor_cast/performance_model/profiling_database/data"
/ "ATLAS_800_A3_752T_128G_DIE/vllm_ascend/vllm0.15.0_torch2.9.0_cann8.5"
/ "op_mapping.yaml"
)
@unittest.skipIf(
not CANN85_OP_MAPPING.exists(),
"CANN 8.5 op_mapping.yaml not found",
)
class CompilePassOpMappingTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open(CANN85_OP_MAPPING, encoding="utf-8") as f:
full = yaml.safe_load(f)
cls.mapping = full.get("operator_mappings", {})
cls.torch_npu_ref = full.get("torch_npu_reference", {})
def test_all_reduce_compile_pass_families_keep_expected_contract(self):
"""BF16/FP16 and quantized MC2 families should keep their distinct lookup contract."""
matmul_entry = self.mapping.get("tensor_cast.matmul_all_reduce.default")
self.assertIsNotNone(matmul_entry, "Missing matmul_all_reduce in op_mapping")
self.assertTrue(matmul_entry.get("composite", False))
self.assertIn("hcom_allReduce_", matmul_entry.get("sub_kernels", []))
self.assertTrue(
any(kernel.startswith("MatMul") for kernel in matmul_entry["sub_kernels"]),
"matmul_all_reduce should keep at least one matmul compute kernel",
)
self.assertNotIn(
"tc_input_count",
matmul_entry,
"matmul_all_reduce should not use quant-style tc_input_count truncation",
)
quant_mc2_ops = [
"tensor_cast.static_quant_linear_all_reduce.default",
"tensor_cast.static_quant_linear_int4_all_reduce.default",
"tensor_cast.fp8_linear_all_reduce.default",
"tensor_cast.mxfp4_linear_all_reduce.default",
]
for op in quant_mc2_ops:
with self.subTest(op=op):
entry = self.mapping.get(op)
self.assertIsNotNone(entry, f"Missing op '{op}' in op_mapping")
self.assertTrue(entry.get("composite", False))
self.assertIn("QuantBatchMatmulV3", entry.get("sub_kernels", []))
self.assertIn("hcom_allReduce_", entry.get("sub_kernels", []))
self.assertEqual(
entry.get("tc_input_count"),
2,
f"{op} should keep tc_input_count=2 for quant lookup truncation",
)
def test_mla_compile_pass_reuses_kv_rmsnorm_kernel_reference(self):
"""MLAPO should stay wired to the shipped KvRmsNormRopeCache kernel reference."""
kv_entry = self.mapping.get("tensor_cast.kv_rmsnorm_rope_cache.default")
self.assertIsNotNone(kv_entry, "Missing kv_rmsnorm_rope_cache in op_mapping")
self.assertEqual(kv_entry.get("kernel_type"), "KvRmsNormRopeCache")
mlapo_entry = self.mapping.get("tensor_cast.mlapo.default")
self.assertIsNotNone(mlapo_entry, "Missing mlapo in op_mapping")
self.assertTrue(mlapo_entry.get("composite", False))
self.assertIn(kv_entry["kernel_type"], mlapo_entry.get("sub_kernels", []))
torch_npu_entry = self.torch_npu_ref.get(kv_entry["kernel_type"])
self.assertIsNotNone(
torch_npu_entry,
"KvRmsNormRopeCache should have a torch_npu_reference entry",
)
self.assertEqual(
torch_npu_entry.get("microbench_api"),
"torch_npu.npu_kv_rmsnorm_rope_cache",
)
self.assertIn("aclnnKvRmsNormRopeCache", torch_npu_entry.get("aclnn", []))
self.assertIn("aclnnKvRmsNormRopeCacheV2", torch_npu_entry.get("aclnn", []))
if __name__ == "__main__":
unittest.main()