"""
kenel_unordered_input Axis Example for PyPTO
This example demonstrates:
- Run attention module with kenel_unordered_input.
"""
import os
import sys
import argparse
from dataclasses import dataclass
from typing import Optional, Tuple
import pypto
import torch
import numpy as np
from numpy.testing import assert_allclose
def _peek_run_mode_from_argv(default: str = "npu") -> str:
"""Read run_mode early so module-level decorators can use it."""
for idx, arg in enumerate(sys.argv):
if arg == "--run_mode" and idx + 1 < len(sys.argv):
value = sys.argv[idx + 1]
if value in ("npu", "sim"):
return value
if arg.startswith("--run_mode="):
value = arg.split("=", 1)[1]
if value in ("npu", "sim"):
return value
return default
global_run_mode = pypto.RunMode.NPU
if _peek_run_mode_from_argv("npu") == "sim":
global_run_mode = pypto.RunMode.SIM
def get_device_id():
"""
Get and validate TILE_FWK_DEVICE_ID from environment variable.
Returns:
int: The device ID if valid, None otherwise.
"""
if 'TILE_FWK_DEVICE_ID' not in os.environ:
print("Please set the environment variable TILE_FWK_DEVICE_ID before running:")
print(" export TILE_FWK_DEVICE_ID=0")
return None
try:
device_id = int(os.environ['TILE_FWK_DEVICE_ID'])
return device_id
except ValueError:
print(f"ERROR: TILE_FWK_DEVICE_ID must be an integer, got: {os.environ['TILE_FWK_DEVICE_ID']}")
return None
@dataclass
class AttentionConfig:
"""Configuration for attention operations."""
num_heads: int = 8
head_dim: int = 64
scale: Optional[float] = None
dtype: pypto.DataType = pypto.DT_FP32
use_dynamic_shape: bool = False
def scaled_dot_product_attention_golden(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""PyTorch reference implementation of scaled dot-product attention."""
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores * scale
if attn_mask is not None:
scores = scores + attn_mask
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
return output
def scaled_dot_product_attention_core(q: pypto.Tensor, k: pypto.Tensor, v: pypto.Tensor,
scale: float, dtype: pypto.DataType) -> pypto.Tensor:
k_t = pypto.transpose(k, 2, 3)
scores = pypto.matmul(q, k_t, out_dtype=dtype)
scores_scaled = scores * scale
attn_weights = pypto.softmax(scores_scaled, dim=-1)
res = pypto.matmul(attn_weights, v, out_dtype=dtype)
return res
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def scaled_dot_product_attention_kernel(
q: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
k: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
v: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
output_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
config: AttentionConfig,
tile: int):
"""Scaled dot-product attention with dynamic batch size."""
cube_tiling = 64
pypto.set_cube_tile_shapes(
[cube_tiling, cube_tiling],
[cube_tiling, cube_tiling],
[cube_tiling, cube_tiling])
bs = q.shape[0]
head = 8
dim = 64
q_len = q.shape[2]
kv_len = k.shape[2]
scale = config.scale if config.scale is not None else (1.0 / (dim**0.5))
b_loop = (bs + tile - 1) // tile
for bs_idx in pypto.loop(b_loop):
b_offset = bs_idx * tile
b_offset_end = min(b_offset + tile, bs)
q_view = pypto.view(q, [tile, head, q_len, dim], [b_offset, 0, 0, 0],
valid_shape=[b_offset_end - b_offset, head, q_len, dim]
)
k_view = pypto.view(k, [tile, head, kv_len, dim], [b_offset, 0, 0, 0],
valid_shape=[b_offset_end - b_offset, head, kv_len, dim]
)
v_view = pypto.view(v, [tile, head, kv_len, dim], [b_offset, 0, 0, 0],
valid_shape=[b_offset_end - b_offset, head, kv_len, dim]
)
pypto.set_vec_tile_shapes(1, 8, 16, 64)
res = scaled_dot_product_attention_core(q_view, k_view, v_view, scale, config.dtype)
pypto.assemble(res, [b_offset, 0, 0, 0], output_tensor)
def test_unordered_input_attention(device_id: int = None, dynamic: bool = True) -> None:
"""Test attention with kenel_unordered_input."""
print("=" * 60)
print("Test: kenel_unordered_input Scaled Dot-Product Attention")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
num_heads, head_dim = 8, 64
batch_size, seq_len_q, seq_len_kv = 8, 64, 64
dtype = torch.float32
q_torch = torch.randn(batch_size, num_heads, seq_len_q, head_dim,
dtype=dtype, device=device)
k_torch = torch.randn(batch_size, num_heads, seq_len_kv, head_dim,
dtype=dtype, device=device)
v_torch = torch.randn(batch_size, num_heads, seq_len_kv, head_dim,
dtype=dtype, device=device)
config = AttentionConfig(num_heads=num_heads, head_dim=head_dim,
dtype=pypto.DT_FP32, use_dynamic_shape=True)
q_shape = q_torch.shape
k_shape = k_torch.shape
out_torch = torch.empty(batch_size, num_heads, seq_len_q, head_dim,
dtype=dtype, device=device)
scaled_dot_product_attention_kernel(q_torch, k_torch, v_torch, out_torch, config, batch_size)
scale = 1.0 / (head_dim ** 0.5)
golden = scaled_dot_product_attention_golden(q_torch, k_torch, v_torch, scale)
print(f"Batch={batch_size}, SeqQ={seq_len_q}, SeqKV={seq_len_kv}")
print(f"Input shape: {q_torch.shape}")
print(f"Output shape: {out_torch.shape}")
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(np.array(out_torch.cpu()), np.array(golden.cpu()), rtol=3e-3, atol=3e-3)
print("✓ Attention (kenel_unordered_input) passed for the test case")
print()
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def op_unordered_input_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
out1: pypto.Tensor([], pypto.DT_FP32),
out2: pypto.Tensor([], pypto.DT_FP32)):
pypto.set_vec_tile_shapes(16, 16)
out1.move(a + b)
out2.move(a * b)
def test_unordered_input_op(device_id: int = None, dynamic: bool = False) -> None:
"""Test op with kenel_unordered_input"""
print("=" * 60)
print("Test: OP with kenel_unordered_input")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
shape = (3, 2)
dtype = torch.float32
a = torch.rand(shape, dtype=dtype, device=device)
b = torch.rand(shape, dtype=dtype, device=device)
y1 = torch.empty(shape, dtype=dtype, device=device)
y2 = torch.empty(shape, dtype=dtype, device=device)
op_unordered_input_kernel(a, b, y1, y2)
y1, y2 = y1.cpu(), y2.cpu()
golden1 = torch.add(a, b).cpu()
golden2 = torch.mul(a, b).cpu()
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(np.array(y1), np.array(golden1), rtol=1e-3, atol=1e-3)
assert_allclose(np.array(y2), np.array(golden2), rtol=1e-3, atol=1e-3)
print(f"Output1: {y1}")
print(f"Expected1: {golden1}")
print(f"Output2: {y2}")
print(f"Expected2: {golden2}")
print("✓ OP with kenel_unordered_input passed for the test case")
print()
def main():
"""Run dynamic examples.
Usage:
python dynamic.py # Run all examples
python dynamic.py 1 # Run example 1 only
python dynamic.py --list # List all available examples
"""
parser = argparse.ArgumentParser(
description="PyPTO Full Function Examples",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s Run all examples
%(prog)s unordered_input_op::test_unordered_input_op
Run example unordered_input_op::test_unordered_input_op
%(prog)s --list List all available examples
"""
)
parser.add_argument(
'example_id',
type=str,
nargs='?',
help='Example ID to run (1-2). If not specified, all examples will run.'
)
parser.add_argument(
'--list',
action='store_true',
help='List all available examples and exit'
)
parser.add_argument(
'--run_mode',
type=str,
nargs='?',
default='npu',
choices=["npu", "sim"],
help='Run mode, supports npu and sim.'
)
args = parser.parse_args()
examples = {
'unordered_input_attention::test_unordered_input_attention': {
'name': 'Test attention with kenel_unordered_input',
'description': 'Attention with kenel_unordered_input example',
'function': test_unordered_input_attention,
'requires_npu': True
},
'unordered_input_op::test_unordered_input_op': {
'name': 'Test op with kenel_unordered_input',
'description': 'OP with kenel_unordered_input example',
'function': test_unordered_input_op,
'requires_npu': True
}
}
if args.list:
print("\n" + "=" * 60)
print("Available Examples")
print("=" * 60 + "\n")
for ex_id, ex_info in sorted(examples.items()):
print(f" ID: {ex_id}")
print(f" name: {ex_info['name']}")
print(f" description: {ex_info['description']}\n")
return
if args.example_id is not None:
if args.example_id not in examples:
print(f"ERROR: Invalid example ID: {args.example_id}")
print(f"Valid example IDs are: {', '.join(map(str, sorted(examples.keys())))}")
print("\nUse --list to see all available examples.")
sys.exit(1)
print("\n" + "=" * 60)
print("PyPTO Dynamic Function Examples")
print("=" * 60 + "\n")
device_id = None
examples_to_run = []
if args.example_id is not None:
examples_to_run = [(args.example_id, examples[args.example_id])]
else:
examples_to_run = list(examples.items())
if args.run_mode == "npu":
device_id = get_device_id()
if device_id is None:
return
import torch_npu
torch.npu.set_device(device_id)
print("Running examples that require NPU hardware...")
print("(Make sure CANN environment is configured and NPU is available)\n")
try:
for ex_id, ex_info in examples_to_run:
print(f"Running Example {ex_id}: {ex_info['name']}")
ex_info['function'](device_id)
if len(examples_to_run) > 1:
print("=" * 60)
print("All kenel_unordered_input tests passed!")
print("=" * 60)
except Exception as e:
print(f"\nError: {e}")
raise
if __name__ == "__main__":
main()