triton.language.dot

1. OP 概述

简介:对两个tensor进行矩阵乘操作。tensor需要是二维或三维并且维度需一致。对于三维块,tl.dot执行批量矩阵乘法,其中每个块的第一维代表批量维度。 原型:

triton.language.dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32, _semantic=None)

2. OP 规格

2.1 参数说明

参数名 类型 说明
input int8 fp16 bf16 fp32 第一个输入,2D or 3D 张量, 为了避免溢出 取值范围限制为-5-5
other int8 fp16 bf16 fp32 第二个输入, 2D or 3D 张量,为了避免溢出 取值范围限制为-5-5
acc int32 float32 存累加结果的张量, accumulator tensor. If not None, the result is added to this tensor, acc_dtype支持 {:code:float16, :code:float32, :code:int32}
input_precision - Available options for NVIDIA 通过选择精度模式来决定是否启用 Tensor Cores 加速
max_num_imprecise_acc int 多少次低精度的累加数(当前昇腾不支持低精度累加)
out_dtype fp32 int32 输出结果类型

返回值: tl.tensor:矩阵乘结果

2.2 支持规格

2.2.1 DataType 支持

输入类型 int8 int16 int32 uint8 uint16 uint32 uint64 int64 fp16 fp32 fp64 bf16 bool
GPU
Ascend A2/A3 × × × × × ×

结论:Ascend 对比 GPU 缺失uint8、uint16、uint32、uint64、fp64的支持能力(硬件限制)。

2.2.2 Shape 支持

支持维度范围
GPU 无限制
Ascend A2/A3 无限制

结论:在 Shape 方面,GPU 与 Ascend 平台无差异。

2.3 特殊限制说明

  • Ascend 对比 GPU 缺失uint8、uint16、uint32、uint64、fp64的支持能力(硬件限制)。

  • acc 不能支持fp16,为了精度硬件默认就是fp32

  • max_num_imprecise_acc 暂时不支持

  • out_dtype对比GPU 缺乏int8和FP16的类型支持

2.4 使用方法

以下示例实现了对输入张量 x_ptr, y_ptr 做矩阵乘计算,参考 ascend/examples/generalization_cases/test_matmul.py:

def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M: tl.constexpr,
        N: tl.constexpr,
        K: tl.constexpr,
        acc_dtype: tl.constexpr,
        stride_am: tl.constexpr,
        stride_ak: tl.constexpr,
        stride_bk: tl.constexpr,
        stride_bn: tl.constexpr,
        stride_cm: tl.constexpr,
        stride_cn: tl.constexpr,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n

    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M))
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N))
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    c = accumulator.to(c_ptr.dtype.element_ty)

    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)