"""
Tensor Creation Operation Examples for PyPTO
This file contains all tensor creation examples merged into a single file.
You can run all examples or select specific ones using command-line arguments.
Usage:
python creation_ops.py # Run all examples
python creation_ops.py --list # List all available examples
python creation_ops.py arange::test_arange_basic # Run a specific case
"""
import argparse
import os
import sys
import pypto
import torch
import numpy as np
from numpy.testing import assert_allclose
def _peek_run_mode_from_argv(default: str = "npu") -> str:
"""Read run_mode early so module-level decorators can use it."""
for idx, arg in enumerate(sys.argv):
if arg == "--run_mode" and idx + 1 < len(sys.argv):
value = sys.argv[idx + 1]
if value in ("npu", "sim"):
return value
if arg.startswith("--run_mode="):
value = arg.split("=", 1)[1]
if value in ("npu", "sim"):
return value
return default
global_run_mode = pypto.RunMode.NPU
if _peek_run_mode_from_argv("npu") == "sim":
global_run_mode = pypto.RunMode.SIM
def get_device_id():
"""
Get and validate TILE_FWK_DEVICE_ID from environment variable.
Returns:
int: The device ID if valid, None otherwise.
"""
if 'TILE_FWK_DEVICE_ID' not in os.environ:
print("Please set the environment variable TILE_FWK_DEVICE_ID before running:")
print(" export TILE_FWK_DEVICE_ID=0")
return None
try:
device_id = int(os.environ['TILE_FWK_DEVICE_ID'])
return device_id
except ValueError:
print(f"ERROR: TILE_FWK_DEVICE_ID must be an integer, got: {os.environ['TILE_FWK_DEVICE_ID']}")
return None
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def arange_end_kernel(out: pypto.Tensor((4,), pypto.DT_INT32),
end,
):
pypto.set_vec_tile_shapes(8)
out.move(pypto.arange(end))
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def arange_start_end_kernel(out: pypto.Tensor((3,), pypto.DT_FP32),
start,
end):
pypto.set_vec_tile_shapes(8)
out.move(pypto.arange(start, end))
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def arange_start_end_step_kernel(out: pypto.Tensor((6,), pypto.DT_FP32),
start,
end,
step):
pypto.set_vec_tile_shapes(8)
out.move(pypto.arange(start, end, step))
def test_arange_basic(device_id=None):
"""Test basic usage of arange function"""
print("=" * 60)
print("Test: Basic Usage of arange Function")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
expected_a = torch.tensor([0, 1, 2, 3], dtype=torch.int32, device=device)
out_torch = torch.empty(4, dtype=torch.int32, device=device)
arange_end_kernel(out_torch, end=4)
print(f"Output a: {out_torch}")
print(f"Expected a: {expected_a}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out_torch.cpu().numpy(), expected_a.cpu().numpy(), rtol=1e-3, atol=1e-3)
expected_b = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device)
out_torch = torch.empty(3, dtype=torch.float32, device=device)
arange_start_end_kernel(out_torch, start=1.0, end=4.0)
print(f"Output b: {out_torch}")
print(f"Expected b: {expected_b}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out_torch.cpu().numpy(), expected_b.cpu().numpy(), rtol=1e-3, atol=1e-3)
expected_c = torch.tensor([1.0, 1.5, 2.0, 2.5, 3.0, 3.5], dtype=torch.float32, device=device)
out_torch = torch.empty(6, dtype=torch.float32, device=device)
arange_start_end_step_kernel(out_torch, start=1.0, end=4.0, step=0.5)
print(f"Output c: {out_torch}")
print(f"Expected c: {expected_c}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out_torch.cpu().numpy(), expected_c.cpu().numpy(), rtol=1e-3, atol=1e-3)
print("✓ Basic usage of arange function completed successfully")
def test_tensor_creation_with_datatypes(device_id=None):
"""Test tensor creation with various data types"""
print("=" * 60)
print("Test: Tensor Creation with Various Data Types")
print("=" * 60)
data_types = [
(pypto.DT_INT4, "DT_INT4"),
(pypto.DT_INT8, "DT_INT8"),
(pypto.DT_INT16, "DT_INT16"),
(pypto.DT_INT32, "DT_INT32"),
(pypto.DT_INT64, "DT_INT64"),
(pypto.DT_FP8, "DT_FP8"),
(pypto.DT_FP16, "DT_FP16"),
(pypto.DT_FP32, "DT_FP32"),
(pypto.DT_BF16, "DT_BF16"),
(pypto.DT_HF4, "DT_HF4"),
(pypto.DT_HF8, "DT_HF8"),
(pypto.DT_UINT8, "DT_UINT8"),
(pypto.DT_UINT16, "DT_UINT16"),
(pypto.DT_UINT32, "DT_UINT32"),
(pypto.DT_UINT64, "DT_UINT64"),
(pypto.DT_BOOL, "DT_BOOL")
]
for dtype, dtype_name in data_types:
print(f"\nCreating tensor with data type: {dtype_name}")
tensor = pypto.tensor([2, 3], dtype, f"tensor_{dtype_name}")
print(f"Name: {tensor.name}")
print(f"Data Type: {tensor.dtype}")
print("✓ Tensor creation with various data types completed successfully")
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def full_float_kernel(out: pypto.Tensor((2, 2), pypto.DT_FP32),
fill_value):
pypto.set_vec_tile_shapes(2, 8)
out.move(pypto.full((2, 2), fill_value, pypto.DT_FP32))
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def full_symbolic_scalar_kernel(out: pypto.Tensor((2, 2), pypto.DT_INT32),
fill_value):
pypto.set_vec_tile_shapes(2, 8)
out.move(pypto.full((2, 2), fill_value, pypto.DT_INT32))
def test_full_basic(device_id=None):
"""Test basic usage of full function"""
print("=" * 60)
print("Test: Basic Usage of full Function")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
expected_a = torch.tensor([[1.0, 1.0], [1.0, 1.0]], dtype=torch.float32, device=device)
out_torch = torch.empty((2, 2), dtype=torch.float32, device=device)
full_float_kernel(out_torch, fill_value=1.0)
print(f"Output a: {out_torch}")
print(f"Expected a: {expected_a}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out_torch.cpu().numpy(), expected_a.cpu().numpy(), rtol=1e-3, atol=1e-3)
expected_b = torch.tensor([[1, 1], [1, 1]], dtype=torch.int32, device=device)
out_torch = torch.empty((2, 2), dtype=torch.int32, device=device)
full_symbolic_scalar_kernel(out_torch, fill_value=pypto.symbolic_scalar(1))
print(f"Output b: {out_torch}")
print(f"Expected b: {expected_b}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out_torch.cpu().numpy(), expected_b.cpu().numpy(), rtol=1e-3, atol=1e-3)
print("✓ Basic usage of full function completed successfully")
def test_basic_tensor_creation(device_id=None):
"""Test basic tensor creation"""
print("=" * 60)
print("Test: Basic Tensor Creation")
print("=" * 60)
tensor = pypto.tensor([2, 3], pypto.DT_FP16, "basic_tensor")
print(f"Shape: {tensor.shape}")
print(f"Data Type: {tensor.dtype}")
print(f"Dimensions: {tensor.dim}")
print(f"Format: {tensor.format}")
print(f"Name: {tensor.name}")
tensor.name = "new_name"
print(f"New Name: {tensor.name}")
print("✓ Basic tensor creation completed successfully")
def test_tensor_creation_with_format(device_id=None):
"""Test tensor creation with specific format"""
print("=" * 60)
print("Test: Tensor Creation with Specific Format")
print("=" * 60)
tensor = pypto.tensor([512, 32], pypto.DT_FP16, "sparse_tensor", pypto.TileOpFormat.TILEOP_NZ)
print(f"Shape: {tensor.shape}")
print(f"Data Type: {tensor.dtype}")
print(f"Dimensions: {tensor.dim}")
print(f"Format: {tensor.format}")
print(f"Name: {tensor.name}")
print("✓ Tensor Creation with Specific Format completed successfully")
def main():
"""Run tensor creation operation examples.
Usage:
python creation_ops.py # Run all examples
python creation_ops.py --list # List all available examples
python creation_ops.py arange::test_arange_basic # Run a specific case
"""
parser = argparse.ArgumentParser(
description="PyPTO Tensor Creation Operation Examples",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s Run all examples
%(prog)s --list List all available examples
%(prog)s arange::test_arange_basic Run a specific case
"""
)
parser.add_argument(
'example_id',
type=str,
nargs="?",
help='Run a specific case (e.g., arange::test_arange_basic). If omitted, all cases run.'
)
parser.add_argument(
'--list',
action='store_true',
help='List all available examples and exit'
)
parser.add_argument(
'--run_mode',
type=str,
nargs='?',
default='npu',
choices=["npu", "sim"],
help='Run mode, supports npu and sim.'
)
args = parser.parse_args()
examples = {
'arange::test_arange_basic': {
'name': 'Test basic usage of arange function',
'description': 'Basic usage of arange function example',
'function': test_arange_basic,
},
'datatype::test_tensor_creation_with_datatypes': {
'name': 'Test tensor creation with various data types',
'description': 'Tensor creation with various data types example',
'function': test_tensor_creation_with_datatypes,
},
'full::test_full_basic': {
'name': 'Test basic usage of full function',
'description': 'Basic usage of full function example',
'function': test_full_basic,
},
'tensor::test_basic_tensor_creation': {
'name': 'Test basic tensor creation',
'description': 'Basic tensor creation example',
'function': test_basic_tensor_creation,
},
'tensor::test_tensor_creation_with_format': {
'name': 'Test tensor creation with specific format',
'description': 'Tensor creation with specific format example',
'function': test_tensor_creation_with_format,
}
}
if args.list:
print("\n" + "=" * 60)
print("Available Examples")
print("=" * 60 + "\n")
for ex_id, ex_info in sorted(examples.items()):
print(f" ID: {ex_id}")
print(f" name: {ex_info['name']}")
print(f" description: {ex_info['description']}\n")
return
device_id = None
examples_to_run = []
if args.example_id:
if args.example_id not in examples:
print(f"ERROR: Invalid case: {args.example_id}")
print(f"Valid cases are: {', '.join(sorted(examples.keys()))}")
print("\nUse --list to see all available examples.")
sys.exit(1)
examples_to_run = [(args.example_id, examples[args.example_id])]
else:
examples_to_run = [(key, info) for key, info in sorted(examples.items())]
print("\n" + "=" * 60)
print("PyPTO Tensor Creation Operation Examples")
print("=" * 60 + "\n")
if args.run_mode == "npu":
device_id = get_device_id()
if device_id is None:
return
import torch_npu
torch.npu.set_device(device_id)
print("Running examples that require NPU hardware...")
print("(Make sure CANN environment is configured and NPU is available)\n")
try:
for ex_id, ex_info in examples_to_run:
print(f"Running Example {ex_id}: {ex_info['name']}")
ex_info['function'](device_id)
if len(examples_to_run) > 1:
print("=" * 60)
print("All creation tests passed!")
print("=" * 60)
except Exception as e:
print(f"\nError: {e}")
raise
if __name__ == "__main__":
main()