al.fixpipe API Documentation
1. Hardware Background
A5 adds a data path from L0C to UB. To enable this path, the temporary solution is to expose it through an explicit front-end call.
2. Interface Description
| Python def fixpipe( src: tl.tensor, dst: bl.buffer, dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, pre_quant_mode: FixpipePreQuantMode = FixpipePreQuantMode.NO_QUANT, pre_relu_mode: FixpipePreReluMode = FixpipePreReluMode.NO_RELU, _builder=None, ) -> None: class FixpipeDMAMode(enum.Enum): NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ class FixpipeDualDstMode(enum.Enum): NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT class FixpipePreQuantMode(enum.Enum): NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 class FixpipePreReluMode(enum.Enum): LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU P_RELU = ascend_ir.FixpipePreReluMode.P_RELU |
2.1 Parameters
| Parameter | Type | Description |
| src | tl.tensor | Source tensor. It must reside in the L0C memory region |
| dst | bl.buffer | Destination buffer. It must reside in the UB memory region |
| dma_mode | al.FixpipeDMAMode | HIVM data-movement mode. Allowed values: `NZ2DN`, `NZ2ND`, `NZ2NZ` |
| dual_dst_mode | al.FixpipeDualDstMode | Dual-destination mode control. Can only be enabled in `NZ2ND` / normal mode |
| pre_quant_mode | al.FixpipePreQuantMode | Quantization / type-conversion mode |
| pre_relu_mode | al.FixpipePreReluMode | Activation-function mode |
| _builder | - | Automatically passed by JIT |
2.2 Return Value
No return value; the input dst is used directly.
3. Constraints
-
fixpipeonly supports data movement from L0C to UB. -
srcmust be the result produced bydot. -
dstmust be a buffer whose memscope is UB.
4. Example Usage
| Python import os import triton import triton.language as tl import triton.extension.buffer.language as bl import triton.language.extra.cann.extension as al from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir from triton._C.libtriton import ir from triton._C.libtriton.ascend import ir as ascend_ir os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False arch = "Ascend910_95" def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) @triton.jit def fixpipe( A_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, ): row_matmul = tl.program_id(0) offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis) offs_k = tl.arange(0, K) # [K] a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :] a_vals = tl.load(a_ptrs) # [M, K] ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB) al.fixpipe(a_vals, ub, dual_dst_mode=al.FixpipeDualDstMode.NO_DUAL) def test_fixpipe(M, K, N): mlir = compile_kernel( fixpipe, { "A_ptr": "*fp32", }, {"M": M, "K": K, "N": N}, ) assert len(mlir) > 0 print(mlir) # ============== Main for manual testing ============== if __name__ == "__main__": test_fixpipe(16, 16, 16) |
5. Compilation Output
| Plain Text module { tt.func public @fixpipe(%arg0: !tt.ptr<f32> loc("/home/linxin/triton-test/fixpipe.py":35:0)) attributes {noinline = false} { %0 = tt.get_program_id x : i32 loc(#loc1) %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> loc(#loc3) %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc4) %4 = tt.splat %0 : i32 -> tensor<16x1xi32> loc(#loc5) %5 = arith.addi %4, %2 : tensor<16x1xi32> loc(#loc5) %c16_i32 = arith.constant 16 : i32 loc(#loc6) %c16_i32_0 = arith.constant 16 : i32 loc(#loc6) %cst = arith.constant dense<16> : tensor<16x1xi32> loc(#loc6) %6 = arith.muli %5, %cst : tensor<16x1xi32> loc(#loc6) %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>> loc(#loc7) %8 = tt.addptr %7, %6 : tensor<16x1x!tt.ptr<f32>>, tensor<16x1xi32> loc(#loc7) %9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc8) %10 = tt.broadcast %8 : tensor<16x1x!tt.ptr<f32>> -> tensor<16x16x!tt.ptr<f32>> loc(#loc9) %11 = tt.broadcast %9 : tensor<1x16xi32> -> tensor<16x16xi32> loc(#loc9) %12 = tt.addptr %10, %11 : tensor<16x16x!tt.ptr<f32>>, tensor<16x16xi32> loc(#loc9) %13 = tt.load %12 : tensor<16x16x!tt.ptr<f32>> loc(#loc10) %alloc = memref.alloc() : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11) annotation.mark %alloc {effects = ["write", "read"]} : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11) hivm.hir.fixpipe {dma_mode = #hivm.dma_mode<nz2nd>} ins(%13 : tensor<16x16xf32>) outs(%alloc : memref<16x16xf32, #hivm.address_space<ub>>) dual_dst_mode = <NO_DUAL> loc(#loc12) tt.return loc(#loc13) } loc(#loc) } loc(#loc) #loc1 = loc("/home/linxin/triton-test/fixpipe.py":42:31) #loc2 = loc("/home/linxin/triton-test/fixpipe.py":44:26) #loc3 = loc("/home/linxin/triton-test/fixpipe.py":44:43) #loc4 = loc("/home/linxin/triton-test/fixpipe.py":45:26) #loc5 = loc("/home/linxin/triton-test/fixpipe.py":47:35) #loc6 = loc("/home/linxin/triton-test/fixpipe.py":47:45) #loc7 = loc("/home/linxin/triton-test/fixpipe.py":47:21) #loc8 = loc("/home/linxin/triton-test/fixpipe.py":47:56) #loc9 = loc("/home/linxin/triton-test/fixpipe.py":47:49) #loc10 = loc("/home/linxin/triton-test/fixpipe.py":48:21) #loc11 = loc("/home/linxin/triton-test/fixpipe.py":50:38) #loc12 = loc("/home/linxin/triton-test/fixpipe.py":51:23) #loc13 = loc("/home/linxin/triton-test/fixpipe.py":51:4) |