"""
Matrix Multiplication (matmul) Operation Examples for PyPTO
This file contains all matrix multiplication examples merged into a single file.
You can run all examples or select specific ones using command-line arguments.
Usage:
python matmul_ops.py # Run all examples
python matmul_ops.py --list # List all available examples
python matmul_ops.py matmul::test_matmul_basic # Run a specific case
"""
import argparse
import os
import sys
import pypto
import torch
import numpy as np
from numpy.testing import assert_allclose
def _peek_run_mode_from_argv(default: str = "npu") -> str:
"""Read run_mode early so module-level decorators can use it."""
for idx, arg in enumerate(sys.argv):
if arg == "--run_mode" and idx + 1 < len(sys.argv):
value = sys.argv[idx + 1]
if value in ("npu", "sim"):
return value
if arg.startswith("--run_mode="):
value = arg.split("=", 1)[1]
if value in ("npu", "sim"):
return value
return default
global_run_mode = pypto.RunMode.NPU
if _peek_run_mode_from_argv("npu") == "sim":
global_run_mode = pypto.RunMode.SIM
def get_device_id():
"""
Get and validate TILE_FWK_DEVICE_ID from environment variable.
Returns:
int: The device ID if valid, None otherwise.
"""
if 'TILE_FWK_DEVICE_ID' not in os.environ:
print("Please set the environment variable TILE_FWK_DEVICE_ID before running:")
print(" export TILE_FWK_DEVICE_ID=0")
return None
try:
device_id = int(os.environ['TILE_FWK_DEVICE_ID'])
return device_id
except ValueError:
print(f"ERROR: TILE_FWK_DEVICE_ID must be an integer, got: {os.environ['TILE_FWK_DEVICE_ID']}")
return None
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def matmul_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
out: pypto.Tensor([], pypto.DT_FP32)):
pypto.set_cube_tile_shapes([32, 32], [64, 64], [64, 64])
out[:] = pypto.matmul(a, b, pypto.DT_FP32)
def test_matmul_basic(device_id: int = None):
"""Test basic matrix multiplication"""
print("=" * 60)
print("Test: Basic Matrix Multiplication")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
dtype = torch.float32
a = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device)
b = torch.tensor([[5, 6], [7, 8]], dtype=dtype, device=device)
expected = torch.tensor([[19, 22], [43, 50]], dtype=dtype, device=device)
out = torch.empty((a.shape[0], b.shape[1]), dtype=dtype, device=device)
matmul_kernel(a, b, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output: {out}")
print(f"Expected: {expected}")
print("✓ Basic matrix multiplication completed successfully")
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def matmul_batch_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
out: pypto.Tensor([], pypto.DT_FP32)):
pypto.set_cube_tile_shapes([32, 32], [64, 64], [64, 64])
out[:] = pypto.matmul(a, b, pypto.DT_FP32)
def test_matmul_batch(device_id: int = None):
"""Test batch matrix multiplication"""
print("=" * 60)
print("Test: Batch Matrix Multiplication")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
dtype = torch.float32
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype, device=device)
b = torch.tensor([[[5, 6], [7, 8]], [[1, 2], [3, 4]]], dtype=dtype, device=device)
expected = torch.tensor([[[19, 22], [43, 50]], [[23, 34], [31, 46]]], dtype=dtype, device=device)
out = torch.empty(a.shape, dtype=dtype, device=device)
matmul_batch_kernel(a, b, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output: {out}")
print(f"Expected: {expected}")
print("✓ Batch matrix multiplication completed successfully")
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def matmul_broadcast_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
out: pypto.Tensor([], pypto.DT_FP32)):
pypto.set_cube_tile_shapes([32, 32], [64, 64], [64, 64])
out[:] = pypto.matmul(a, b, pypto.DT_FP32)
def test_matmul_broadcast(device_id: int = None):
"""Test batch matrix multiplication with broadcasting"""
print("=" * 60)
print("Test: Batch Matrix Multiplication with Broadcasting")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
dtype = torch.float32
a = torch.tensor([[[1, 2], [3, 4]]], dtype=dtype, device=device)
b = torch.tensor([[[5, 6], [7, 8]], [[1, 2], [3, 4]]], dtype=dtype, device=device)
expected = torch.tensor([[[19, 22], [43, 50]], [[7, 10], [15, 22]]], dtype=dtype, device=device)
out = torch.empty(b.shape, dtype=dtype, device=device)
matmul_broadcast_kernel(a, b, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output: {out}")
print(f"Expected: {expected}")
print("✓ Batch matrix multiplication with broadcasting completed successfully")
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def matmul_trans_right_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
out: pypto.Tensor([], pypto.DT_FP32)):
pypto.set_cube_tile_shapes([32, 32], [64, 64], [64, 64])
out[:] = pypto.matmul(a, b, pypto.DT_FP32, b_trans=True)
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def matmul_trans_left_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
out: pypto.Tensor([], pypto.DT_FP32)):
pypto.set_cube_tile_shapes([32, 32], [64, 64], [64, 64])
out[:] = pypto.matmul(a, b, pypto.DT_FP32, a_trans=True)
def test_matmul_trans(device_id: int = None):
"""Test matrix multiplication with transposition"""
print("=" * 60)
print("Test: Matrix Multiplication with Transposition")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
dtype = torch.float32
a = torch.tensor([[1, 2, 3],
[4, 5, 6]], dtype=dtype, device=device)
b = torch.tensor([[7, 8],
[9, 10],
[11, 12]], dtype=dtype, device=device)
expected = torch.tensor([[58, 64],
[139, 154]], dtype=dtype, device=device)
out = torch.empty((a.shape[0], b.shape[1]), dtype=dtype, device=device)
matmul_kernel(a, b, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output basic: {out}")
print(f"Expected basic: {expected}")
dtype = torch.float32
a = torch.tensor([[1, 2, 3],
[4, 5, 6]], dtype=dtype, device=device)
b = torch.tensor([[7, 9, 11],
[8, 10, 12]], dtype=dtype, device=device)
expected = torch.tensor([[58, 64],
[139, 154]], dtype=dtype, device=device)
out = torch.empty((a.shape[0], b.shape[0]), dtype=dtype, device=device)
matmul_trans_right_kernel(a, b, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output trans right: {out}")
print(f"Expected trans right: {expected}")
dtype = torch.float32
a = torch.tensor([[1, 4],
[2, 5],
[3, 6]], dtype=dtype, device=device)
b = torch.tensor([[7, 8],
[9, 10],
[11, 12]], dtype=dtype, device=device)
expected = torch.tensor([[58, 64],
[139, 154]], dtype=dtype, device=device)
out = torch.empty((a.shape[1], b.shape[1]), dtype=dtype, device=device)
matmul_trans_left_kernel(a, b, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output trans left: {out}")
print(f"Expected trans left: {expected}")
print("✓ Matrix multiplication with transposition completed successfully")
@pypto.frontend.jit(runtime_options={"run_mode": global_run_mode})
def matmul_bias_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP32),
bias: pypto.Tensor([], pypto.DT_FP32),
out: pypto.Tensor([], pypto.DT_FP32)):
extend_params = {"bias_tensor": bias}
pypto.set_cube_tile_shapes([32, 32], [64, 64], [64, 64])
out[:] = pypto.matmul(a, b, pypto.DT_FP32, extend_params=extend_params)
def test_matmul_bias(device_id: int = None):
"""Test matrix multiplication with bias"""
print("=" * 60)
print("Test: Matrix Multiplication with Bias")
print("=" * 60)
device = f'npu:{device_id}' if global_run_mode == pypto.RunMode.NPU and device_id is not None else 'cpu'
dtype = torch.float32
a = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device)
b = torch.tensor([[5, 6], [7, 8]], dtype=dtype, device=device)
bias = torch.tensor([[1, 2]], dtype=dtype, device=device)
expected = torch.tensor([[20, 24], [44, 52]], dtype=dtype, device=device)
out = torch.empty(b.shape, dtype=dtype, device=device)
matmul_bias_kernel(a, b, bias, out)
if global_run_mode == pypto.RunMode.NPU:
assert_allclose(out.cpu().numpy(), expected.cpu().numpy(), rtol=1e-3, atol=1e-3)
print(f"Output: {out}")
print(f"Expected: {expected}")
print("✓ Matrix multiplication with bias completed successfully")
def main():
"""Run matrix multiplication examples.
Usage:
python matmul_ops.py # Run all examples
python matmul_ops.py --list # List all available examples
python matmul_ops.py matmul::test_matmul_basic # Run a specific case
"""
parser = argparse.ArgumentParser(
description="PyPTO Matrix Multiplication (matmul) Operation Examples",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s Run all examples
%(prog)s --list List all available examples
%(prog)s matmul::test_matmul_basic Run a specific case
"""
)
parser.add_argument(
'example_id',
type=str,
nargs="?",
help='Run a specific case (e.g., matmul::test_matmul_basic). If omitted, all cases run.'
)
parser.add_argument(
'--list',
action='store_true',
help='List all available examples and exit'
)
parser.add_argument(
"--run_mode", "--run-mode",
nargs="?", type=str, default="npu", choices=["npu", "sim"],
help="Run mode, supports npu and sim."
)
args = parser.parse_args()
examples = {
'matmul::test_matmul_basic': {
'name': 'Test basic matrix multiplication',
'description': 'Basic matrix multiplication example',
'function': test_matmul_basic
},
'matmul::test_matmul_batch': {
'name': 'Test batch matrix multiplication',
'description': 'Batch matrix multiplication example',
'function': test_matmul_batch
},
'matmul::test_matmul_broadcast': {
'name': 'Test batch matrix multiplication with broadcasting',
'description': 'Batch matrix multiplication with broadcasting example',
'function': test_matmul_broadcast
},
'matmul::test_matmul_trans': {
'name': 'Test matrix multiplication with transposition',
'description': 'Matrix multiplication with transposition example',
'function': test_matmul_trans
},
'matmul::test_matmul_bias': {
'name': 'Test matrix multiplication with bias',
'description': 'Matrix multiplication with bias example',
'function': test_matmul_bias
}
}
if args.list:
print("\n" + "=" * 60)
print("Available Examples")
print("=" * 60 + "\n")
for case_key, ex_info in sorted(examples.items()):
print(f" {case_key}")
print(f" {ex_info['name']}")
print(f" {ex_info['description']}\n")
return
examples_to_run = []
if args.example_id:
if args.example_id not in examples:
print(f"ERROR: Invalid case: {args.example_id}")
print(f"Valid cases are: {', '.join(sorted(examples.keys()))}")
print("\nUse --list to see all available examples.")
sys.exit(1)
examples_to_run = [(args.example_id, examples[args.example_id])]
else:
examples_to_run = [(key, info) for key, info in sorted(examples.items())]
print("\n" + "=" * 60)
print("PyPTO Matrix Multiplication (matmul) Operation Examples")
print("=" * 60 + "\n")
device_id = None
if args.run_mode == "npu":
device_id = get_device_id()
if device_id is None:
return
import torch_npu
torch.npu.set_device(device_id)
try:
for case_key, ex_info in examples_to_run:
if args.run_mode == "npu" and device_id is None:
print(f"Skipping {case_key} ({ex_info['name']}): NPU device not configured")
continue
ex_info['function'](device_id)
if len(examples_to_run) > 1:
print("=" * 60)
print("All matmul tests passed!")
print("=" * 60)
except Exception as e:
print(f"\nError: {e}")
raise
if __name__ == "__main__":
main()