al.fixpipe 接口文档
1. 硬件背景
A5增加了L0C到UB的数据通路,为实现此通路,临时方案在前端显式调用此通路
2. 接口说明
| Python def fixpipe( src: tl.tensor, dst: bl.buffer, dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, pre_quant_mode: FixpipePreQuantMode = FixpipePreQuantMode.NO_QUANT, pre_relu_mode: FixpipePreReluMode = FixpipePreReluMode.NO_RELU, _builder=None, ) -> None: class FixpipeDMAMode(enum.Enum): NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ class FixpipeDualDstMode(enum.Enum): NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT class FixpipePreQuantMode(enum.Enum): NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 class FixpipePreReluMode(enum.Enum): LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU P_RELU = ascend_ir.FixpipePreReluMode.P_RELU |
2.1 入参
| 参数名 | 类型 | 说明 |
| src | tl.tensor | 源张量,必须位于l0C内存区域 |
| dst | bl.buffer | 目标缓冲区,必须位于UB内存区域 |
| dma_mode | al.FixpipeDMAMode | HIVM数据搬运模式,可选值:NZ2DN、NZ2ND、NZ2NZ |
| dual_dst_mode | al.FixpipeDualDstMode | 双目标模式控制,仅NZ2ND/普通模式可启用 |
| pre_quant_mode | al.FixpipePreQuantMode | 量化/类型转换模式 |
| pre_relu_mode | al.FixpipePreReluMode | 激活函数模式 |
| _builder | - | JIT自动传参 |
2.2 返回值
无返回值,直接使用入参dst
3. 约束说明
-
fixpipe仅支持从 l0c到ub的数据搬运
-
src必须是dot后的结果
-
dst必须是memscope为ub的buffer
4. 用例示例
| Python import os import triton import triton.language as tl import triton.extension.buffer.language as bl import triton.language.extra.cann.extension as al from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir from triton._C.libtriton import ir from triton._C.libtriton.ascend import ir as ascend_ir os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False arch = "Ascend910_95" def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) @triton.jit def fixpipe( A_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, ): row_matmul = tl.program_id(0) offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis) offs_k = tl.arange(0, K) # [K] a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :] a_vals = tl.load(a_ptrs) # [M, K] ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB) al.fixpipe(a_vals, ub, dual_dst_mode=al.FixpipeDualDstMode.NO_DUAL) def test_fixpipe(M, K, N): mlir = compile_kernel( fixpipe, { "A_ptr": "*fp32", }, {"M": M, "K": K, "N": N}, ) assert len(mlir) > 0 print(mlir) # ============== Main for manual testing ============== if __name__ == "__main__": test_fixpipe(16, 16, 16) |
5. 编译输出结果
| Plain Text module { tt.func public @fixpipe(%arg0: !tt.ptr<f32> loc("/home/linxin/triton-test/fixpipe.py":35:0)) attributes {noinline = false} { %0 = tt.get_program_id x : i32 loc(#loc1) %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> loc(#loc3) %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc4) %4 = tt.splat %0 : i32 -> tensor<16x1xi32> loc(#loc5) %5 = arith.addi %4, %2 : tensor<16x1xi32> loc(#loc5) %c16_i32 = arith.constant 16 : i32 loc(#loc6) %c16_i32_0 = arith.constant 16 : i32 loc(#loc6) %cst = arith.constant dense<16> : tensor<16x1xi32> loc(#loc6) %6 = arith.muli %5, %cst : tensor<16x1xi32> loc(#loc6) %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>> loc(#loc7) %8 = tt.addptr %7, %6 : tensor<16x1x!tt.ptr<f32>>, tensor<16x1xi32> loc(#loc7) %9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc8) %10 = tt.broadcast %8 : tensor<16x1x!tt.ptr<f32>> -> tensor<16x16x!tt.ptr<f32>> loc(#loc9) %11 = tt.broadcast %9 : tensor<1x16xi32> -> tensor<16x16xi32> loc(#loc9) %12 = tt.addptr %10, %11 : tensor<16x16x!tt.ptr<f32>>, tensor<16x16xi32> loc(#loc9) %13 = tt.load %12 : tensor<16x16x!tt.ptr<f32>> loc(#loc10) %alloc = memref.alloc() : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11) annotation.mark %alloc {effects = ["write", "read"]} : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11) hivm.hir.fixpipe {dma_mode = #hivm.dma_mode<nz2nd>} ins(%13 : tensor<16x16xf32>) outs(%alloc : memref<16x16xf32, #hivm.address_space<ub>>) dual_dst_mode = <NO_DUAL> loc(#loc12) tt.return loc(#loc13) } loc(#loc) } loc(#loc) #loc1 = loc("/home/linxin/triton-test/fixpipe.py":42:31) #loc2 = loc("/home/linxin/triton-test/fixpipe.py":44:26) #loc3 = loc("/home/linxin/triton-test/fixpipe.py":44:43) #loc4 = loc("/home/linxin/triton-test/fixpipe.py":45:26) #loc5 = loc("/home/linxin/triton-test/fixpipe.py":47:35) #loc6 = loc("/home/linxin/triton-test/fixpipe.py":47:45) #loc7 = loc("/home/linxin/triton-test/fixpipe.py":47:21) #loc8 = loc("/home/linxin/triton-test/fixpipe.py":47:56) #loc9 = loc("/home/linxin/triton-test/fixpipe.py":47:49) #loc10 = loc("/home/linxin/triton-test/fixpipe.py":48:21) #loc11 = loc("/home/linxin/triton-test/fixpipe.py":50:38) #loc12 = loc("/home/linxin/triton-test/fixpipe.py":51:23) #loc13 = loc("/home/linxin/triton-test/fixpipe.py":51:4) |