subview
1. 硬件背景
昇腾硬件 A5 支持了定义新视图,仅通过偏移、大小和步幅实现,不复制底层数据。
2. 接口说明
接口一
| Python def subview( src: bl.buffer, offsets: List[tl.tensor], sizes: List[tl.constexpr], strides: List[tl.constexpr], builder: ir.builder ) |
接口二
| Python def subview( self, offsets: List[tl.tensor], sizes: List[tl.constexpr], strides: List[tl.constexpr], _builder=None ) |
返回值:bl.buffer
3. 入参说明
-
src: buffer -> 源buffer
-
offsets: List[tl.tensor] -> 偏移
-
sizes: List[tl.constexpr] -> 输出的size
-
strides: List[tl.constexpr] -> 步长
4. 约束说明
-
输入的参数size、offset、stride必须大于0(offset可以是0),不能为负值。
-
size的每一个维度的大小不能大于原buffer的大小。
-
子视图的每一个维度的大小不能超过原buffer的大小。
-
stride的访问不能超过src的大小,stride所有元素全为1。
-
参数的设置要指明每一个维度的值,参数维度应该和输入buffer的维度保持一致。
-
offset必须32字节对齐。
-
子视图中最后一个维度的第二行第一个点的偏移必须是32字节对齐。
解释补充:sizes、strides在使用的时候传入类型:List[tl.constexpr](注意不要误传tensor,否则会报错-类型不匹配)。offsets补充支持了tensor传入(也可以传入constexpr)
5. 用例示例
| Python import os os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" import triton import triton.language as tl from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir import triton.extension.buffer.language as bl import triton.language.extra.cann.extension as al from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, {}) return str(module) @triton.jit def test_subview_kernel1(XBLOCK: tl.constexpr): # 1. Allocate a local buffer src_buffer = bl.alloc(tl.float32, [XBLOCK, XBLOCK]) result_buffer = bl.subview( src_buffer, offsets=[1, 0], sizes=[XBLOCK - 2, XBLOCK], strides=[1, 1], ) @triton.jit def test_subview_kernel2(XBLOCK: tl.constexpr, offset: tl.constexpr, size: tl.constexpr, stride: tl.constexpr): # Reuse the 2D subview path because the 1D path appears to hit a naming # issue in this Triton-Ascend build. src_buffer = bl.alloc(tl.float32, [XBLOCK, XBLOCK]) bl.subview( src_buffer, offsets=[offset, 0], sizes=[size, XBLOCK], strides=[stride, 1], ) # ============== Main for manual testing ============== if __name__ == "__main__": print("=" * 60) print("Test 1: test_subview_function") print("=" * 60) mlir = compile_kernel(test_subview_kernel1, {}, {"XBLOCK": 8}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) print("\n" + "=" * 60) print("Test 2: test_subview_constructor") print("=" * 60) mlir = compile_kernel( test_subview_kernel2, {}, {"XBLOCK": 32, "offset": 1, "size": 24, "stride": 1}, ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) |
输出:
| Plain Text ============================================================ Test 1: test_subview_function ============================================================ ✅ Generated MLIR (907 chars): module { tt.func public @test_subview_kernel1() attributes {noinline = false} { %alloc = memref.alloc() : memref<8x8xf32> loc(#loc1) annotation.mark %alloc {effects = ["write", "read"]} : memref<8x8xf32> loc(#loc1) %c1_i32 = arith.constant 1 : i32 loc(#loc2) %c0_i32 = arith.constant 0 : i32 loc(#loc2) %0 = arith.index_cast %c1_i32 : i32 to index loc(#loc2) %1 = arith.index_cast %c0_i32 : i32 to index loc(#loc2) %subview = memref.subview %alloc[%0, %1] [6, 8] [1, 1] : memref<8x8xf32> to memref<6x8xf32, strided<[8, 1], offset: ?>> loc(#loc2) tt.return loc(#loc3) } loc(#loc) } loc(#loc) #loc = loc("/home/ganpengfei/workspace/triton-test/subview.py":33:0) #loc1 = loc("/home/ganpengfei/workspace/triton-test/subview.py":35:38) #loc2 = loc("/home/ganpengfei/workspace/triton-test/subview.py":37:8) #loc3 = loc("/home/ganpengfei/workspace/triton-test/subview.py":36:4) ============================================================ Test 2: test_subview_constructor ============================================================ ✅ Generated MLIR (918 chars): module { tt.func public @test_subview_kernel2() attributes {noinline = false} { %alloc = memref.alloc() : memref<32x32xf32> loc(#loc1) annotation.mark %alloc {effects = ["write", "read"]} : memref<32x32xf32> loc(#loc1) %c1_i32 = arith.constant 1 : i32 loc(#loc2) %c0_i32 = arith.constant 0 : i32 loc(#loc2) %0 = arith.index_cast %c1_i32 : i32 to index loc(#loc2) %1 = arith.index_cast %c0_i32 : i32 to index loc(#loc2) %subview = memref.subview %alloc[%0, %1] [24, 32] [1, 1] : memref<32x32xf32> to memref<24x32xf32, strided<[32, 1], offset: ?>> loc(#loc2) tt.return loc(#loc3) } loc(#loc) } loc(#loc) #loc = loc("/home/ganpengfei/workspace/triton-test/subview.py":45:0) #loc1 = loc("/home/ganpengfei/workspace/triton-test/subview.py":48:38) #loc2 = loc("/home/ganpengfei/workspace/triton-test/subview.py":50:8) #loc3 = loc("/home/ganpengfei/workspace/triton-test/subview.py":49:4) |