al.fixpipe 接口文档

1. 硬件背景

A5增加了L0C到UB的数据通路,为实现此通路,临时方案在前端显式调用此通路

2. 接口说明

Python
def fixpipe(
src: tl.tensor,
dst: bl.buffer,
dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND,
dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL,
pre_quant_mode: FixpipePreQuantMode = FixpipePreQuantMode.NO_QUANT,
pre_relu_mode: FixpipePreReluMode = FixpipePreReluMode.NO_RELU,
_builder=None,
) -> None:

class FixpipeDMAMode(enum.Enum):
NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN
NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND
NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ
class FixpipeDualDstMode(enum.Enum):
NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL
COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT
ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT
class FixpipePreQuantMode(enum.Enum):
NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT
F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16
F322F16 = ascend_ir.FixpipePreQuantMode.F322F16
S322I8 = ascend_ir.FixpipePreQuantMode.S322I8
class FixpipePreReluMode(enum.Enum):
LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU
NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU
NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU
P_RELU = ascend_ir.FixpipePreReluMode.P_RELU

2.1 入参

参数名 类型 说明
src tl.tensor 源张量,必须位于l0C内存区域
dst bl.buffer 目标缓冲区,必须位于UB内存区域
dma_mode al.FixpipeDMAMode HIVM数据搬运模式,可选值:NZ2DN、NZ2ND、NZ2NZ
dual_dst_mode al.FixpipeDualDstMode 双目标模式控制,仅NZ2ND/普通模式可启用
pre_quant_mode al.FixpipePreQuantMode 量化/类型转换模式
pre_relu_mode al.FixpipePreReluMode 激活函数模式
_builder - JIT自动传参

2.2 返回值

无返回值,直接使用入参dst

3. 约束说明

  • fixpipe仅支持从 l0c到ub的数据搬运

  • src必须是dot后的结果

  • dst必须是memscope为ub的buffer

4. 用例示例

Python
import os
import triton
import triton.language as tl
import triton.extension.buffer.language as bl
import triton.language.extra.cann.extension as al
from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
from triton._C.libtriton import ir
from triton._C.libtriton.ascend import ir as ascend_ir

os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"


class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False
arch = "Ascend910_95"


def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {}, {})
return str(module)


@triton.jit
def fixpipe(
A_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
):

row_matmul = tl.program_id(0)

offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis)
offs_k = tl.arange(0, K) # [K]

a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :]
a_vals = tl.load(a_ptrs) # [M, K]

ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB)
al.fixpipe(a_vals, ub, dual_dst_mode=al.FixpipeDualDstMode.NO_DUAL)

def test_fixpipe(M, K, N):
mlir = compile_kernel(
fixpipe,
{
"A_ptr": "*fp32",
},
{"M": M, "K": K, "N": N},
)
assert len(mlir) > 0
print(mlir)

# ============== Main for manual testing ==============
if __name__ == "__main__":
test_fixpipe(16, 16, 16)

5. 编译输出结果

Plain Text
module {
tt.func public @fixpipe(%arg0: !tt.ptr<f32> loc("/home/linxin/triton-test/fixpipe.py":35:0)) attributes {noinline = false} {
%0 = tt.get_program_id x : i32 loc(#loc1)
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2)
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> loc(#loc3)
%3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc4)
%4 = tt.splat %0 : i32 -> tensor<16x1xi32> loc(#loc5)
%5 = arith.addi %4, %2 : tensor<16x1xi32> loc(#loc5)
%c16_i32 = arith.constant 16 : i32 loc(#loc6)
%c16_i32_0 = arith.constant 16 : i32 loc(#loc6)
%cst = arith.constant dense<16> : tensor<16x1xi32> loc(#loc6)
%6 = arith.muli %5, %cst : tensor<16x1xi32> loc(#loc6)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>> loc(#loc7)
%8 = tt.addptr %7, %6 : tensor<16x1x!tt.ptr<f32>>, tensor<16x1xi32> loc(#loc7)
%9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc8)
%10 = tt.broadcast %8 : tensor<16x1x!tt.ptr<f32>> -> tensor<16x16x!tt.ptr<f32>> loc(#loc9)
%11 = tt.broadcast %9 : tensor<1x16xi32> -> tensor<16x16xi32> loc(#loc9)
%12 = tt.addptr %10, %11 : tensor<16x16x!tt.ptr<f32>>, tensor<16x16xi32> loc(#loc9)
%13 = tt.load %12 : tensor<16x16x!tt.ptr<f32>> loc(#loc10)
%alloc = memref.alloc() : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11)
annotation.mark %alloc {effects = ["write", "read"]} : memref<16x16xf32, #hivm.address_space<ub>> loc(#loc11)
hivm.hir.fixpipe {dma_mode = #hivm.dma_mode<nz2nd>} ins(%13 : tensor<16x16xf32>) outs(%alloc : memref<16x16xf32, #hivm.address_space<ub>>) dual_dst_mode = <NO_DUAL> loc(#loc12)
tt.return loc(#loc13)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home/linxin/triton-test/fixpipe.py":42:31)
#loc2 = loc("/home/linxin/triton-test/fixpipe.py":44:26)
#loc3 = loc("/home/linxin/triton-test/fixpipe.py":44:43)
#loc4 = loc("/home/linxin/triton-test/fixpipe.py":45:26)
#loc5 = loc("/home/linxin/triton-test/fixpipe.py":47:35)
#loc6 = loc("/home/linxin/triton-test/fixpipe.py":47:45)
#loc7 = loc("/home/linxin/triton-test/fixpipe.py":47:21)
#loc8 = loc("/home/linxin/triton-test/fixpipe.py":47:56)
#loc9 = loc("/home/linxin/triton-test/fixpipe.py":47:49)
#loc10 = loc("/home/linxin/triton-test/fixpipe.py":48:21)
#loc11 = loc("/home/linxin/triton-test/fixpipe.py":50:38)
#loc12 = loc("/home/linxin/triton-test/fixpipe.py":51:23)
#loc13 = loc("/home/linxin/triton-test/fixpipe.py":51:4)