"""
Copyright (c) 2026 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
"""
import logging
import sys
import torch
import torch_npu
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("torch_sip_test")
def get_torch_sip():
"""导入插件模块,失败则抛出 ImportError"""
try:
import torch_sip
logger.info("✓ torch_sip module loaded successfully")
return torch_sip
except ImportError as e:
error_msg = (
f"Failed to import torch_sip: {e}\n"
"Please build extension first:\n"
" cd torch_sip && python setup.py build_ext --inplace"
)
raise ImportError(error_msg) from e
def initialize_device():
"""初始化 NPU 设备,失败则抛出 RuntimeError"""
if not torch.npu.is_available():
raise RuntimeError("NPU is not available. This test requires NPU hardware.")
device = torch.device("npu:0")
torch.npu.set_device(device)
logger.info("✓ NPU available. Using device: %s", device)
return device
def verify_precision(result, expected):
"""验证精度并输出差异样本"""
result_cpu = result.cpu()
expected_cpu = expected.cpu()
if torch.allclose(result_cpu, expected_cpu, rtol=1e-3, atol=1e-3):
logger.info("✓ Result matches expected torch.conj(x) on CPU")
return True
logger.error("✗ Result does not match expected")
logger.error("Result sample: %s", result_cpu.flatten()[:2])
logger.error("Expect sample: %s", expected_cpu.flatten()[:2])
return False
def test_conj(torch_sip, device):
"""测试 asd_conj 算子逻辑"""
logger.info("-" * 40)
logger.info("Testing asd_conj operation")
shape = (4, 4)
x_real = torch.randn(shape, device=device)
x_imag = torch.randn(shape, device=device)
x = torch.complex(x_real, x_imag)
logger.info("Input tensor: shape=%s, dtype=%s", x.shape, x.dtype)
expected = torch.conj(x)
try:
result = torch_sip.conj(x)
logger.info("✓ asd_conj executed successfully")
return verify_precision(result, expected)
except Exception:
logger.exception("✗ asd_conj execution failed due to internal error")
return False
def main():
"""主程序入口"""
logger.info("torch_sip Extension Test")
logger.info("=" * 40)
logger.info("PyTorch version: %s", torch.__version__)
logger.info("Python version: %s", sys.version.split()[0])
try:
sip_module = get_torch_sip()
device = initialize_device()
success = test_conj(sip_module, device)
logger.info("=" * 40)
if success:
logger.info("All tests passed! ✓")
return 0
logger.error("Some tests failed. ✗")
return 1
except (ImportError, RuntimeError) as e:
logger.error("Setup failed: %s", e)
return 1
except Exception:
logger.exception("An unhandled exception occurred")
return 1
if __name__ == "__main__":
sys.exit(main())