"""
add_scalar_loop_multi_jit Example for PyPTO
This example demonstrates how to implement a add_scalar_loop_multi_jit operation using PyPTO, including:
"""
import os
import sys
import argparse
import pypto
import torch
import numpy as np
from numpy.testing import assert_allclose
def _peek_run_mode_from_argv(default: str = "npu") -> str:
"""Read run_mode early so module-level decorators can use it."""
for idx, arg in enumerate(sys.argv):
if arg == "--run_mode" and idx + 1 < len(sys.argv):
value = sys.argv[idx + 1]
if value in ("npu", "sim"):
return value
if arg.startswith("--run_mode="):
value = arg.split("=", 1)[1]
if value in ("npu", "sim"):
return value
return default
global_run_mode = pypto.RunMode.NPU
if _peek_run_mode_from_argv("npu") == "sim":
global_run_mode = pypto.RunMode.SIM
def get_device_id():
"""
Get and validate TILE_FWK_DEVICE_ID from environment variable.
Returns:
int: The device ID if valid, None otherwise.
"""
if 'TILE_FWK_DEVICE_ID' not in os.environ:
print("Please set the environment variable TILE_FWK_DEVICE_ID before running:")
print(" export TILE_FWK_DEVICE_ID=0")
return None
try:
device_id = int(os.environ['TILE_FWK_DEVICE_ID'])
return device_id
except ValueError:
print(f"ERROR: TILE_FWK_DEVICE_ID must be an integer, got: {os.environ['TILE_FWK_DEVICE_ID']}")
return None
SHAPE = (32, 32, 1, 256)
VAL = 1
def add_core(input0: pypto.Tensor, input1: pypto.Tensor, add1_flag: bool = False):
pypto.set_vec_tile_shapes(1, 4, 1, 64)
out = pypto.tensor(SHAPE, pypto.DT_FP32)
if add1_flag:
t3 = input0 + input1
out[:] = t3 + VAL
else:
out[:] = input0 + input1
return out
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def add_kernel(
input0: pypto.Tensor(SHAPE, pypto.DT_FP32),
input1: pypto.Tensor(SHAPE, pypto.DT_FP32),
out: pypto.Tensor(SHAPE, pypto.DT_FP32),
add1_flag: bool = True):
out[:] = add_core(input0, input1, add1_flag)
def test_add_scalar_loop_multi_jit(device_id=None) -> None:
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
shape = SHAPE
val = VAL
input_data0 = torch.rand(shape, dtype=torch.float, device=device)
input_data1 = torch.rand(shape, dtype=torch.float, device=device)
print(f"Input0 shape: {input_data0.shape}")
print(f"Input1 shape: {input_data1.shape}")
golden = torch.add(input_data0, input_data1)
output_data = torch.empty(shape, dtype=torch.float, device=device)
add_kernel(input_data0, input_data1, output_data, False)
max_diff = np.abs(output_data.cpu().numpy() - golden.cpu().numpy()).max()
print(f"Output shape: {output_data.shape}")
print(f"Max difference: {max_diff:.6f}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(np.array(output_data.cpu()), np.array(golden.cpu()), rtol=3e-3, atol=3e-3)
golden2 = torch.add(input_data0, input_data1) + val
output_data2 = torch.empty(shape, dtype=torch.float, device=device)
add_kernel(input_data0, input_data1, output_data2, True)
max_diff = np.abs(output_data2.cpu().numpy() - golden2.cpu().numpy()).max()
print(f"Output shape: {output_data2.shape}")
print(f"Max difference: {max_diff:.6f}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(np.array(output_data2.cpu()), np.array(golden2.cpu()), rtol=3e-3, atol=3e-3)
print("✓ add_scalar_loop_multi_jit test passed")
print()
def main():
"""Run add_scalar_loop_multi_jit example.
Usage:
python add_scalar_loop_multi_jit.py # Run example
python add_scalar_loop_multi_jit.py --list # List available examples
"""
parser = argparse.ArgumentParser(
description="PyPTO add_scalar_loop_multi_jit Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s add_scalar_loop_multi_jit::test_add_scalar_loop_multi_jit
Run the add_scalar_loop_multi_jit::test_add_scalar_loop_multi_jit example
%(prog)s --list List all available examples
"""
)
parser.add_argument(
'example_id',
type=str,
nargs='?',
help='Example ID to run (1). If not specified, the example will run.'
)
parser.add_argument(
'--list',
action='store_true',
help='List all available examples and exit'
)
parser.add_argument(
'--run_mode',
type=str,
nargs='?',
default='npu',
choices=["npu", "sim"],
help='Run mode, supports npu and sim.'
)
args = parser.parse_args()
examples = {
"add_scalar_loop_multi_jit::test_add_scalar_loop_multi_jit": {
'name': 'add_scalar_loop_multi_jit',
'description': 'add_scalar_loop_multi_jit implementation',
'function': test_add_scalar_loop_multi_jit
}
}
if args.list:
print("\n" + "=" * 60)
print("Available Examples")
print("=" * 60 + "\n")
for ex_id, ex_info in sorted(examples.items()):
print(f" ID: {ex_id}")
print(f" name: {ex_info['name']}")
print(f" description: {ex_info['description']}\n")
return
if args.example_id is not None:
if args.example_id not in examples:
print(f"ERROR: Invalid example ID: {args.example_id}")
print(f"Valid example IDs are: {', '.join(map(str, sorted(examples.keys())))}")
print("\nUse --list to see all available examples.")
sys.exit(1)
print("\n" + "=" * 60)
print("PyPTO add_scalar_loop_multi_jit Example")
print("=" * 60 + "\n")
device_id = None
examples_to_run = []
if args.example_id is not None:
example = examples.get(args.example_id)
if example is None:
raise ValueError(f"Invalid example ID: {args.example_id}")
examples_to_run = [(args.example_id, example)]
else:
examples_to_run = list(examples.items())
if args.run_mode == "npu":
device_id = get_device_id()
if device_id is None:
return
import torch_npu
torch.npu.set_device(device_id)
print("Running examples that require NPU hardware...")
print("(Make sure CANN environment is configured and NPU is available)\n")
try:
for ex_id, ex_info in examples_to_run:
print(f"Running Example {ex_id}: {ex_info['name']}")
ex_info['function'](device_id)
if len(examples_to_run) > 1:
print("=" * 60)
print("All add_scalar_loop_multi_jit tests passed!")
print("=" * 60)
except Exception as e:
print(f"\nError: {e}")
raise
if __name__ == "__main__":
main()