"""
Dynamic Shape Examples for PyPTO
This file demonstrates the usage of PyPTO's dynamic shape feature, which allows
kernels to handle inputs with shapes that are not known at compile time.
Key Concepts:
1. Define dynamic dimensions using pypto.DYNAMIC
2. Only make necessary dimensions dynamic; keep others as concrete values or Ellipsis
3. Use pypto.view / pypto.assemble with pypto.loop for explicit tiling and
boundary management on dynamic dimensions
Examples included:
- dynamic_mul: Basic dynamic batch dimension with view/assemble tiling
- dynamic_partial: Partial dynamic dimensions (only batch is dynamic, others concrete)
- dynamic_attention: Multi-head attention with dynamic batch size
- dynamic_multi_dim: Multiple dynamic dimensions in a single kernel
Usage:
python dynamic.py # Run all examples
python dynamic.py --list # List all available examples
python dynamic.py dynamic_mul::test_dynamic_mul # Run a specific case
"""
import os
import sys
import argparse
from dataclasses import dataclass
from typing import Optional
import pypto
import torch
import numpy as np
from numpy.testing import assert_allclose
def _peek_run_mode_from_argv(default: str = "npu") -> str:
"""Read run_mode early so module-level decorators can use it."""
for idx, arg in enumerate(sys.argv):
if arg == "--run_mode" and idx + 1 < len(sys.argv):
value = sys.argv[idx + 1]
if value in ("npu", "sim"):
return value
if arg.startswith("--run_mode="):
value = arg.split("=", 1)[1]
if value in ("npu", "sim"):
return value
return default
global_run_mode = pypto.RunMode.NPU
if _peek_run_mode_from_argv("npu") == "sim":
global_run_mode = pypto.RunMode.SIM
def get_device_id():
"""
Get and validate TILE_FWK_DEVICE_ID from environment variable.
Returns:
int: The device ID if valid, None otherwise.
"""
if 'TILE_FWK_DEVICE_ID' not in os.environ:
print("Please set the environment variable TILE_FWK_DEVICE_ID before running:")
print(" export TILE_FWK_DEVICE_ID=0")
return None
try:
device_id = int(os.environ['TILE_FWK_DEVICE_ID'])
return device_id
except ValueError:
print(f"ERROR: TILE_FWK_DEVICE_ID must be an integer, got: {os.environ['TILE_FWK_DEVICE_ID']}")
return None
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def dynamic_mul_kernel(
x: pypto.Tensor([pypto.DYNAMIC, 128], pypto.DT_FP16),
output: pypto.Tensor([pypto.DYNAMIC, 128], pypto.DT_FP16),
tile_b: int):
batch_size_dyn = x.shape[0]
b_loop = (batch_size_dyn + tile_b - 1) // tile_b
for idx in pypto.loop(b_loop):
b_offset = idx * tile_b
b_offset_end = min(b_offset + tile_b, batch_size_dyn)
valid_shape = [b_offset_end - b_offset, 128]
x_view = pypto.view(x, [tile_b, 128], [b_offset, 0],
valid_shape=valid_shape)
pypto.set_vec_tile_shapes(1, 128)
result = pypto.mul(x_view, 2.0)
pypto.assemble(result, [b_offset, 0], output)
def test_dynamic_mul(device_id: int = None):
"""Test dynamic mul with different batch sizes - same kernel, no recompilation."""
print("=" * 60)
print("Test: Dynamic Mul (basic view/assemble tiling)")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
test_batch_sizes = [8, 16]
for bs in test_batch_sizes:
x = torch.randn(bs, 128, dtype=torch.float16, device=device)
result = torch.zeros(bs, 128, dtype=torch.float16, device=device)
dynamic_mul_kernel(x, result, bs)
if global_run_mode == pypto.RunMode.NPU:
torch.npu.synchronize()
golden = x * 2.0
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(
np.array(result.cpu()), np.array(golden.cpu()),
rtol=1e-3, atol=1e-3
)
print(f" batch_size={bs}: Input shape {x.shape} -> Output shape {result.shape}")
print("✓ Dynamic mul passed for all batch sizes")
print()
def softmax_core(input_tensor: pypto.Tensor) -> pypto.Tensor:
"""Compute softmax along the last dimension."""
return pypto.softmax(input_tensor, dim=-1)
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def softmax_kernel(
input_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
output_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32)):
tile_b = 1
bs_dyn, seqlen, head, dim = input_tensor.shape
b_loop = bs_dyn // tile_b
pypto.set_vec_tile_shapes(1, 4, 1, 64)
for idx in pypto.loop(0, b_loop, 1, name="LOOP_L0_bIdx", idx_name="idx"):
b_offset = idx * tile_b
b_offset_end = (idx + 1) * tile_b
input_view = input_tensor[b_offset:b_offset_end, :seqlen, :head, :dim]
softmax_out = softmax_core(input_view)
output_tensor[b_offset:, ...] = softmax_out
def test_dynamic_partial(device_id: int = None):
"""Test softmax with partial dynamic dimensions (only batch is dynamic)."""
print("=" * 60)
print("Test: Partial Dynamic Dimensions (softmax)")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
seqlen, head, dim = 32, 1, 256
test_batch_sizes = [8, 32]
for bs in test_batch_sizes:
shape = (bs, seqlen, head, dim)
x = torch.rand(shape, dtype=torch.float32, device=device)
y = torch.zeros(shape, dtype=torch.float32, device=device)
softmax_kernel(x, y)
if global_run_mode == pypto.RunMode.NPU:
torch.npu.synchronize()
golden = torch.softmax(x, dim=-1)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(
np.array(y.cpu()), np.array(golden.cpu()),
rtol=1e-3, atol=1e-3
)
print(f" batch_size={bs}: Input shape {x.shape} -> Output shape {y.shape}")
print("✓ Partial dynamic softmax passed for all batch sizes")
print()
@dataclass
class AttentionConfig:
"""Configuration for attention operations."""
num_heads: int = 8
head_dim: int = 64
scale: Optional[float] = None
dtype: pypto.DataType = pypto.DT_FP32
def scaled_dot_product_attention_golden(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
scale: float, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""PyTorch reference implementation of scaled dot-product attention."""
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores * scale
if attn_mask is not None:
scores = scores + attn_mask
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
return output
def scaled_dot_product_attention_core(
q: pypto.Tensor, k: pypto.Tensor, v: pypto.Tensor,
scale: float, dtype: pypto.DataType
) -> pypto.Tensor:
"""Core attention computation in PyPTO."""
k_t = pypto.transpose(k, 2, 3)
scores = pypto.matmul(q, k_t, out_dtype=dtype)
scores_scaled = scores * scale
attn_weights = pypto.softmax(scores_scaled, dim=-1)
res = pypto.matmul(attn_weights, v, out_dtype=dtype)
return res
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def attention_kernel(
q: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
k: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
v: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
output_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32),
config: AttentionConfig,
tile: int):
"""Scaled dot-product attention with dynamic batch size."""
bs_dyn = q.shape[0]
head = config.num_heads
dim = config.head_dim
q_len = q.shape[2]
kv_len = k.shape[2]
scale = config.scale if config.scale is not None else (1.0 / (dim ** 0.5))
cube_tiling = 64
pypto.set_cube_tile_shapes(
[cube_tiling, cube_tiling],
[cube_tiling, cube_tiling],
[cube_tiling, cube_tiling])
bs_loop = (bs_dyn + tile - 1) // tile
for bss_idx in pypto.loop(bs_loop):
bs_offset = bss_idx * tile
bs_offset_end = min(bs_offset + tile, bs_dyn)
q_view = pypto.view(
q, [tile, head, q_len, dim], [bs_offset, 0, 0, 0],
valid_shape=[bs_offset_end - bs_offset, head, q_len, dim]
)
k_view = pypto.view(
k, [tile, head, kv_len, dim], [bs_offset, 0, 0, 0],
valid_shape=[bs_offset_end - bs_offset, head, kv_len, dim]
)
v_view = pypto.view(
v, [tile, head, kv_len, dim], [bs_offset, 0, 0, 0],
valid_shape=[bs_offset_end - bs_offset, head, kv_len, dim]
)
pypto.set_vec_tile_shapes(1, 8, 16, 64)
res = scaled_dot_product_attention_core(
q_view, k_view, v_view, scale, config.dtype
)
pypto.assemble(res, [bs_offset, 0, 0, 0], output_tensor)
def test_dynamic_attention(device_id: int = None):
"""Test attention with dynamic batch sizes."""
print("=" * 60)
print("Test: Dynamic Scaled Dot-Product Attention")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
num_heads, head_dim = 8, 64
config = AttentionConfig(
num_heads=num_heads, head_dim=head_dim, dtype=pypto.DT_FP32
)
test_cases = [
(2, 16, 16),
(4, 32, 32),
(8, 64, 64),
]
for batch_size, seq_len_q, seq_len_kv in test_cases:
dtype = torch.float32
q = torch.randn(batch_size, num_heads, seq_len_q, head_dim,
dtype=dtype, device=device)
k = torch.randn(batch_size, num_heads, seq_len_kv, head_dim,
dtype=dtype, device=device)
v = torch.randn(batch_size, num_heads, seq_len_kv, head_dim,
dtype=dtype, device=device)
out = torch.empty(batch_size, num_heads, seq_len_q, head_dim,
dtype=dtype, device=device)
attention_kernel(q, k, v, out, config, batch_size)
if global_run_mode == pypto.RunMode.NPU:
torch.npu.synchronize()
scale = 1.0 / (head_dim ** 0.5)
golden = scaled_dot_product_attention_golden(q, k, v, scale).cpu()
if global_run_mode == pypto.RunMode.NPU:
out_cpu = out.cpu()
max_diff = (out_cpu - golden).abs().max().item()
print(f" Batch={batch_size}, SeqQ={seq_len_q}, SeqKV={seq_len_kv}, "
f"Max diff: {max_diff:.6f}")
assert_allclose(np.array(out_cpu), np.array(golden), rtol=3e-3, atol=3e-3)
print(f" Input shape: {q.shape} -> Output shape: {out.shape}")
print("✓ Dynamic attention passed for all test cases")
print()
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def dynamic_add_kernel(
x: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_FP16),
y: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_FP16),
output: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC], pypto.DT_FP16),
tile_b: int,
tile_h: int):
batch_dyn = x.shape[0]
hidden_dyn = x.shape[1]
b_loop = (batch_dyn + tile_b - 1) // tile_b
for b_idx in pypto.loop(b_loop):
b_offset = b_idx * tile_b
b_offset_end = min(b_offset + tile_b, batch_dyn)
valid_b = b_offset_end - b_offset
h_loop = (hidden_dyn + tile_h - 1) // tile_h
for h_idx in pypto.loop(h_loop):
h_offset = h_idx * tile_h
h_offset_end = min(h_offset + tile_h, hidden_dyn)
valid_h = h_offset_end - h_offset
x_view = pypto.view(
x, [tile_b, tile_h], [b_offset, h_offset],
valid_shape=[valid_b, valid_h]
)
y_view = pypto.view(
y, [tile_b, tile_h], [b_offset, h_offset],
valid_shape=[valid_b, valid_h]
)
pypto.set_vec_tile_shapes(tile_b, tile_h)
result = pypto.add(x_view, y_view)
pypto.assemble(result, [b_offset, h_offset], output)
def test_dynamic_multi_dim(device_id: int = None):
"""Test kernel with multiple dynamic dimensions."""
print("=" * 60)
print("Test: Multiple Dynamic Dimensions (add)")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
test_cases = [
(8, 64),
(16, 128),
]
for bs, hs in test_cases:
x = torch.randn(bs, hs, dtype=torch.float16, device=device)
y = torch.randn(bs, hs, dtype=torch.float16, device=device)
result = torch.zeros(bs, hs, dtype=torch.float16, device=device)
dynamic_add_kernel(x, y, result, bs, hs)
if global_run_mode == pypto.RunMode.NPU:
torch.npu.synchronize()
golden = x + y
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(
np.array(result.cpu()), np.array(golden.cpu()),
rtol=1e-3, atol=1e-3
)
print(f" batch={bs}, hidden={hs}: "
f"Input shapes {x.shape}, {y.shape} -> Output shape {result.shape}")
print("✓ Multiple dynamic dimensions passed for all test cases")
print()
def main():
"""Run dynamic shape examples.
Usage:
python dynamic.py # Run all examples
python dynamic.py dynamic_mul::test_dynamic_mul # Run a specific case
python dynamic.py --list # List all available examples
"""
parser = argparse.ArgumentParser(
description="PyPTO Dynamic Shape Examples",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s Run all examples
%(prog)s dynamic_mul::test_dynamic_mul Run a specific case
%(prog)s --list List all available examples
"""
)
parser.add_argument(
'example_id',
type=str,
nargs='?',
help='Example ID to run. If not specified, all examples will run.'
)
parser.add_argument(
'--list',
action='store_true',
help='List all available examples and exit'
)
parser.add_argument(
'--run_mode',
type=str,
nargs='?',
default='npu',
choices=["npu", "sim"],
help='Run mode, supports npu and sim.'
)
args = parser.parse_args()
examples = {
'dynamic_mul::test_dynamic_mul': {
'name': 'Basic dynamic batch mul',
'description': 'Demonstrates basic view/assemble tiling with a single dynamic batch dimension',
'function': test_dynamic_mul,
},
'dynamic_partial::test_dynamic_partial': {
'name': 'Partial dynamic dimensions (softmax)',
'description': 'Only batch is dynamic; seqlen, head, dim stay concrete',
'function': test_dynamic_partial,
},
'dynamic_attention::test_dynamic_attention': {
'name': 'Dynamic attention',
'description': 'Scaled dot-product attention with dynamic batch size, view/assemble on 4D tensors',
'function': test_dynamic_attention,
},
'dynamic_multi_dim::test_dynamic_multi_dim': {
'name': 'Multiple dynamic dimensions',
'description': 'Both batch and hidden dimensions are dynamic, nested view/assemble loops',
'function': test_dynamic_multi_dim,
},
}
if args.list:
print("\n" + "=" * 60)
print("Available Examples")
print("=" * 60 + "\n")
for ex_id, ex_info in sorted(examples.items()):
print(f" ID: {ex_id}")
print(f" name: {ex_info['name']}")
print(f" description: {ex_info['description']}\n")
return
if args.example_id is not None:
if args.example_id not in examples:
print(f"ERROR: Invalid example ID: {args.example_id}")
print(f"Valid example IDs are: {', '.join(sorted(examples.keys()))}")
print("\nUse --list to see all available examples.")
sys.exit(1)
print("\n" + "=" * 60)
print("PyPTO Dynamic Shape Examples")
print("=" * 60 + "\n")
device_id = None
examples_to_run = []
if args.example_id is not None:
examples_to_run = [(args.example_id, examples[args.example_id])]
else:
examples_to_run = list(examples.items())
if args.run_mode == "npu":
device_id = get_device_id()
if device_id is None:
return
import torch_npu
torch.npu.set_device(device_id)
print("Running examples that require NPU hardware...")
print("(Make sure CANN environment is configured and NPU is available)\n")
try:
for ex_id, ex_info in examples_to_run:
print(f"Running Example {ex_id}: {ex_info['name']}")
ex_info['function'](device_id)
if len(examples_to_run) > 1:
print("=" * 60)
print("All dynamic shape tests passed!")
print("=" * 60)
except Exception as e:
print(f"\nError: {e}")
raise
if __name__ == "__main__":
main()