triton.language.dot_scaled
1. OP 概述
简介:计算以缩放格式表示两个矩阵块的矩阵乘积
triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format,
acc=None, lhs_k_pack=True, rhs_k_pack=True,
out_dtype=triton.language.float32, _semantic=None)
2. OP 规格
2.1 参数说明
| 参数名 | 类型 | 说明 |
|---|---|---|
lhs |
tensor |
左矩阵张量的基指针(支持bf16、fp16格式) |
lhs_scale |
tensor |
左矩阵缩放张量的基指针(支持int8格式) |
lhs_format |
string |
左矩阵张量的存放格式 (支持"bf16"和"fp16") |
rhs |
tensor |
右矩阵张量的基指针 (支持bf16、fp16格式) |
rhs_scale |
tensor |
右矩阵缩放张量的基指针(支持int8格式) |
rhs_format |
string |
右矩阵张量的存放格式 (支持"bf16"和"fp16") |
acc |
tensor |
累积张量 |
lhs_k_pack |
(bool, optional) |
true 沿 K 维度打包 false 沿 M 维度打包 |
rhs_k_pack |
(bool, optional) |
true 沿 K 维度打包 false 沿 N 维度打包 |
_semantic |
- | 保留参数,暂不支持外部调用 |
返回值:
out:tensor类型,计算缩放矩阵乘后输出的值
2.2 支持规格
2.2.1 DataType 支持
| fp4 | fp8 | bf16 | fp16 | |
|---|---|---|---|---|
| GPU | √ | √ | √ | √ |
| Ascend A2/A3 | × | × | √ | √ |
结论: 1、Ascend 对比 GPU 缺失fp4、fp8的支持能力(硬件限制)。 2、缩放张量的值为int8,GPU上为uint8。
2.2.2 Shape 支持
| 支持维度范围 | |
|---|---|
| GPU | 可支持 2~3维 tensor |
| Ascend | 可支持 2~3维 tensor |
结论:在 Shape 方面,GPU 与 Ascend 平台无差异,lhs/rhs矩阵均支持 2 至 3 维张量,但scale矩阵只支持2维。
2.3 特殊限制说明
1、由于不支持fp8,左右矩阵不支持fp4、fp8格式,Ascend 对比 GPU 缺失lhs_k_pack、rhs_k_pack的矩阵解压缩支持能力(硬件限制)。 2、输入矩阵lhs、rhs推荐输入范围为[-5, 5],超过可能会出现极值inf。 3、由于硬件存在对齐要求,需要限制scale矩阵做broadcast的倍数,至少应为16
4、当前支持的缩放矩阵格式为int8,社区为uint8
2.4 使用方法
以下示例实现了对输入张量 x 做就地绝对值计算:
@triton.jit
def dot_scale_kernel(a_base, stride_a0: tl.constexpr, stride_a1: tl.constexpr, a_scale, b_base, stride_b0: tl.constexpr,
stride_b1: tl.constexpr, b_scale, out,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
type_b: tl.constexpr):
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K
PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K
str_a0: tl.constexpr = stride_a0
a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0,
str_a0)[None, :] * stride_a1
b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0,
BLOCK_N)[None, :] * stride_b1
a = tl.load(a_ptr)
b = tl.load(b_ptr)
SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if a_scale is not None:
scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0,
SCALE_BLOCK_K)[None, :]
a_scale = tl.load(scale_a_ptr)
if b_scale is not None:
scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0,
SCALE_BLOCK_K)[None, :]
b_scale = tl.load(scale_b_ptr)
accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32)
out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
tl.store(out_ptr, accumulator.to(a.dtype))
x = torch.randn(shape, dtype=torch.bfloat16, device="npu")
y = torch.randn(shape, dtype=torch.bfloat16, device="npu")
M, K = shape[0], shape[1]
scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device="npu")
scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device="npu")
type_a, type_b = "bf16", "bf16"
pgm = dot_scale_kernel[(1,)](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b)