"""
Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
"""
import logging
import os
import sys
from dataclasses import dataclass
from typing import Tuple
import torch
import torch_npu
import torch_sip
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("torch_sip_interp_coeff")
@dataclass(frozen=True)
class InterpConfig:
"""参数对象:封装插值算子的维度、类型及名称"""
batch: int
n_rs: int
total_subcarrier: int
n_signal: int = 14
dtype: torch.dtype = torch.complex64
name: str = "Case"
class InterpWithCoeffTester:
"""带系数的线性插值 (InterpWithCoeff) 算子测试类"""
def __init__(self, device: str = "npu:0"):
self.device = device
blocking_env = os.getenv("ASCEND_LAUNCH_BLOCKING", "0")
mode_str = "异步 (Async)" if blocking_env == "0" else "同步 (Blocking)"
logger.info("当前 NPU 模式: %s (BLOCKING=%s)", mode_str, blocking_env)
@staticmethod
def log_section(title: str):
"""静态工具方法:打印测试章节分割线"""
logger.info("\n" + "="*20 + f" {title} " + "="*20)
def run_case(self, cfg: InterpConfig) -> bool:
"""
执行单次插值测试用例。
G.ERR.01: 最小化 try 块。
"""
out_m = cfg.n_signal - cfg.n_rs
tensor_x = self._get_complex_tensor((cfg.batch, cfg.n_rs, cfg.total_subcarrier), cfg.dtype)
tensor_coeff = self._get_complex_tensor((cfg.batch, out_m, cfg.n_rs), cfg.dtype)
x_ref = tensor_x.cpu().to(torch.complex64)
coeff_ref = tensor_coeff.cpu().to(torch.complex64)
ref_out = torch.bmm(coeff_ref, x_ref)
try:
npu_out = torch_sip.asd_interp_with_coeff(tensor_x, tensor_coeff)
except Exception as exc:
logger.error("[%s] 算子执行崩溃: %s", cfg.name, exc)
return False
npu_res_c64 = npu_out.cpu().to(torch.complex64)
rtol, atol = (1e-3, 1e-3) if cfg.dtype == torch.complex64 else (1e-2, 1e-2)
is_close = torch.allclose(npu_res_c64, ref_out, rtol=rtol, atol=atol)
status = "PASS" if is_close else "FAIL"
dtype_str = "C64" if cfg.dtype == torch.complex64 else "C32"
logger.info("[%s] %-20s | %s | B=%d, nRs=%d, SubC=%d",
status, cfg.name, dtype_str, cfg.batch, cfg.n_rs, cfg.total_subcarrier)
if not is_close:
max_err = (npu_res_c64 - ref_out).abs().max().item()
logger.error(" -> Max Absolute Error: %.5f", max_err)
return is_close
def _get_complex_tensor(self, shape: Tuple[int, ...], dtype: torch.dtype) -> torch.Tensor:
"""辅助方法:安全生成复数测试张量"""
f_dtype = torch.float16 if dtype == torch.complex32 else torch.float32
real = torch.randn(shape, dtype=f_dtype)
imag = torch.randn(shape, dtype=f_dtype)
return torch.complex(real, imag).to(self.device)
def main() -> int:
"""
测试主套件入口。
返回: 0 代表全部通过, 1 代表有失败项。
"""
tester = InterpWithCoeffTester()
all_passed = True
test_suites = [
InterpConfig(1, 2, 32, name="SingleBatch_C64"),
InterpConfig(4, 4, 64, name="MultiBatch_C64"),
InterpConfig(8, 2, 128, name="LargeSubC_C64"),
InterpConfig(1, 2, 32, dtype=torch.complex32, name="SingleBatch_C32"),
InterpConfig(4, 4, 64, dtype=torch.complex32, name="MultiBatch_C32"),
]
InterpWithCoeffTester.log_section("信道估计插值算子连通性测试")
for config in test_suites:
if not tester.run_case(config):
all_passed = False
logger.info("-" * 60)
if all_passed:
logger.info("测试结论: ✅ 全部通过")
return 0
else:
logger.error("测试结论: ❌ 存在失败项")
return 1
if __name__ == "__main__":
sys.exit(main())