al.fixpipe API Documentation

1. Hardware Background

A5 adds a data path from L0C to UB. To enable this path, the temporary solution is to expose it through an explicit front-end call.

2. Interface Description

Python
def fixpipe(
src: tl.tensor,
dst: bl.buffer,
dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND,
dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL,
pre_quant_mode: FixpipePreQuantMode = FixpipePreQuantMode.NO_QUANT,
pre_relu_mode: FixpipePreReluMode = FixpipePreReluMode.NO_RELU,
_builder=None,
) -> None:

class FixpipeDMAMode(enum.Enum):
NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN
NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND
NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ
class FixpipeDualDstMode(enum.Enum):
NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL
COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT
ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT
class FixpipePreQuantMode(enum.Enum):
NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT
F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16
F322F16 = ascend_ir.FixpipePreQuantMode.F322F16
S322I8 = ascend_ir.FixpipePreQuantMode.S322I8
class FixpipePreReluMode(enum.Enum):
LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU
NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU
NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU
P_RELU = ascend_ir.FixpipePreReluMode.P_RELU

2.1 Parameters

Parameter Type Description
src tl.tensor Source tensor. It must reside in the L0C memory region
dst bl.buffer Destination buffer. It must reside in the UB memory region
dma_mode al.FixpipeDMAMode HIVM data-movement mode. Allowed values: `NZ2DN`, `NZ2ND`, `NZ2NZ`
dual_dst_mode al.FixpipeDualDstMode Dual-destination mode control. Can only be enabled in `NZ2ND` / normal mode
pre_quant_mode al.FixpipePreQuantMode Quantization / type-conversion mode
pre_relu_mode al.FixpipePreReluMode Activation-function mode
_builder - Automatically passed by JIT

2.2 Return Value

No return value; the input dst is used directly.

3. Constraints

  • fixpipe only supports data movement from L0C to UB.

  • src must be the result produced by dot.

  • dst must be a buffer whose memscope is UB.

4. Example Usage

Python
import os
import triton
import triton.language as tl
import triton.extension.buffer.language as bl
import triton.language.extra.cann.extension as al
from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
from triton._C.libtriton import ir
from triton._C.libtriton.ascend import ir as ascend_ir

os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"


class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False
arch = "Ascend910_95"


def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {}, {})
return str(module)


@triton.jit
def fixpipe(
A_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
):

row_matmul = tl.program_id(0)

offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis)
offs_k = tl.arange(0, K) # [K]

a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :]
a_vals = tl.load(a_ptrs) # [M, K]

ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB)
al.fixpipe(a_vals, ub, dual_dst_mode=al.FixpipeDualDstMode.NO_DUAL)

def test_fixpipe(M, K, N):
mlir = compile_kernel(
fixpipe,
{
"A_ptr": "*fp32",
},
{"M": M, "K": K, "N": N},
)
assert len(mlir) > 0
print(mlir)

# ============== Main for manual testing ==============
if __name__ == "__main__":
test_fixpipe(16, 16, 16)

5. Compilation Output

Plain Text
module {
tt.func public @fixpipe(%arg0: !tt.ptr<f32> loc("/home/linxin/triton-test/fixpipe.py":35:0)) attributes {noinline = false} {
%0 = tt.get_program_id x : i32 loc(#loc1)
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2)
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> loc(#loc3)
%3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc4)
%4 = tt.splat %0 : i32 -> tensor<16x1xi32> loc(#loc5)
%5 = arith.addi %4, %2 : tensor<16x1xi32> loc(#loc5)
%c16_i32 = arith.constant 16 : i32 loc(#loc6)
%c16_i32_0 = arith.constant 16 : i32 loc(#loc6)
%cst = arith.constant dense<16> : tensor<16x1xi32> loc(#loc6)
%6 = arith.muli %5, %cst : tensor<16x1xi32> loc(#loc6)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>> loc(#loc7)
%8 = tt.addptr %7, %6 : tensor<16x1x!tt.ptr<f32>>, tensor<16x1xi32> loc(#loc7)
%9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc8)
%10 = tt.broadcast %8 : tensor<16x1x!tt.ptr<f32>> -> tensor<16x16x!tt.ptr<f32>> loc(#loc9)
%11 = tt.broadcast %9 : tensor<1x16xi32> -> tensor<16x16xi32> loc(#loc9)
%12 = tt.addptr %10, %11 : tensor<16x16x!tt.ptr<f32>>, tensor<16x16xi32> loc(#loc9)
%13 = tt.load %12 : tensor<16x16x!tt.ptr<f32>> loc(#loc10)
%alloc = memref.alloc() : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11)
annotation.mark %alloc {effects = ["write", "read"]} : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11)
hivm.hir.fixpipe {dma_mode = #hivm.dma_mode<nz2nd>} ins(%13 : tensor<16x16xf32>) outs(%alloc : memref<16x16xf32, #hivm.address_space<ub>>) dual_dst_mode = <NO_DUAL> loc(#loc12)
tt.return loc(#loc13)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home/linxin/triton-test/fixpipe.py":42:31)
#loc2 = loc("/home/linxin/triton-test/fixpipe.py":44:26)
#loc3 = loc("/home/linxin/triton-test/fixpipe.py":44:43)
#loc4 = loc("/home/linxin/triton-test/fixpipe.py":45:26)
#loc5 = loc("/home/linxin/triton-test/fixpipe.py":47:35)
#loc6 = loc("/home/linxin/triton-test/fixpipe.py":47:45)
#loc7 = loc("/home/linxin/triton-test/fixpipe.py":47:21)
#loc8 = loc("/home/linxin/triton-test/fixpipe.py":47:56)
#loc9 = loc("/home/linxin/triton-test/fixpipe.py":47:49)
#loc10 = loc("/home/linxin/triton-test/fixpipe.py":48:21)
#loc11 = loc("/home/linxin/triton-test/fixpipe.py":50:38)
#loc12 = loc("/home/linxin/triton-test/fixpipe.py":51:23)
#loc13 = loc("/home/linxin/triton-test/fixpipe.py":51:4)