triton.language.dot_scaled

1. OP 概述

简介:计算以缩放格式表示两个矩阵块的矩阵乘积

triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format,
    acc=None, lhs_k_pack=True, rhs_k_pack=True,
    out_dtype=triton.language.float32, _semantic=None)

2. OP 规格

2.1 参数说明

参数名 类型 说明
lhs tensor 左矩阵张量的基指针(支持bf16、fp16格式)
lhs_scale tensor 左矩阵缩放张量的基指针(支持int8格式)
lhs_format string 左矩阵张量的存放格式 (支持"bf16"和"fp16")
rhs tensor 右矩阵张量的基指针 (支持bf16、fp16格式)
rhs_scale tensor 右矩阵缩放张量的基指针(支持int8格式)
rhs_format string 右矩阵张量的存放格式 (支持"bf16"和"fp16")
acc tensor 累积张量
lhs_k_pack (bool, optional) true 沿 K 维度打包
false 沿 M 维度打包
rhs_k_pack (bool, optional) true 沿 K 维度打包
false 沿 N 维度打包
_semantic - 保留参数,暂不支持外部调用

返回值: out:tensor类型,计算缩放矩阵乘后输出的值

2.2 支持规格

2.2.1 DataType 支持

fp4 fp8 bf16 fp16
GPU
Ascend A2/A3 × ×

结论: 1、Ascend 对比 GPU 缺失fp4、fp8的支持能力(硬件限制)。 2、缩放张量的值为int8,GPU上为uint8。

2.2.2 Shape 支持

支持维度范围
GPU 可支持 2~3维 tensor
Ascend 可支持 2~3维 tensor

结论:在 Shape 方面,GPU 与 Ascend 平台无差异,lhs/rhs矩阵均支持 2 至 3 维张量,但scale矩阵只支持2维。

2.3 特殊限制说明

1、由于不支持fp8,左右矩阵不支持fp4、fp8格式,Ascend 对比 GPU 缺失lhs_k_pack、rhs_k_pack的矩阵解压缩支持能力(硬件限制)。 2、输入矩阵lhs、rhs推荐输入范围为[-5, 5],超过可能会出现极值inf。 3、由于硬件存在对齐要求,需要限制scale矩阵做broadcast的倍数,至少应为16

4、当前支持的缩放矩阵格式为int8,社区为uint8

2.4 使用方法

以下示例实现了对输入张量 x 做就地绝对值计算:

@triton.jit
def dot_scale_kernel(a_base, stride_a0: tl.constexpr, stride_a1: tl.constexpr, a_scale, b_base, stride_b0: tl.constexpr,
                     stride_b1: tl.constexpr, b_scale, out,
                     BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
                     type_b: tl.constexpr):
    PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K
    PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K
    str_a0: tl.constexpr = stride_a0
    a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0,
                                                                            str_a0)[None, :] * stride_a1
    b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0,
                                                                                     BLOCK_N)[None, :] * stride_b1

    a = tl.load(a_ptr)
    b = tl.load(b_ptr)
    SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    if a_scale is not None:
        scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0,
                                                                                           SCALE_BLOCK_K)[None, :]
        a_scale = tl.load(scale_a_ptr)
    if b_scale is not None:
        scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0,
                                                                                           SCALE_BLOCK_K)[None, :]
        b_scale = tl.load(scale_b_ptr)
    accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32)

    out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
    tl.store(out_ptr, accumulator.to(a.dtype))

x = torch.randn(shape, dtype=torch.bfloat16, device="npu")
y = torch.randn(shape, dtype=torch.bfloat16, device="npu")
M, K = shape[0], shape[1]
scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device="npu")
scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device="npu")
type_a, type_b = "bf16", "bf16"
pgm = dot_scale_kernel[(1,)](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b)