"""
Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
"""
import logging
import sys
import traceback
import torch
import torch_npu
import torch_sip
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("torch_sip_cal_test")
class BlasCalTester:
"""BLAS Scal 算子功能与边界测试类"""
def __init__(self, device: str = "npu:0"):
self.device = device
self.total_tests = 0
self.passed_tests = 0
@staticmethod
def log_header(title: str):
"""打印测试章节标题"""
logger.info("\n%s %s %s", "=" * 20, title, "=" * 20)
def validate(self, scenario: str, npu_out: torch.Tensor, ref_out: torch.Tensor, atol: float = 1e-5) -> bool:
"""
验证计算结果准确性
"""
self.total_tests += 1
if npu_out is None:
logger.error(" ✗ %s: FAILED | 算子返回值为 None", scenario)
return False
try:
npu_cpu = npu_out.cpu()
ref_cpu = ref_out.cpu()
if npu_cpu.shape != ref_cpu.shape:
logger.error(" ✗ %s: FAILED | Shape mismatch: NPU %s vs REF %s",
scenario, npu_cpu.shape, ref_cpu.shape)
return False
is_close = torch.allclose(npu_cpu, ref_cpu, atol=atol)
if is_close:
logger.info(" ✓ %s: PASSED", scenario)
self.passed_tests += 1
return True
max_diff = (npu_cpu - ref_cpu).abs().max()
logger.error(" ✗ %s: FAILED | Max Diff: %.6e", scenario, max_diff.item())
return False
except RuntimeError as e:
logger.error(" ⚠ %s: EXCEPTION | Runtime error during validation: %s", scenario, str(e))
return False
except Exception as e:
logger.error(" ⚠ %s: UNKNOWN ERROR | %s", scenario, str(e))
logger.debug(traceback.format_exc())
return False
def test_sscal(self):
"""测试 SSCAL (Float32 * Float 标量)"""
BlasCalTester.log_header("Test SSCAL (Float32 * Float)")
shape = (1024, 1024)
try:
x = torch.randn(shape, dtype=torch.float32, device=self.device)
alpha = 2.5
ref = x.cpu() * alpha
out = torch_sip.asd_blas_cal(x, alpha)
_ = self.validate("SSCAL (1M elements, alpha=2.5)", out, ref)
except Exception as e:
self.total_tests += 1
logger.error(" ✗ SSCAL: FAILED | Unexpected Exception: %s", str(e))
def test_csscal(self):
"""测试 CSSCAL (Complex64 * Float 标量)"""
BlasCalTester.log_header("Test CSSCAL (Complex64 * Float)")
shape = (512, 512)
try:
real = torch.randn(shape, dtype=torch.float32)
imag = torch.randn(shape, dtype=torch.float32)
x = torch.complex(real, imag).to(self.device)
alpha = 1.2
ref = x.cpu() * alpha
out = torch_sip.asd_blas_cal(x, alpha)
_ = self.validate("CSSCAL (Complex * Real Alpha)", out, ref)
except Exception as e:
self.total_tests += 1
logger.error(" ✗ CSSCAL: FAILED | Unexpected Exception: %s", str(e))
def test_cscal(self):
"""测试 CSCAL (Complex64 * Complex64 标量)"""
BlasCalTester.log_header("Test CSCAL (Complex64 * Complex)")
shape = (1000,)
try:
real = torch.randn(shape, dtype=torch.float32)
imag = torch.randn(shape, dtype=torch.float32)
x = torch.complex(real, imag).to(self.device)
alpha = complex(0.5, 1.5)
ref = x.cpu() * alpha
out = torch_sip.asd_blas_cal(x, alpha)
_ = self.validate("CSCAL (Complex * Complex Alpha)", out, ref)
except Exception as e:
self.total_tests += 1
logger.error(" ✗ CSCAL: FAILED | Unexpected Exception: %s", str(e))
def test_exceptions(self):
"""测试异常边界:实数向量不能用复数标量缩放"""
BlasCalTester.log_header("Test Error Guard")
try:
x = torch.randn((10,), dtype=torch.float32, device=self.device)
alpha = complex(1.0, 1.0)
logger.info(" [Info] Attempting Float32 * Complex (Should Fail)...")
_ = torch_sip.asd_blas_cal(x, alpha)
except (RuntimeError, TypeError, ValueError) as exc:
logger.info(" ✓ Error Guard PASSED: Catch expected exception: %s", str(exc)[:60])
self.passed_tests += 1
self.total_tests += 1
return
except Exception as e:
logger.error(" ✗ Error Guard FAILED: Caught unexpected exception type: %s", type(e))
self.total_tests += 1
return
logger.error(" ✗ Error Guard FAILED: No exception raised for invalid dtypes")
self.total_tests += 1
def main():
"""主测试流程"""
if not torch.npu.is_available():
logger.critical("NPU device not available. Exiting.")
return 1
tester = BlasCalTester()
tester.test_sscal()
tester.test_csscal()
tester.test_cscal()
tester.test_exceptions()
logger.info("\n" + "=" * 60)
logger.info("Final Report: %d/%d Tests Passed", tester.passed_tests, tester.total_tests)
logger.info("=" * 60)
return 0 if (tester.passed_tests == tester.total_tests) else 1
if __name__ == "__main__":
sys.exit(main())