import types
import torch._C
from torch._C import _add_docstr as add_docstr
import torch_npu
def _add_torch_npu_docstr(method, docstr):
"""Add doc to operator API.
If implementing the Python side interface with pybind11, _add_docstr is needed to add doc.
"""
func = getattr(torch_npu, method, None)
if not func:
return
if isinstance(func, types.BuiltinMethodType):
add_docstr(func, docstr)
else:
getattr(torch_npu, method).__doc__ = docstr
_add_torch_npu_docstr(
"_npu_dropout",
"""
torch_npu._npu_dropout(self, p) -> (Tensor, Tensor)
功能描述
不使用种子(seed)进行dropout结果计数。与torch.dropout相似,优化NPU设备实现。
参数说明
self (Tensor) - 输入张量。
p (Float) - 丢弃概率。
示例
>>> import torch
>>> import torch_npu
>>> input = torch.tensor([1.,2.,3.,4.]).npu()
>>> input
tensor([1., 2., 3., 4.], device='npu:0')
>>> prob = 0.3
>>> output, mask = torch_npu._npu_dropout(input, prob)
>>> output
tensor([0.0000, 2.8571, 0.0000, 0.0000], device='npu:0')
>>> mask
tensor([ 98, 255, 188, 186, 120, 157, 175, 159, 77, 223, 127, 79, 247, 151,
253, 255], device='npu:0', dtype=torch.uint8)
"""
)
_add_torch_npu_docstr(
"copy_memory_",
"""
torch_npu.copy_memory_(dst, src, non_blocking=False) -> Tensor
功能描述
从src拷贝元素到self张量,并返回self。
参数说明
dst (Tensor) - 拷贝目标张量(即接收数据的张量)。
src (Tensor) - 拷贝源张量(即提供数据的张量)。
non_blocking (Bool,默认值为False) - 如果设置为True且此拷贝位于CPU和NPU之间,则拷贝可能相对于主机异步发生。在其他情况下,此参数没有效果。
约束说明
copy_memory_仅支持NPU张量。copy_memory_的输入张量应具有相同的dtype和设备index。
示例
>>> a=torch.IntTensor([0, 0, -1]).npu()
>>> b=torch.IntTensor([1, 1, 1]).npu()
>>> torch_npu.copy_memory_(a, b)
tensor([1, 1, 1], device='npu:0', dtype=torch.int32)
"""
)
_add_torch_npu_docstr(
"empty_with_format",
"""
torch_npu.empty_with_format(size, dtype, layout, device, pin_memory, acl_format)
功能描述
返回一个填充未初始化数据的张量。
参数说明
size (ListInt) - 定义输出张量shape的整数序列。可以是参数数量(可变值),也可以是列表或元组等集合。
dtype (torch.dtype, 可选,默认值为None) - 返回张量所需数据类型。如果值为None,请使用全局默认值(请参见torch.set_default_tensor_type()).
layout (torch.layout, 可选,默认值为torch.strided) - 返回张量所需布局。
device (torch.device, 可选,默认值为None) - 返回张量的所需设备。
pin_memory (Bool, 可选,默认值为False) - 如果设置此参数,返回张量将分配在固定内存中。
acl_format (Int,默认值为2) - 返回张量所需内存格式。
示例
>>> torch_npu.empty_with_format((2, 3), dtype=torch.float32, device="npu")
tensor([[1., 1., 1.],
[1., 1., 1.]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"fast_gelu",
"""
torch_npu.fast_gelu(self) -> Tensor
功能描述
gelu的npu实现。支持FakeTensor模式。
参数说明
self (Tensor) - 数据类型:float16、float32。
示例
示例一:
>>> x = torch.rand(2).npu()
>>> x
tensor([0.5991, 0.4094], device='npu:0')
>>> torch_npu.fast_gelu(x)
tensor([0.4403, 0.2733], device='npu:0')
示例二:
//FakeTensor模式
>>> from torch._subclasses.fake_tensor import FakeTensorMode
>>> with FakeTensorMode():
... x = torch.rand(2).npu()
... torch_npu.fast_gelu(x)
FakeTensor(..., device='npu:0', size=(2,))
"""
)
_add_torch_npu_docstr(
"npu_fast_gelu",
"""
功能描述
算子功能: 快速高斯误差线性单元激活函数(Fast Gaussian Error Linear Units activation function), 对输入的每个元素计算FastGelu的前向结果.
计算公式
公式1: fast_gelu(x)=$$\frac{x}{1+e^{-1.702\begin{vmatrix}x\end{vmatrix}}}e^{0.851x(x-\begin{vmatrix}x\end{vmatrix})
该公式支持: Atlas 训练系列产品/Atlas 推理系列产品
公式2: $$\frac{x}{1+e^{-1.702x}}
该公式支持: Atlas A2 训练系列产品/Atlas 800I A2 推理产品/Atlas A3 训练系列产品
接口原型
torch_npu.npu_fast_gelu(Tensor input) -> Tensor
参数说明
input: Tensor类型, 即公式中的x. 数据格式支持ND, 支持非连续的Tensor. 输入最大支持8维.
Atlas 训练系列产品: 数据类型支持float16、float32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float32、bfloat16.
Atlas A3 训练系列产品: 数据类型支持float16、float32、bfloat16.
Atlas 推理系列产品: 数据类型仅支持float16、float32.
输出说明
一个Tensor类型的输出, 代表fast_gelu的计算结果.
约束说明
该接口支持推理、训练场景下使用.
该接口支持图模式.
input输入不能含有None.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
支持的型号
Atlas 训练系列产品
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
Atlas 推理系列产品
示例
单算子调用
import os
import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
x = torch.from_numpy(data_var).to(torch.float32).npu()
y = torch_npu.npu_fast_gelu(x).cpu().numpy()
图模式调用
import os
import torch
import torch_npu
import numpy as np
import torch.nn as nn
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
os.environ["ENABLE_ACLNN"] = "false"
torch_npu.npu.set_compile_mode(jit_compile=True)
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, x):
y = torch_npu.npu_fast_gelu(x)
return y
npu_mode = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
x = torch.from_numpy(data_var).to(torch.float32)
y =npu_mode(x).cpu().numpy()
"""
)
_add_torch_npu_docstr(
"npu_alloc_float_status",
"""
torch_npu.npu_alloc_float_status(self) -> Tensor
功能描述
生成一个包含8个0的一维张量。
参数说明
self (Tensor) - 任何张量。
示例
>>> import torch
>>> import torch_npu
>>> input = torch.randn([1,2,3]).npu()
>>> output = torch_npu.npu_alloc_float_status(input)
>>> input
tensor([[[ 2.2324, 0.2478, -0.1056],
[ 1.1273, -0.2573, 1.0558]]], device='npu:0')
>>> output
tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_anchor_response_flags",
"""
torch_npu.npu_anchor_response_flags(self, featmap_size, stride, num_base_anchors) -> Tensor
功能描述
在单个特征图中生成锚点的责任标志。
参数说明
self (Tensor) - 真值框,shape为[batch, 4]的2D张量。
featmap_size (ListInt of length 2) - 特征图大小。
strides (ListInt of length 2) - 当前水平的步长。
num_base_anchors (Int) - base anchors的数量。
示例
>>> x = torch.rand(100, 4).npu()
>>> y = torch_npu.npu_anchor_response_flags(x, [60, 60], [2, 2], 9)
>>> y.shape
torch.Size([32400])
"""
)
_add_torch_npu_docstr(
"npu_apply_adam",
"""
torch_npu.npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out = (var, m, v))
功能描述
adam结果计数。
参数说明
beta1_power (Scalar) - beta1的幂。
beta2_power (Scalar) - beta2的幂。
lr (Scalar) - 学习率。
beta1 (Scalar) - 一阶矩估计值的指数衰减率。
beta2 (Scalar) - 二阶矩估计值的指数衰减率。
epsilon (Scalar) - 添加到分母中以提高数值稳定性的项数。
grad (Tensor) - 梯度。
use_locking (Bool,可选) - 设置为True时使用lock进行更新操作。
use_nesterov (Bool,可选) - 设置为True时采用nesterov更新。
var (Tensor) - 待优化变量。
m (Tensor) - 变量平均值。
v (Tensor) - 变量方差。
"""
)
_add_torch_npu_docstr(
"npu_apply_adam_w_out",
"""
torch_npu.npu_apply_adam_w_out(beta1_power, beta2_power, lr, weight_decay,
beta1, beta2, epsilon, grad, max_grad_norm, amsgrad, maximize, var, m, v) -> (Tensor(a!),Tensor(b!),Tensor(c!))
功能描述
实现adamW优化器功能
参数说明
beta1_power (Scalar) - beta1的幂,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
beta2_power (Scalar) - beta2的幂,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
lr (Scalar) - 学习率,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
weight_decay (Scalar) - 权重衰减系数,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
beta1 (Scalar) - beta1参数,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
beta2 (Scalar) - beta1参数,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
epsilon (Scalar) - 防止除数为0,shape要求为[1],数据类型支持FLOAT16、BFLOAT16、FLOAT32。
grad (Tensor) - 梯度数据,数据类型支持FLOAT16、BFLOAT16、FLOAT32。
max_grad_norm (Tensor,可选) - 保存参数v的最大值,数据类型支持FLOAT16、BFLOAT16、FLOAT32。
amsgrad (Bool,可选) - 是否使用max_grad_norm变量,数据类型支持BOOL。
maximize (Bool,可选) - 是否对梯度grad取反,应用梯度上升方向优化权重使损失函数最大化,数据类型为BOOL。
var (Tensor) - 待计算的权重输入同时也是输出,shape支持1-8维度,数据类型支持FLOAT16、BFLOAT16、FLOAT32。
m (Tensor) - adamW优化器的m参数,shape支持1-8维度,数据类型支持FLOAT16、BFLOAT16、FLOAT32。
v (Tensor) - adamW优化器的v参数,shape支持1-8维度,数据类型支持FLOAT16、BFLOAT16、FLOAT32。
"""
)
_add_torch_npu_docstr(
"npu_batch_gather_matmul",
"""
接口原型:
npu_batch_gather_matmul(Tensor input, Tensor x, Tensor weight_b, Tensor indices, Tensor? weight_a=None, int layer_idx=0, float scale=1e-3, int y_offset=0, int y_slice_size=-1) -> Tensor
功能描述:
npu_batch_gather_matmul: 对于GPU的Batched Gather Matrix-Vector Multiplication (BGMV)。将输入x根据输入索引indices,分别和对应的weight_a,weight_b相乘, 然后将结果累加到输入y并输出。
参数说明:
input:Device侧的tensor,表示待进行累加更新的张量,数据类型Float16,shape支持2维:[batch_size, y_column]。数据格式支持ND。第一维需要和x的第一维一致。支持非连续的Tensor,不支持空Tensor。
x:Device侧的tensor,表示分组前的输入张量,数据类型Float16,shape支持2维:[batch_size, H1],且H1是16的整数倍。数据格式支持ND。支持非连续的Tensor,不支持空Tensor。
weight_b:Device侧的tensor,表示进行矩阵乘的第二个权重矩阵,数据类型Float16。shape支持4维:[W, L, H2, R],第三维需要小于y的第二维(H2<y_column),且H2是16的整数倍。当weight_a为空,weight_b 的shape 是[W, L, H2, H1]。支持非连续的Tensor,不支持空Tensor。
indices:Device侧的tensor,标识输入x的分组索引,数据类型Int32。shape支持1维:[batch_size]。数据格式支持ND。第一维需要和x以及y的第一维保持一致。支持非连续的Tensor,不支持空Tensor。
weight_a:Device侧的tensor,表示进行矩阵乘的第一个权重矩阵,数据类型Float16。为空指针时会跳过第一个矩阵乘。shape支持4维:[W, L, R, H1],前两维需要和weight_b的前两维一致,用W和L表示;第三维需要和weight_b的第四维保持一致,都用R表示,R需要是16的整数倍且取值范围为[16, 128] ;第四维需要和x的第二维保持一致,都用H1表示,需要是16的整数倍。支持非连续的Tensor,不支持空Tensor。
layer_idx:Host侧的整型,表示weight的层数索引,数据类型Int,默认值为0。默认值为0。值需要小于weight_b的第二个维度L。
scale: Host侧的浮点型,表示matmul结果的缩放系数,数据类型Float,默认值为1e-3。
y_offset: Host侧的整型,表示y更新的偏移值,数据类型Int,默认值为0。值需要小于y的第二个维度y_column。
y_slice_size: Host侧的整型,表示y更新时的范围,数据类型Int,默认值为-1。当为-1时,按照y_column的值传入;当非-1 时,以传入的值做更新范围。
输出说明:
out:Device侧的Tensor类型,计算输出,复用y输入地址;数据类型和shape与y一致。
约束说明:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:仅在推理场景下使用。
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例:
单算子调用
import numpy as np
import torch
import torch_npu
x_data=torch.from_numpy(np.random.uniform(-1, 1, (4096, 16)).astype(np.float16)).npu()
y_data = torch.from_numpy(np.ones((4096, 6144)).astype(np.float16)).npu()
wa_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 16, 4096)).astype(np.float16)).npu()
wb_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 4096, 16)).astype(np.float16)).npu()
indices_data =torch.from_numpy(np.random.randint(-1, 2, size=4096).reshape(4096).astype(np.int32)).npu()
pred=torch_npu.npu_batch_gather_matmul(y_data,x_data,wb_t_all_data,indices_data,wa_t_all_data,y_slice_size=4096,scale=1e-3,y_offset=0,layer_idx=0)
torch_npu.npu_batch_gather_matmul_(y_data,x_data,wb_t_all_data,indices_data,wa_t_all_data,y_slice_size=4096,scale=1e-3,y_offset=0,layer_idx=0)
print(y_data)
图模式调用
import numpy as np
import torch
import torch_npu
import torchair
config = torchair.CompilerConfig()
npu_backend_plain = torchair.get_npu_backend(compiler_config=config)
x_data=torch.from_numpy(np.random.uniform(-1, 1, (4096, 16)).astype(np.float16)).npu()
y_data = torch.from_numpy(np.ones((4096, 6144)).astype(np.float16)).npu()
wa_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 16, 4096)).astype(np.float16)).npu()
wb_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 4096, 16)).astype(np.float16)).npu()
indices_data=torch.from_numpy(np.random.randint(-1,2,size=4096).reshape(4096).astype(np.int32)).npu()
def f(y_data, x_data, wb_t_all_data, indices_data, wa_t_all_data=None, y_slice_size=4096, scale=2, y_offset=0):
with torch.npu.amp.autocast():
pred = torch_npu.npu_batch_gather_matmul(y_data, x_data, wb_t_all_data, indices_data, wa_t_all_data, y_slice_size=y_slice_size, scale=scale, y_offset=y_offset, layer_idx=0)
return pred
opt =torch.compile(f, backend=npu_backend_plain, dynamic=True)
y2 = opt(y_data, x_data, wb_t_all_data, indices_data)
print(y2)
"""
)
_add_torch_npu_docstr(
"npu_batch_gather_matmul_",
"""
接口原型:
npu_batch_gather_matmul_(Tensor(a!) input, Tensor x, Tensor weight_b, Tensor indices, Tensor? weight_a=None, int layer_idx=0, float scale=1e-3, int y_offset=0, int y_slice_size=-1) -> Tensor(a!)
功能描述:
npu_batch_gather_matmul_: npu_batch_gather_matmul的inplace版本。将输入x根据输入索引indices,分别和对应的weight_a,weight_b 相乘,然后将结果累加到输入y并输出。
参数说明:
input :Device侧的tensor,表示待进行累加更新的张量,数据类型Float16,shape支持2维:[batch_size, y_column]。数据格式支持ND。第一维需要和x的第一维一致。支持非连续的Tensor,不支持空Tensor。
x:Device侧的tensor,表示分组前的输入张量,数据类型Float16,shape支持2维:[batch_size, H1],且H1是16的整数倍。数据格式支持ND。支持非连续的Tensor,不支持空Tensor。
weight_b:Device侧的tensor,表示进行矩阵乘的第二个权重矩阵,数据类型Float16。shape支持4维:[W, L, H2, R],第三维需要小于y的第二维(H2<y_column),且H2是16的整数倍。当weight_a为空,weight_b 的shape 是[W, L, H2, H1]。支持非连续的Tensor,不支持空Tensor。
indices:Device侧的tensor,标识输入x的分组索引,数据类型Int32。shape支持1维:[batch_size]。数据格式支持ND。第一维需要和x以及y的第一维保持一致。支持非连续的Tensor,不支持空Tensor。
weight_a :Device侧的tensor,表示进行矩阵乘的第一个权重矩阵,数据类型Float16。为空指针时会跳过第一个矩阵乘。shape支持4维:[W, L, R, H1],前两维需要和weight_b的前两维一致,用W和L表示;第三维需要和weight_b的第四维保持一致,都用R表示,R需要是16的整数倍且取值范围为[16, 128] ;第四维需要和x的第二维保持一致,都用H1表示,需要是16的整数倍。支持非连续的Tensor,不支持空Tensor。
layer_idx:Host侧的整型,表示weight的层数索引,数据类型Int,默认值为0。默认值为0。值需要小于weight_b的第二个维度L。
scale: Host侧的浮点型,表示matmul结果的缩放系数,数据类型Float,默认值为1e-3。
y_offset: Host侧的整型,表示y更新的偏移值,数据类型Int,默认值为0。值需要小于y的第二个维度y_column。
y_slice_size: Host侧的整型,表示y更新时的范围,数据类型Int,默认值为-1。当为-1时,按照y_column的值传入;当非-1 时,以传入的值做更新范围。
输出说明:
out:Device侧的Tensor类型,计算输出,复用y输入地址;数据类型和shape与y一致。
约束说明:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:仅在推理场景下使用。
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例:
单算子调用
import numpy as np
import torch
import torch_npu
x_data=torch.from_numpy(np.random.uniform(-1, 1, (4096, 16)).astype(np.float16)).npu()
y_data = torch.from_numpy(np.ones((4096, 6144)).astype(np.float16)).npu()
wa_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 16, 4096)).astype(np.float16)).npu()
wb_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 4096, 16)).astype(np.float16)).npu()
indices_data =torch.from_numpy(np.random.randint(-1, 2, size=4096).reshape(4096).astype(np.int32)).npu()
pred=torch_npu.npu_batch_gather_matmul(y_data,x_data,wb_t_all_data,indices_data,wa_t_all_data,y_slice_size=4096,scale=1e-3,y_offset=0,layer_idx=0)
torch_npu.npu_batch_gather_matmul_(y_data,x_data,wb_t_all_data,indices_data,wa_t_all_data,y_slice_size=4096,scale=1e-3,y_offset=0,layer_idx=0)
print(y_data)
图模式调用
import numpy as np
import torch
import torch_npu
import torchair
config = torchair.CompilerConfig()
npu_backend_plain = torchair.get_npu_backend(compiler_config=config)
x_data=torch.from_numpy(np.random.uniform(-1, 1, (4096, 16)).astype(np.float16)).npu()
y_data = torch.from_numpy(np.ones((4096, 6144)).astype(np.float16)).npu()
wa_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 16, 4096)).astype(np.float16)).npu()
wb_t_all_data =torch.from_numpy(np.random.uniform(-1, 1, (2, 1, 4096, 16)).astype(np.float16)).npu()
indices_data=torch.from_numpy(np.random.randint(-1,2,size=4096).reshape(4096).astype(np.int32)).npu()
def f(y_data, x_data, wb_t_all_data, indices_data, wa_t_all_data=None, y_slice_size=4096, scale=2, y_offset=0):
with torch.npu.amp.autocast():
pred = torch_npu.npu_batch_gather_matmul(y_data, x_data, wb_t_all_data, indices_data, wa_t_all_data, y_slice_size=y_slice_size, scale=scale, y_offset=y_offset, layer_idx=0)
return pred
opt =torch.compile(f, backend=npu_backend_plain, dynamic=True)
y2 = opt(y_data, x_data, wb_t_all_data, indices_data)
print(y2)
"""
)
_add_torch_npu_docstr(
"npu_batch_nms",
"""
torch_npu.npu_batch_nms(self, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size, change_coordinate_frame=False, transpose_box=False) -> (Tensor, Tensor, Tensor, Tensor)
功能描述
根据batch分类计算输入框评分,通过评分排序,删除评分高于阈值(iou_threshold)的框,支持多批多类处理。通过NonMaxSuppression(nms)操作可有效删除冗余的输入框,提高检测精度。NonMaxSuppression:抑制不是极大值的元素,搜索局部的极大值,常用于计算机视觉任务中的检测类模型。
参数说明
self (Tensor) - 必填值,输入框的tensor,包含batch大小,数据类型Float16,输入示例:[batch_size, num_anchors, q, 4],其中q=1或q=num_classes。
scores (Tensor) - 必填值,输入tensor,数据类型Float16,输入示例:[batch_size, num_anchors, num_classes]。
score_threshold (Float32) - 必填值,指定评分过滤器的iou_threshold,用于筛选框,去除得分较低的框,数据类型Float32。
iou_threshold (Float32) - 必填值,指定nms的iou_threshold,用于设定阈值,去除高于阈值的的框,数据类型Float32。
max_size_per_class (Int) - 必填值,指定每个类别的最大可选的框数,数据类型Int。
max_total_size (Int) - 必填值,指定每个batch最大可选的框数,数据类型Int。
change_coordinate_frame (Bool,默认值为False) -可选值, 是否正则化输出框坐标矩阵,数据类型Bool。
transpose_box (Bool,默认值为False) - 可选值,确定是否在此op之前插入转置,数据类型Bool。True表示boxes使用4,N排布。 False表示boxes使用过N,4排布。
输出说明
nmsed_boxes (Tensor) - shape为(batch, max_total_size, 4)的3D张量,指定每批次输出的nms框,数据类型Float16。
nmsed_scores (Tensor) - shape为(batch, max_total_size)的2D张量,指定每批次输出的nms分数,数据类型Float16。
nmsed_classes (Tensor) - shape为(batch, max_total_size)的2D张量,指定每批次输出的nms类,数据类型Float16。
nmsed_num (Tensor) - shape为(batch)的1D张量,指定nmsed_boxes的有效数量,数据类型Int32。
示例
>>> boxes = torch.randn(8, 2, 4, 4, dtype = torch.float32).to("npu")
>>> scores = torch.randn(3, 2, 4, dtype = torch.float32).to("npu")
>>> nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = torch_npu.npu_batch_nms(boxes, scores, 0.3, 0.5, 3, 4)
>>> nmsed_boxes
>>> nmsed_scores
>>> nmsed_classes
>>> nmsed_num
"""
)
_add_torch_npu_docstr(
"npu_bert_apply_adam",
"""
torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, step_size=None, adam_mode=0, *, out=(var,m,v))
功能描述
adam结果计数。
参数说明
参数:
var (Tensor) - float16或float32类型张量。
m (Tensor) - 数据类型和shape与exp_avg相同。
v (Tensor) - 数据类型和shape与exp_avg相同。
lr (Scalar) - 数据类型与exp_avg相同。
beta1 (Scalar) - 数据类型与exp_avg相同。
beta2 (Scalar) - 数据类型与exp_avg相同。
epsilon (Scalar) - 数据类型与exp_avg相同。
grad (Tensor) - 数据类型和shape与exp_avg相同。
max_grad_norm (Scalar) - 数据类型与exp_avg相同。
global_grad_norm (Scalar) - 数据类型与exp_avg相同。
weight_decay (Scalar) - 数据类型与exp_avg相同。
step_size (Tensor,可选,默认值为None) - shape为(1, ),数据类型与exp_avg一致。
adam_mode (Int,默认值为0) - 选择adam模式。0表示“adam”,1表示“mbert_adam”。
关键字参数:
out (Tensor,可选) - 输出张量。
示例
>>> var_in = torch.rand(321538).uniform_(-32., 21.).npu()
>>> m_in = torch.zeros(321538).npu()
>>> v_in = torch.zeros(321538).npu()
>>> grad = torch.rand(321538).uniform_(-0.05, 0.03).npu()
>>> max_grad_norm = -1.
>>> beta1 = 0.9
>>> beta2 = 0.99
>>> weight_decay = 0.
>>> lr = 0.
>>> epsilon = 1e-06
>>> global_grad_norm = 0.
>>> var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, out=(var_in, m_in, v_in))
>>> var_out
tensor([ 14.7733, -30.1218, -1.3647, ..., -16.6840, 7.1518, 8.4872],
device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_block_sparse_attention",
"""
功能描述:
BlockSparseAttention 稀疏注意力正向计算, 支持块级稀疏模式, 通过 block_sparse_mask 指定每个 Q 块选择的 KV 块, 实现高效稀疏注意力.
接口原型:
torch_npu.npu_block_sparse_attention(Tensor query, Tensor key, Tensor value, Tensor block_sparse_mask, int[] block_shape, *, str q_input_layout='TND', str kv_input_layout='TND', int num_key_value_heads=1, float scale_value=0.0, int inner_precise=1, int[]? actual_seq_lengths=None, int[]? actual_seq_lengths_kv=None, int? softmax_lse_flag=0) -> (Tensor, Tensor)
参数说明:
query: Tensor 类型, 公式中的 query. 数据格式 ND. TND: [totalQTokens, headNum, headDim]; BNSD: [batch, headNum, maxQSeqLength, headDim]. 数据类型:float16、bfloat16.
key: Tensor 类型, 公式中的 key. TND: [totalKTokens, numKeyValueHeads, headDim]; BNSD: [batch, numKeyValueHeads, maxKvSeqLength, headDim]. 数据类型与 query 一致.
value: Tensor 类型, 公式中的 value. shape 与 key 一致, 数据类型与 query 一致.
block_sparse_mask: Tensor 类型, 必选, 块稀疏掩码. shape 为 [batch, headNum, ceilDiv(maxQSeqLength, blockShapeX), ceilDiv(maxKvSeqLength, blockShapeY)]. 底层算子数据类型为 INT8; PyTorch 侧通常使用 int8、int32 或 bool 表示 0/1 掩码.
block_shape: list[int] 必选, 稀疏块形状 [blockShapeX, blockShapeY]. 至少两元素且大于 0; blockShapeY 必须为 128 的倍数.
* 其后为关键字参数, 须以关键字形式传入.
q_input_layout: str 类型, 可选, 默认 "TND". query 的排布, 仅支持 "TND"、"BNSD".
kv_input_layout: str 类型, 可选, 默认 "TND". key、value 的排布, 仅支持 "TND"、"BNSD", 需与 q_input_layout 一致.
num_key_value_heads: int 类型, 可选, 默认 1. key/value 的 head 数.
scale_value: float 类型, 可选, 默认 0.0. 缩放系数, 一般取 D^-0.5.
inner_precise: int 类型, 可选, 默认 1. Softmax 计算精度. 0 表示 fp32 中间结果, 1 表示 fp16 中间结果. 当 query/key/value 为 bfloat16 时仅支持 0.
actual_seq_lengths: list[int] 可选, 各 batch 的 query 实际序列长度. q_input_layout 为 "TND" 时必选.
actual_seq_lengths_kv: list[int] 可选, 各 batch 的 key/value 实际序列长度. kv_input_layout 为 "TND" 时必选.
softmax_lse_flag: int 可选, 默认 0. 0 表示不输出 softmax_lse; 1 表示输出 softmax_lse.
输出说明:
(Tensor, Tensor). 第一个为 attention_out, 与 query 的 dtype 和 layout 一致; 第二个为 softmax_lse, 当 softmax_lse_flag=1 时有效.
约束说明:
q_input_layout、kv_input_layout 仅支持 "TND"、"BNSD".
query、key、value 数据类型必须一致且为 float16 或 bfloat16.
query 的 head 数 N1 与 key/value 的 head 数 N2,需满足 N1 >= N2 且 N1 % N2 == 0.
block_sparse_mask 必传,且shape必须为[batch, headNum, ceilDiv(maxQS, blockShapeX), ceilDiv(maxKVS, blockShapeY)].
block_shape 必传,必须包含至少两个元素[blockShapeX, blockShapeY],且值必须大于0;blockShapeY 必须为 128 的倍数。
当 q_input_layout 为 "TND" 时 actual_seq_lengths 必选; 当 kv_input_layout 为 "TND" 时 actual_seq_lengths_kv 必选.
actual_seq_lengths 与 actual_seq_lengths_kv 当前必须同时配置或同时不配置,仅配置其中之一会被算子拦截.
正向路径当前支持 headDim=64 或 128; 反向路径当前支持 headDim=128.
反向路径支持 q_input_layout 和 kv_input_layout 同为 "BNSD" 或同为 "TND",并支持 MHA/GQA 场景. MHA 场景下 N1 == N2, GQA 场景下需满足 N1 > N2 且 N1 % N2 == 0, 其中 N1 为 query 的 head 数, N2 为 key/value 的 head 数.
inner_precise 仅支持 0(表示float32 softmax) 或 1(表示float16 softmax);当 query/key/value 为 bfloat16 时,仅支持 0.
支持的PyTorch版本
PyTorch 2.10
PyTorch 2.9
PyTorch 2.8
PyTorch 2.7
PyTorch 2.6
支持的型号:
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
调用示例:
BNSD 布局
import torch
import torch_npu
B, N, S, D = 2, 8, 32, 64
num_kv_heads = 8
scale_value = 1.0 / (D ** 0.5)
block_shape = [128, 128] # blockShapeY 须为 128 的倍数
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8).npu()
attention_out, _ = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
)
print(attention_out.shape) # (B, N, S, D)
TND 布局
T, N, D = 32, 8, 64
num_kv_heads = 8
scale_value = 1.0 / (D ** 0.5)
block_shape = [128, 128] # blockShapeY 须为 128 的倍数
ceil_q = (T + block_shape[0] - 1) // block_shape[0]
ceil_kv = (T + block_shape[1] - 1) // block_shape[1]
query = torch.randn(T, N, D, dtype=torch.float16).npu()
key = torch.randn(T, num_kv_heads, D, dtype=torch.float16).npu()
value = torch.randn(T, num_kv_heads, D, dtype=torch.float16).npu()
block_sparse_mask = torch.ones(1, N, ceil_q, ceil_kv, dtype=torch.int8).npu()
attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="TND", kv_input_layout="TND",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[T], actual_seq_lengths_kv=[T], softmax_lse_flag=1
)
print(attention_out.shape) # (T, N, D)
"""
)
_add_torch_npu_docstr(
"npu_bmmV2",
"""
torch_npu.npu_bmmV2(self, mat2, output_sizes) -> Tensor
功能描述
将矩阵“a”乘以矩阵“b”,生成“a*b”。支持FakeTensor模式。
参数说明
self (Tensor) - 2D或更高维度矩阵张量。数据类型:float16、float32、int32。格式:[ND, NHWC, FRACTAL_NZ]。
mat2 (Tensor) - 2D或更高维度矩阵张量。数据类型:float16、float32、int32。格式:[ND, NHWC, FRACTAL_NZ]。
output_sizes (ListInt,默认值为[]) - 输出的shape,用于matmul的反向传播。
示例
示例一:
>>> mat1 = torch.randn(10, 3, 4).npu()
>>> mat2 = torch.randn(10, 4, 5).npu()
>>> res = torch_npu.npu_bmmV2(mat1, mat2, [])
>>> res.shape
torch.Size([10, 3, 5])
示例二:
//FakeTensor模式
>>> from torch._subclasses.fake_tensor import FakeTensorMode
>>> with FakeTensorMode():
... mat1 = torch.randn(10, 3, 4).npu()
... mat2 = torch.randn(10, 4, 5).npu()
... result = torch_npu.npu_bmmV2(mat1, mat2, [])
...
>>> result
FakeTensor(..., device='npu:0', size=(10, 3, 5))
"""
)
_add_torch_npu_docstr(
"npu_bounding_box_decode",
"""
torch_npu.npu_bounding_box_decode(rois, deltas, means0, means1, means2, means3, stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip) -> Tensor
功能描述
根据rois和deltas生成标注框。自定义FasterRcnn算子。
参数说明
rois (Tensor) - 区域候选网络(RPN)生成的region of interests(ROI)。shape为(N,4)数据类型为float32或float16的2D张量。“N”表示ROI的数量, “4”表示“x0”、“x1”、“y0”和“y1”。
deltas (Tensor) - RPN生成的ROI和真值框之间的绝对变化。shape为(N,4)数据类型为float32或float16的2D张量。“N”表示错误数,“4”表示“dx”、“dy”、“dw”和“dh”。
means0 (Float) - index。
means1 (Float) - index。
means2 (Float) - index。
means3 (Float,默认值为[0,0,0,0]) - index。"deltas" = "deltas" x "stds" + "means"
stds0 (Float) - index。
stds1 (Float) - index。
stds2 (Float) - index。
stds3 (Float, 默认值:[1.0,1.0,1.0,1.0]) - index。"deltas" = "deltas" x "stds" + "means"
max_shape (ListInt of length 2) - shape[h, w],指定传输到网络的图像大小。用于确保转换后的bbox shape不超过“max_shape”。
wh_ratio_clip (Float) -“dw”和“dh”的值在(-wh_ratio_clip, wh_ratio_clip)范围内。
示例
>>> rois = torch.tensor([[1., 2., 3., 4.], [3.,4., 5., 6.]], dtype = torch.float32).to("npu")
>>> deltas = torch.tensor([[5., 6., 7., 8.], [7.,8., 9., 6.]], dtype = torch.float32).to("npu")
>>> output = torch_npu.npu_bounding_box_decode(rois, deltas, 0, 0, 0, 0, 1, 1, 1, 1, (10, 10), 0.1)
>>> output
tensor([[2.5000, 6.5000, 9.0000, 9.0000],
[9.0000, 9.0000, 9.0000, 9.0000]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_bounding_box_encode",
"""
torch_npu.npu_bounding_box_encode(anchor_box, ground_truth_box, means0, means1, means2, means3, stds0, stds1, stds2, stds3) -> Tensor
功能描述
计算标注框和ground truth真值框之间的坐标变化。自定义FasterRcnn算子。
参数说明
anchor_box (Tensor) - 输入张量。锚点框。shape为(N,4)数据类型为float32的2D张量。“N”表示标注框的数量,“4”表示“x0”、“x1”、“y0”和“y1”。
ground_truth_box (Tensor) - 输入张量。真值框。shape为(N,4)数据类型为float32的2D张量。“N”表示标注框的数量,“4”表示“x0”、“x1”、“y0”和“y1”。
means0 (Float) - index。
means1 (Float) - index。
means2 (Float) - index。
means3 (Float, 默认值为[0,0,0,0]) - index。 "deltas" = "deltas" x "stds" + "means"
stds0 (Float) - index。
stds1 (Float) - index。
stds2 (Float) - index。
stds3 (Float, 默认值:[1.0,1.0,1.0,1.0]) -index。 "deltas" = "deltas" x "stds" + "means"
示例
>>> import torch
>>> import torch_npu
>>> anchor_box = torch.tensor([[1., 2., 3., 4.], [3.,4., 5., 6.]], dtype = torch.float32).to("npu")
>>> ground_truth_box = torch.tensor([[5., 6., 7., 8.], [7.,8., 9., 6.]], dtype = torch.float32).to("npu")
>>> output = torch_npu.npu_bounding_box_encode(anchor_box, ground_truth_box, 0, 0, 0, 0, 0.1, 0.1, 0.2, 0.2)
>>> output
tensor([[13.3281, 13.3281, 0.0000, 0.0000],
[13.3281, 6.6641, 0.0000, -5.4922]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_broadcast",
"""
torch_npu.npu_broadcast(self, size) -> Tensor
功能描述
返回self张量的新视图,其单维度扩展,结果连续。
张量也可以扩展更多维度,新的维度添加在最前面。
参数说明
self (Tensor) - 输入张量。
size (ListInt) - 对应扩展尺寸。
示例
>>> x = torch.tensor([[1], [2], [3]]).npu()
>>> x.shape
torch.Size([3, 1])
>>> torch_npu.npu_broadcast(x, [3, 4])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_ciou",
"""
torch_npu.npu_ciou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=True, int mode=0, bool atan_sub_flag=False) -> Tensor
功能描述
应用基于NPU的CIoU操作。在DIoU的基础上增加了penalty item,并propose CIoU。
参数说明
boxes1 (Tensor):格式为xywh、shape为(4, n)的预测检测框。
boxes2 (Tensor):相应的gt检测框,shape为(4, n)。
trans (Bool,默认值为False):是否有偏移。
is_cross (Bool,默认值为True):box1和box2之间是否有交叉操作。
mode (Int,默认值为0):选择CIoU的计算方式。0表示IoU,1表示IoF。
atan_sub_flag (Bool,默认值为False):是否将正向的第二个值传递给反向。
输出说明
torch.Tensor:mask操作的结果。
约束说明
到目前为止,CIoU向后只支持当前版本中的trans==True、is_cross==False、mode==0('iou')。如果需要反向传播,确保参数正确。
示例
>>> box1 = torch.randn(4, 32).npu()
>>> box1.requires_grad = True
>>> box2 = torch.randn(4, 32).npu()
>>> box2.requires_grad = True
>>> ciou = torch_npu.npu_ciou(box1, box2, trans=True, is_cross=False, mode=0)
>>> l = ciou.sum()
>>> l.backward()
"""
)
_add_torch_npu_docstr(
"npu_clear_float_status",
"""
torch_npu.npu_clear_float_status(self) -> Tensor
功能描述
在每个核中设置地址0x40000的值为0。
参数说明
self (Tensor) - 数据类型为float32的张量。
示例
>>> x = torch.rand(2).npu()
>>> torch_npu.npu_clear_float_status(x)
tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_confusion_transpose",
"""
torch_npu.npu_confusion_transpose(self, perm, shape, transpose_first) -> Tensor
功能描述
混淆reshape和transpose运算。
参数说明
self (Tensor) - 数据类型:float16、float32、int8、int16、int32、int64、uint8、uint16、uint32、uint64。
perm (ListInt) - self张量的维度排列。
shape (ListInt) - 输入shape。
transpose_first (Bool) - 如果值为True,首先执行transpose,否则先执行reshape。
示例
>>> x = torch.rand(2, 3, 4, 6).npu()
>>> x.shape
torch.Size([2, 3, 4, 6])
>>> y = torch_npu.npu_confusion_transpose(x, (0, 2, 1, 3), (2, 4, 18), True)
>>> y.shape
torch.Size([2, 4, 18])
>>> y2 = torch_npu.npu_confusion_transpose(x, (0, 2, 1), (2, 12, 6), False)
>>> y2.shape
torch.Size([2, 6, 12])
"""
)
_add_torch_npu_docstr(
"npu_conv2d",
"""
torch_npu.npu_conv2d(input, weight, bias, stride, padding, dilation, groups) -> Tensor
功能描述
在由多个输入平面组成的输入图像上应用一个2D卷积。
参数说明
input (Tensor) - shape的输入张量,值为 (minibatch, in_channels, iH, iW)。
weight (Tensor) - shape过滤器,值为 (out_channels, in_channels/groups, kH, kW)。
bias (Tensor, 可选) - shape偏差 (out_channels)。
stride (ListInt) - 卷积核步长。
padding (ListInt) - 输入两侧的隐式填充。
dilation (ListInt) - 内核元素间距。
groups (Int) - 对输入进行分组。In_channels可被组数整除。
"""
)
_add_torch_npu_docstr(
"npu_conv3d",
"""
torch_npu.npu_conv3d(input, weight, bias, stride, padding, dilation, groups) -> Tensor
功能描述
在由多个输入平面组成的输入图像上应用一个3D卷积。
参数说明
input (Tensor) - shape的输入张量,值为 (minibatch, in_channels, iT, iH, iW)。
weight (Tensor) - shape过滤器,值为 (out_channels, in_channels/groups, kT, kH, kW)。
bias (Tensor, 可选) - shape偏差 (out_channels)。
stride (ListInt) - 卷积核步长。
padding (ListInt) - 输入两侧的隐式填充。
dilation (ListInt) - 内核元素间距。
groups (Int) - 对输入进行分组。In_channels可被组数整除。
"""
)
_add_torch_npu_docstr(
"npu_conv_transpose2d",
"""
torch_npu.npu_conv_transpose2d(input, weight, bias, padding, output_padding, stride, dilation, groups) -> Tensor
功能描述
在由多个输入平面组成的输入图像上应用一个2D转置卷积算子,有时这个过程也被称为“反卷积”。
参数说明
input (Tensor) - shape的输入张量,值为 (minibatch, in_channels, iH, iW)。
weight (Tensor) - shape过滤器,值为 (in_channels, out_channels/groups, kH, kW)。
bias (Tensor, 可选) - shape偏差 (out_channels)。
padding (ListInt) - (dilation * (kernel_size - 1) - padding) 用零来填充输入每个维度的两侧。
output_padding (ListInt) - 添加到输出shape每个维度一侧的附加尺寸。
stride (ListInt) - 卷积核步长。
dilation (ListInt) - 内核元素间距。
groups (Int) - 对输入进行分组。In_channels可被组数整除。
"""
)
_add_torch_npu_docstr(
"npu_convolution",
"""
torch_npu.npu_convolution(input, weight, bias, stride, padding, dilation, groups) -> Tensor
功能描述
在由多个输入平面组成的输入图像上应用一个2D或3D卷积。
参数说明
input (Tensor) - shape的输入张量,值为 (minibatch, in_channels, iH, iW) 或 (minibatch, in_channels, iT, iH, iW)。
weight (Tensor) - shape过滤器,值为 (out_channels, in_channels/groups, kH, kW) 或 (out_channels, in_channels/groups, kT, kH, kW)。
bias (Tensor, 可选) - shape偏差 (out_channels)。
stride (ListInt) - 卷积核步长。
padding (ListInt) - 输入两侧的隐式填充。
dilation (ListInt) - 内核元素间距。
groups (Int) - 对输入进行分组。In_channels可被组数整除。
"""
)
_add_torch_npu_docstr(
"npu_convolution_transpose",
"""
torch_npu.npu_convolution_transpose(input, weight, bias, padding, output_padding, stride, dilation, groups) -> Tensor
功能描述
在由多个输入平面组成的输入图像上应用一个2D或3D转置卷积算子,有时这个过程也被称为“反卷积”。
参数说明
input (Tensor) - shape的输入张量,值为 (minibatch, in_channels, iH, iW) 或 (minibatch, in_channels, iT, iH, iW)。
weight (Tensor) - shape过滤器,值为 (in_channels, out_channels/groups, kH, kW) 或 (in_channels, out_channels/groups, kT, kH, kW)。
bias (Tensor, 可选) - shape偏差 (out_channels)。
padding (ListInt) - (dilation * (kernel_size - 1) - padding) 用零来填充输入每个维度的两侧。
output_padding (ListInt) - 添加到输出shape每个维度一侧的附加尺寸。
stride (ListInt) - 卷积核步长。
dilation (ListInt) - 内核元素间距。
groups (Int) - 对输入进行分组。In_channels可被组数整除。
"""
)
_add_torch_npu_docstr(
"npu_deformable_conv2d",
"""
torch_npu.npu_deformable_conv2d(self, weight, offset, bias, kernel_size, stride, padding, dilation=[1,1,1,1], groups=1, deformable_groups=1, modulated=True) -> (Tensor, Tensor)
功能描述
使用预期输入计算变形卷积输出(deformed convolution output)。
参数说明
self (Tensor) - 输入图像的4D张量。格式为“NHWC”,数据按以下顺序存储:[batch, in_height, in_width, in_channels]。
weight (Tensor) - 可学习过滤器的4D张量。数据类型需与self相同。格式为“HWCN”,数据按以下顺序存储:[filter_height, filter_width, in_channels / groups, out_channels]。
offset (Tensor) - x-y坐标偏移和掩码的4D张量。格式为“NHWC”,数据按以下顺序存储:[batch, out_height, out_width, deformable_groups * filter_height * filter_width * 3]。
bias (Tensor,可选) - 过滤器输出附加偏置(additive bias)的1D张量,数据按[out_channels]的顺序存储。
kernel_size (ListInt of length 2) - 内核大小,2个整数的元组/列表。
stride (ListInt) - 4个整数的列表,表示每个输入维度的滑动窗口步长。维度顺序根据self的数据格式解释。N维和C维必须设置为1。
padding (ListInt) - 4个整数的列表,表示要添加到输入每侧(顶部、底部、左侧、右侧)的像素数。
dilations (ListInt,默认值为[1, 1, 1, 1]) - 4个整数的列表,表示输入每个维度的膨胀系数(dilation factor)。维度顺序根据self的数据格式解释。N维和C维必须设置为1。
groups (Int,默认值为1) - int32类型单整数,表示从输入通道到输出通道的阻塞连接数。In_channels和out_channels需都可被“groups”数整除。
deformable_groups (Int,默认值为1) - int32类型单整数,表示可变形组分区的数量。In_channels需可被“deformable_groups”数整除。
modulated (Bool,可选,默认值为True) - 指定DeformableConv2D版本。True表示v2版本, False表示v1版本,目前仅支持v2。
示例
>>> x = torch.rand(16, 32, 32, 32).npu()
>>> weight = torch.rand(32, 32, 5, 5).npu()
>>> offset = torch.rand(16, 75, 32, 32).npu()
>>> output, _ = torch_npu.npu_deformable_conv2d(x, weight, offset, None, kernel_size=[5, 5], stride = [1, 1, 1, 1], padding = [2, 2, 2, 2])
>>> output.shape
torch.Size([16, 32, 32, 32])
"""
)
_add_torch_npu_docstr(
"npu_diou",
"""
torch_npu.npu_diou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> Tensor
功能描述
应用基于NPU的DIoU操作。考虑到目标之间距离,以及距离和范围的重叠率,不同目标或边界需趋于稳定。
参数说明
boxes1 (Tensor) - 格式为xywh、shape为(4, n)的预测检测框。
boxes2 (Tensor) - 相应的gt检测框,shape为(4, n)。
trans (Bool,默认值为False) - 是否有偏移。
is_cross (Bool,默认值为False) - box1和box2之间是否有交叉操作。
mode (Int,默认值为0) - 选择DIoU的计算方式。0表示IoU,1表示IoF。
输出说明
torch.Tensor (Tensor) - mask操作的结果。
约束说明
到目前为止,DIoU向后只支持当前版本中的trans==True、is_cross==False、mode==0('iou')。如果需要反向传播,确保参数正确。
示例
>>> box1 = torch.randn(4, 32).npu()
>>> box1.requires_grad = True
>>> box2 = torch.randn(4, 32).npu()
>>> box2.requires_grad = True
>>> diou = torch_npu.contrib.function.npu_diou(box1, box2)
>>> l = diou.sum()
>>> l.backward()
"""
)
_add_torch_npu_docstr(
"npu_dropout_with_add_softmax",
"""
torch_npu.npu_dropout_with_add_softmax(Tensor self, Tensor x1, Scalar alpha, float prob, int dim) -> (Tensor, Tensor, Tensor)
功能描述
实现axpy_v2、softmax_v2、drop_out_domask_v3功能。即:
y=x1+ self *alpha
Softmax(xi)= exp(xi)/∑jexp(xj)
output = 根据mask舍弃x中的元素,留下来的元素乘(1/prob)
参数说明
Tensor self:4维张量,shape为(N, C, H, W)。
Tensor x1:4维张量,shape为(N, C, H, W)。
约束说明
self和x1的shape相同;
H和W是[128, 256, 384, 512]其中之一;
(N * C)%32结果为0;
dim为-1。
示例
self = torch.rand(16, 16, 128, 128).npu()
tensor([[[[7.2556e-02, 3.0909e-01, 7.9734e-01, ..., 6.1179e-01,
6.2624e-03, 8.5186e-01],
[8.9196e-02, 3.3319e-01, 4.0780e-01, ..., 1.9144e-01,
2.2701e-01, 6.4018e-01],
[4.7275e-01, 7.4895e-01, 4.6215e-01, ..., 9.3753e-01,
6.6048e-02, 8.1877e-02],
...,
[7.9366e-01, 5.1516e-01, 5.6594e-01, ..., 1.6457e-01,
1.0640e-01, 3.4322e-03],
[1.5743e-02, 1.2893e-01, 5.8990e-01, ..., 4.1721e-01,
8.7816e-02, 6.8886e-01],
[4.2980e-01, 5.5447e-01, 3.1894e-01, ..., 9.2638e-01,
9.9324e-01, 4.6225e-01]],
[[6.2426e-01, 4.5948e-01, 1.0837e-01, ..., 8.9386e-01,
3.6932e-01, 1.2406e-01],
[9.1823e-01, 6.2311e-01, 5.1474e-01, ..., 2.1042e-01,
6.5943e-01, 3.1797e-01],
[5.2891e-01, 2.0183e-01, 2.1452e-01, ..., 9.1638e-01,
6.4109e-01, 9.4484e-01],
...,
[3.7783e-02, 1.3218e-01, 3.1192e-01, ..., 2.4931e-01,
4.8809e-01, 9.6085e-01],
[3.3197e-01, 9.1186e-02, 2.4839e-01, ..., 2.1156e-03,
6.4952e-01, 8.5996e-01],
[1.7941e-01, 5.1532e-01, 7.8133e-01, ..., 3.5526e-01,
5.3576e-01, 6.0538e-01]],
[[2.6743e-01, 7.4942e-01, 1.9146e-01, ..., 4.9179e-01,
6.3319e-01, 9.9269e-01],
[1.5163e-01, 3.7388e-01, 8.0604e-02, ..., 8.1193e-01,
1.7922e-01, 8.6578e-01],
[8.2558e-01, 9.5139e-01, 2.1313e-01, ..., 2.1722e-01,
2.8402e-01, 8.8888e-01],
...,
[1.8222e-01, 2.7645e-01, 6.7305e-01, ..., 6.8003e-01,
4.0917e-01, 7.6655e-01],
[3.1234e-01, 7.8519e-01, 8.8509e-01, ..., 7.2574e-01,
9.6134e-01, 2.2267e-01],
[4.9233e-01, 8.8407e-01, 7.4390e-01, ..., 5.2253e-02,
5.5150e-02, 4.4108e-02]],
...,
[[4.3370e-01, 2.1176e-01, 4.7512e-01, ..., 5.7611e-01,
3.2619e-01, 1.1523e-01],
[6.1469e-01, 7.4528e-01, 7.9559e-02, ..., 9.7112e-01,
1.8391e-01, 8.9883e-01],
[8.6677e-02, 3.5051e-02, 1.6875e-01, ..., 3.9833e-01,
6.7967e-01, 4.7062e-01],
...,
[7.1648e-01, 1.8378e-01, 5.3054e-01, ..., 8.4282e-01,
9.1972e-01, 7.0031e-01],
[5.9876e-01, 6.7868e-01, 6.4128e-01, ..., 4.9516e-02,
7.2571e-01, 5.8792e-01],
[7.6723e-01, 6.9527e-01, 9.3573e-01, ..., 6.3490e-02,
6.6129e-01, 2.4517e-01]],
[[5.0158e-01, 8.2565e-01, 7.5532e-01, ..., 6.9342e-01,
3.3244e-01, 5.3913e-01],
[2.3347e-01, 9.7822e-02, 1.5009e-01, ..., 5.5090e-01,
9.1813e-01, 7.9857e-01],
[7.2416e-02, 5.9086e-01, 1.2243e-01, ..., 7.8511e-01,
2.4803e-01, 5.3717e-01],
...,
[7.4899e-01, 1.5467e-02, 4.9711e-01, ..., 2.2938e-02,
1.6099e-01, 3.1928e-01],
[3.9111e-01, 1.2422e-01, 6.1795e-02, ..., 8.4212e-01,
6.1346e-01, 1.0957e-01],
[3.6311e-02, 8.9652e-01, 7.7428e-01, ..., 9.2212e-01,
4.9290e-01, 4.5609e-01]],
[[2.2052e-01, 4.4260e-01, 8.8627e-01, ..., 9.2381e-01,
7.7046e-01, 9.2057e-01],
[5.5775e-01, 8.8951e-01, 7.9238e-01, ..., 3.9209e-01,
9.6636e-01, 8.1876e-01],
[3.4709e-01, 7.8678e-01, 1.4396e-01, ..., 7.9073e-01,
3.9021e-01, 8.5285e-01],
...,
[1.4238e-01, 9.8432e-01, 2.7802e-01, ..., 5.1720e-01,
1.6290e-01, 8.2036e-01],
[2.0184e-01, 1.0635e-01, 1.9612e-01, ..., 9.7101e-01,
9.6679e-01, 7.0811e-01],
[5.8240e-01, 3.1642e-01, 9.6549e-01, ..., 5.1130e-02,
5.6725e-01, 3.5238e-01]]]], device='npu:0')
x1 = torch.rand(16, 16, 128, 128).npu()
tensor([[[[2.4353e-01, 8.5665e-01, 5.3571e-01, ..., 5.9101e-01,
4.0872e-01, 6.3873e-01],
[1.4489e-01, 8.7982e-01, 3.3114e-01, ..., 2.5155e-01,
8.4987e-01, 8.7096e-01],
[6.5837e-02, 2.2677e-02, 7.2063e-01, ..., 2.3542e-01,
9.3041e-01, 8.9596e-01],
...,
[5.1450e-01, 7.9412e-01, 8.9288e-01, ..., 3.3639e-01,
5.6086e-01, 4.8770e-02],
[4.7557e-01, 1.4793e-01, 4.9800e-01, ..., 3.9479e-01,
5.6052e-01, 9.8271e-01],
[7.4438e-01, 7.5646e-01, 2.7942e-02, ..., 3.0381e-01,
4.3703e-01, 1.4037e-02]],
[[4.0232e-01, 9.4407e-01, 6.4969e-01, ..., 3.4524e-01,
8.2647e-01, 5.4792e-01],
[1.1801e-01, 1.8281e-01, 6.1723e-01, ..., 1.9393e-01,
4.5877e-01, 8.9990e-01],
[2.6244e-01, 6.9614e-01, 3.6008e-01, ..., 5.0258e-01,
8.1919e-01, 4.6943e-01],
...,
[7.4710e-01, 5.8911e-01, 1.5292e-01, ..., 6.6590e-01,
4.0754e-01, 3.6944e-01],
[9.0501e-01, 2.7943e-01, 3.7068e-01, ..., 1.5053e-01,
7.3413e-01, 7.9626e-01],
[9.5200e-01, 7.8327e-01, 3.4033e-01, ..., 8.0892e-01,
4.0480e-01, 3.8717e-01]],
[[7.5938e-01, 2.9089e-01, 5.9916e-01, ..., 6.2526e-01,
3.9670e-01, 3.3548e-01],
[7.0733e-01, 8.1400e-01, 4.9259e-01, ..., 1.6607e-02,
6.5331e-01, 7.3150e-02],
[5.2770e-01, 7.8141e-01, 4.1904e-01, ..., 3.8917e-01,
4.1405e-01, 9.9596e-01],
...,
[4.8669e-01, 9.9948e-01, 1.2023e-01, ..., 7.0420e-01,
2.8522e-01, 6.6192e-01],
[4.9718e-01, 7.5792e-01, 6.6748e-01, ..., 9.7302e-01,
3.3443e-01, 3.6536e-01],
[7.7033e-01, 6.0550e-01, 8.2024e-01, ..., 2.9711e-01,
1.9410e-01, 6.6304e-01]],
...,
[[1.0284e-01, 6.5712e-01, 6.0831e-01, ..., 6.2622e-01,
2.0355e-01, 9.4250e-01],
[4.9053e-01, 2.0148e-01, 2.4974e-01, ..., 9.2521e-01,
1.9919e-01, 4.4700e-01],
[7.6515e-01, 8.7755e-01, 1.3500e-01, ..., 8.2136e-01,
2.0848e-01, 5.6432e-01],
...,
[3.3618e-01, 1.8585e-01, 5.3475e-01, ..., 4.9333e-01,
9.1018e-01, 9.5052e-01],
[2.1400e-01, 1.7407e-01, 5.8925e-01, ..., 7.5722e-01,
2.9850e-01, 3.9298e-01],
[6.3625e-01, 1.7168e-01, 2.9183e-01, ..., 9.9674e-01,
2.1718e-01, 5.2626e-01]],
[[1.8651e-01, 2.5385e-01, 2.0384e-01, ..., 3.4462e-01,
8.4150e-01, 4.7431e-01],
[2.4992e-01, 1.1788e-01, 1.9730e-01, ..., 4.3722e-02,
7.8943e-01, 9.9097e-01],
[1.4493e-02, 6.4856e-01, 8.3344e-01, ..., 8.6623e-01,
1.5456e-01, 7.8423e-01],
...,
[6.1458e-01, 4.4260e-01, 7.4133e-01, ..., 2.5126e-01,
2.7251e-01, 6.9784e-01],
[2.2419e-01, 3.4159e-01, 2.3232e-01, ..., 8.2850e-01,
8.2644e-02, 4.8390e-01],
[1.0171e-01, 8.7662e-01, 2.0457e-01, ..., 7.6868e-01,
7.6592e-01, 3.1254e-01]],
[[1.8866e-01, 1.5755e-01, 3.1025e-02, ..., 6.5044e-01,
7.8293e-01, 9.8030e-01],
[3.7703e-01, 5.3198e-01, 1.8633e-01, ..., 4.7398e-01,
8.3618e-01, 8.7283e-01],
[5.7119e-01, 4.3620e-01, 8.2536e-01, ..., 2.5390e-01,
5.6144e-01, 4.4044e-01],
...,
[1.3243e-01, 6.2002e-02, 7.5278e-01, ..., 7.5907e-01,
4.2472e-01, 1.7624e-01],
[4.7985e-01, 7.9769e-01, 8.1433e-01, ..., 7.3780e-01,
2.2877e-02, 4.8816e-01],
[4.5100e-01, 9.9698e-02, 7.0776e-01, ..., 9.8046e-01,
2.2372e-01, 8.6304e-01]]]], device='npu:0')
_, _, out = torch_npu.npu_dropout_with_add_softmax(self, x1, 2, 0.9, -1)
tensor([[[[0.0000, 0.0639, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0632, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0794, ..., 0.0000, 0.0000, 0.1571],
[0.0000, 0.0000, 0.0000, ..., 0.1270, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.1030, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.2134, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0342, 0.0000, 0.0633, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.1578, 0.1334, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.2316, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0237, 0.0000, ..., 0.0000, 0.2128, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.1421, 0.0000, 0.0000, ..., 0.0499, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0218, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1461, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.1130, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.1976, ..., 0.0000, 0.0000, 0.0000]]]],
device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_dtype_cast",
"""
torch_npu.npu_dtype_cast(input, dtype, input_dtype=None) -> Tensor
功能描述
执行张量数据类型(dtype)转换。支持FakeTensor模式。
参数说明
input (Tensor) - 输入张量。
dtype (int) - 返回张量的目标数据类型。
input_dtype (int) - 输入张量的数据类型, 默认值为None。为None时,采用输入张量原本的数据类型。
示例
示例一:
>>> torch_npu.npu_dtype_cast(torch.tensor([0, 0.5, -1.]).npu(), dtype=torch.int)
tensor([ 0, 0, -1], device='npu:0', dtype=torch.int32)
示例二:
//FakeTensor模式
>>> from torch._subclasses.fake_tensor import FakeTensorMode
>>> with FakeTensorMode():
... x = torch.rand(2, dtype=torch.float32).npu()
... res = torch_npu.npu_dtype_cast(x, torch.float16)
...
>>> res
FakeTensor(..., device='npu:0', size=(2,), dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_format_cast",
"""
torch_npu.npu_format_cast(self, acl_format) -> Tensor
功能描述
修改NPU张量的格式。
参数说明
self (Tensor) - 输入张量。
acl_format (Int) - 目标格式。
示例
>>> x = torch.rand(2, 3, 4, 5).npu()
>>> torch_npu.get_npu_format(x)
0
>>> x1 = torch_npu.npu_format_cast(x, 29)
>>> torch_npu.get_npu_format(x1)
29
"""
)
_add_torch_npu_docstr(
"npu_format_cast_",
"""
torch_npu.npu_format_cast_(self, src) -> Tensor
功能描述
原地修改self张量格式,与src格式保持一致。
参数说明
self (Tensor) - 输入张量。
src (Tensor,int) - 目标格式。
示例
>>> x = torch.rand(2, 3, 4, 5).npu()
>>> torch_npu.get_npu_format(x)
0
>>> torch_npu.get_npu_format(torch_npu.npu_format_cast_(x, 29))
29
"""
)
_add_torch_npu_docstr(
"npu_fused_attention_score",
"""
torch_npu.npu_fused_attention_score(Tensor query_layer, Tensor key_layer, Tensor value_layer, Tensor attention_mask, Scalar scale, float keep_prob, bool query_transpose=False, bool key_transpose=False, bool bmm_score_transpose_a=False, bool bmm_score_transpose_b=False, bool value_transpose=False, bool dx_transpose=False) -> Tensor
功能描述
实现“Transformer attention score”的融合计算逻辑,主要将matmul、transpose、add、softmax、dropout、batchmatmul、permute等计算进行了融合。
参数说明
query_layer:Tensor类型,仅支持float16。
key_layer:Tensor类型,仅支持float16。
value_layer:Tensor类型,仅支持float16 。
attention_mask:Tensor类型,仅支持float16 。
scale:缩放系数,浮点数标量 。
keep_prob:不做dropout的概率,0-1之间,浮点数。
query_transpose:query是否做转置,bool类型,默认为False 。
key_transpose:key是否做转置,bool类型,默认为False 。
bmm_score_transpose_a:bmm计算中左矩阵是否做转置,bool类型,默认为False。
bmm_score_transpose_b:bmm计算中右矩阵是否做转置,bool类型,默认为False。
value_transpose:value是否做转置,bool类型,默认为False。
dx_transpose:反向计算时dx是否做转置,bool类型,默认为False。
约束说明
输入tensor的格式编号必须均为29,数据类型为FP16。
支持的型号:
Atlas 训练系列产品
示例
>>> import torch
>>> import torch_npu
>>> query_layer = torch_npu.npu_format_cast(torch.rand(24, 16, 512, 64).npu(), 29).half()
>>> key_layer = torch_npu.npu_format_cast(torch.rand(24, 16, 512, 64).npu(), 29).half()
>>> value_layer = torch_npu.npu_format_cast(torch.rand(24, 16, 512, 64).npu(), 29).half()
>>> attention_mask = torch_npu.npu_format_cast(torch.rand(24, 16, 512, 512).npu(), 29).half()
>>> scale = 0.125
>>> keep_prob = 0.5
>>> context_layer = torch_npu.npu_fused_attention_score(query_layer, key_layer, value_layer, attention_mask, scale, keep_prob)
>>> print(context_layer)
tensor([[0.5010, 0.4709, 0.4841, ..., 0.4321, 0.4448, 0.4834],
[0.5107, 0.5049, 0.5239, ..., 0.4436, 0.4375, 0.4651],
[0.5308, 0.4944, 0.5005, ..., 0.5010, 0.5103, 0.5303],
...,
[0.5142, 0.5068, 0.5176, ..., 0.5498, 0.4868, 0.4805],
[0.4941, 0.4731, 0.4863, ..., 0.5161, 0.5239, 0.5190],
[0.5459, 0.5107, 0.5415, ..., 0.4641, 0.4688, 0.4531]],
device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_lightning_indexer",
"""
功能实现描述
LightningIndexer基于一系列操作得到每一个 token 对应的 Top-k 个位置。
函数原型
custom.npu_lightning_indexer(query, key, weights, *, actual_seq_lengths_query=None, actual_seq_lengths_key=None, block_table=None, layout_query='BSND', layout_key='BSND', sparse_count=2048, sparse_mode=3) -> Tensor
参数说明
key(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16,layout_key为PA_BSND时shape为[block_count, block_size, N2, D],其中block_count为PageAttention时block总数,block_size为一个block的token数。
weights(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16,支持输入shape[B,S1,N1]、[T,N1]。
*:代表其之前的参数是位置相关的,必须按照顺序输入,属于必选参数;其之后的参数是键值对赋值,与位置无关,属于可选参数(不传入会使用默认值)。
actual_seq_lengths_query(Tensor):可选参数,表示不同Batch中query的有效token数,数据类型支持int32。如果不指定seqlen可传入None,表示和query的shape的S长度相同。
该入参中每个Batch的有效token数不超过query中的维度S大小。支持长度为B的一维tensor。当query的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为B值,该入参中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须>=前一个元素的值。不能出现负值。
actual_seq_lengths_key(Tensor):可选参数,表示不同Batch中key的有效token数,数据类型支持int32。如果不指定seqlen可传入None,表示和key的shape的S长度相同。支持长度为B的一维tensor。
block_table(Tensor):可选参数,表示PageAttention中KV存储使用的block映射表,数据格式支持ND,数据类型支持int32。
PageAttention场景下,block_table必须为二维,第一维长度需要等于B,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actual_seq_lengths_key对应的block数量)
layout_query(str):可选参数,用于标识输入query的数据排布格式,当前支持BSND、TND,默认值"BSND"。
layout_key(str):可选参数,用于标识输入key的数据排布格式,当前支持PA_BSND、BSND、TND,默认值"BSND",在非PageAttention场景下,该参数值应与layout_query值保持一致。
sparse_count(int):可选参数,代表topK阶段需要保留的block数量,支持1-2048以及3072、4096、5120、6144、7168、8192,数据类型支持int32。
sparse_mode(int):可选参数,表示sparse的模式,支持0/3,数据类型支持int32。
sparse_mode为0时,代表defaultMask模式。
sparse_mode为3时,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景。
out(Tensor):公式中的输出,数据类型支持int32。数据格式支持ND。
说明:
query、key、weights参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
S1表示query shape中的S,S2表示key shape中的S,N1表示query shape中的N,N2表示key shape中的N。
query(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16。
"""
)
_add_torch_npu_docstr(
"npu_quant_lightning_indexer",
"""
功能实现描述
QuantLightningIndexer在LightningIndexer的基础上支持了Per-Token-Head量化输入。
接口原型
custom.npu_quant_lightning_indexer(query, key, weights, query_dequant_scale, key_dequant_scale, query_quant_mode, key_quant_mode, *, actual_seq_lengths_query=None, actual_seq_lengths_key=None, block_table=None, layout_query='BSND', layout_key='BSND', sparse_count=2048, sparse_mode=3, pre_tokens=2^63-1, next_tokens=2^63-1) -> Tensor
参数说明
query(Tensor):必选参数,不支持非连续,数据格式支持ND,Atlas A3 推理系列产品数据类型支持int8,Ascend 950PR/Ascend 950DT数据类型支持float8_e4m3fn, hifloat8。layout_query为BSND时shape为[B,S1,N1,D],当layout_query为TND时shape为[T1,N1,D],N1支持小于等于64。
key(Tensor):必选参数,不支持非连续,数据格式支持ND,Atlas A3 推理系列产品数据类型支持int8,Ascend 950PR/Ascend 950DT数据类型支持float8_e4m3fn,hifloat8。layout_key为PA_BSND时shape为[block_count, block_size, N2, D],其中block_count为PageAttention时block总数,block_size为一个block的token数。
weights(Tensor):必选参数,不支持非连续,数据格式支持ND,Atlas A3 推理系列产品数据类型支持float16,Ascend 950PR/Ascend 950DT数据类型支持float16,bfloat16。支持输入shape[B,S1,N1]、[T,N1]。
query_dequant_scale(Tensor):必选参数,不支持非连续,数据格式支持ND,Atlas A3 推理系列产品数据类型支持float16,Ascend 950PR/Ascend 950DT数据类型支持float16,float32,支持输入shape[B,S1,N1]、[T,N1]。
key_dequant_scale(Tensor):必选参数,不支持非连续,数据格式支持ND,Atlas A3 推理系列产品数据类型支持float16,Ascend 950PR/Ascend 950DT数据类型支持float16,float32,layout_key为PA_BSND时shape为[block_count, block_size, N2],其中block_count为PageAttention时block总数,block_size为一个block的token数。
*:代表其之前的参数是位置相关的,必须按照顺序输入,属于必选参数;其之后的参数是键值对赋值,与位置无关,属于可选参数(不传入会使用默认值)。
actual_seq_lengths_query(Tensor):可选参数,表示不同Batch中query的有效token数,数据类型支持int32。如果不指定seqlen可传入None,表示和query的shape的S长度相同。
该入参中每个Batch的有效token数不超过query中的维度S大小。支持长度为B的一维tensor。当query的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为B值,该入参中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须>=前一个元素的值。不能出现负值。
actual_seq_lengths_key(Tensor):可选参数,表示不同Batch中key的有效token数,数据类型支持int32。如果不指定seqlen可传入None,表示和key的shape的S长度相同。支持长度为B的一维tensor。
block_table(Tensor):可选参数,表示PageAttention中KV存储使用的block映射表,数据格式支持ND,数据类型支持int32。
PageAttention场景下,block_table必须为二维,第一维长度需要等于B,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actual_seq_lengths_key对应的block数量)
query_quant_mode(int):可选参数,用于标识输入query的量化模式,当前支持Per-Token-Head量化模式,默认值0。
key_quant_mode(int):可选参数,用于标识输入key的量化模式,当前支持Per-Token-Head量化模式,默认值0。
layout_query(str):可选参数,用于标识输入query的数据排布格式,当前支持BSND、TND,默认值"BSND"。
layout_key(str):可选参数,用于标识输入key的数据排布格式,当前支持PA_BSND、BSND、TND,默认值"BSND"。在非PageAttention场景下,layout_key应与layout_query保持一致。
sparse_count(int):可选参数,代表topK阶段需要保留的block数量,支持1-2048,数据类型支持int32。
sparse_mode(int):可选参数,表示sparse的模式,支持0/3,数据类型支持int32。
sparse_mode为0时,代表defaultMask模式。
sparse_mode为3时,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景。
out(Tensor):公式中的输出,数据类型支持int32。数据格式支持ND。
说明:
query、key、weights、query_dequant_scale、key_dequant_scale参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
S1表示query shape中的S,S2表示key shape中的S,N1表示query shape中的N,N2表示key shape中的N。
"""
)
_add_torch_npu_docstr(
"npu_sparse_flash_attention",
"""
功能实现描述
随着大模型上下文长度的增加,Sparse Attention的重要性与日俱增,这一技术通过“只计算关键部分”大幅减少计算量,然而会引入大量的离散访存,造成数据搬运时间增加,进而影响整体性能。
接口原型
custom.npu_sparse_flash_attention(Tensor query, Tensor key, Tensor value, Tensor sparse_indices, double scale_value, int sparse_block_size, *, Tensor? block_table=None, Tensor? actual_seq_lengths_query=None, Tensor? actual_seq_lengths_kv=None, Tensor? query_rope=None, Tensor? key_rope=None, str layout_query='BSND', str layout_kv='BSND', int sparse_mode=3) -> Tensor
参数说明
key(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16,layout_kv为PA_BSND时shape为[block_num, block_size, KV_N, D],其中block_num为PageAttention时block总数,block_size为一个block的token数。
value(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16。
sparse_indices(Tensor):必选参数,代表离散取kvCache的索引,不支持非连续,数据格式支持ND,数据类型支持int32。当query的layout为BSND时,shape需要传入[B, Q_S, KV_N, sparse_size],当query的layout为TND时,shape需要传入[Q_T, KV_N, sparse_size],其中sparse_size为一次离散选取的token数,需要保证每行有效值均在前半部分,无效值均在后半部分。
scale_value(double):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持double。
sparse_block_size(int):必选参数,代表sparse阶段的block大小,在计算importance score时使用,数据类型支持int64。
*:代表其之前的参数是位置相关的,必须按照顺序输入,属于必选参数;其之后的参数是键值对赋值,与位置无关,属于可选参数(不传入会使用默认值)。
block_table(Tensor):可选参数,表示PageAttention中kvCache存储使用的block映射表。数据格式支持ND,数据类型支持int32,shape为2维,其中第一维长度为B,第二维长度不小于所有batch中最大的s2对应的block数量,即s2_max / block_size向上取整。
actual_seq_lengths_query(Tensor):可选参数,表示不同Batch中query的有效token数,数据类型支持int32。如果不指定seqlen可传入None,表示和query的shape的S长度相同。
该入参中每个Batch的有效token数不超过query中的维度S大小。支持长度为B的一维tensor。当query的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为B值,该入参中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须>=前一个元素的值。不能出现负值。
actual_seq_lengths_kv(Tensor):可选参数,表示不同Batch中key和value的有效token数,数据类型支持int32。如果不指定None,表示和key的shape的S长度相同。
该入参中每个Batch的有效token数不超过key/value中的维度S大小且不小于0。支持长度为B的一维tensor。
query_rope(Tensor):可选参数,表示MLA结构中的query的rope信息,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16。
key_rope(Tensor):可选参数,表示MLA结构中的key的rope信息,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16。
layout_query(str):可选参数,用于标识输入query的数据排布格式,用户不特意指定时可传入默认值"BSND",支持传入BSND和TND。
说明: 1、query数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示hidden层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
layout_kv(str):可选参数,用于标识输入key的数据排布格式,用户不特意指定时可传入默认值"BSND",支持传入TND、BSND和PA_BSND,其中PA_BSND在使能PageAttention时使用。
sparse_mode(int):可选参数,表示sparse的模式。数据类型支持int64。
sparse_mode为0时,代表全部计算。
sparse_mode为3时,代表rightDownCausal模式的mask,对应以右下顶点往左上为划分线的下三角场景。
说明:
query、key、value参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
Q_S和S1表示query shape中的S,KV_S和S2表示key shape中的S,Q_N表示num_query_heads,KV_N表示num_key_value_heads。
query(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16和float16。
"""
)
_add_torch_npu_docstr(
"npu_kv_quant_sparse_flash_attention",
"""
功能实现描述
QuantSparseFlashAttentionAnti在SparseFlashAttention的基础上支持了Per-Token-Head-Tile-128量化输入。
接口原型
custom.npu_kv_quant_sparse_flash_attention(Tensor query, Tensor key, Tensor value, Tensor sparse_indices, double scale_value, int sparse_block_size, int key_quant_mode, int value_quant_mode, *, Tensor? key_dequant_scale=None, Tensor? value_dequant_scale=None, Tensor? block_table=None, Tensor? actual_seq_lengths_query=None, Tensor? actual_seq_lengths_kv=None, str layout_query='BSND', str layout_kv='BSND', int sparse_mode=3, int attention_mode=0, int quant_scale_repo_mode=0, int tile_size=0, int rope_head_dim=0) -> Tensor
query(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持bfloat16,query相同dtype的q_nope和q_rope按D维度拼接得到。
key(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持int8,int8的k_nope、query相同dtype的k_rope和float32的量化参数按D维度拼接得到,layout_kv为PA_BSND时shape为[block_num, block_size, KV_N, D],其中block_num为PageAttention时block总数,block_size为一个block的token数。
value(Tensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持int8。
sparse_indices(Tensor):必选参数,代表离散取kvCache的索引,不支持非连续,数据格式支持ND,数据类型支持int32,shape需要传入[B, Q_S, KV_N, sparse_size],其中sparse_size为一次离散选取的token数,需要保证每行有效值均在前半部分,无效值均在后半部分。
scale_value(double):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持double。
sparse_block_size(int):必选参数,代表sparse阶段的block大小,在计算importance score时使用,数据类型支持int64。
key_quant_mode(int):必选参数,代表key的量化模式,数据类型支持int64,支持传入2,代表per_tile量化模式。
value_quant_mode(int):必选参数,代表value的量化模式,数据类型支持int64,支持传入2,代表per_tile量化模式。
*:代表其之前的参数是位置相关的,必须按照顺序输入,属于必选参数;其之后的参数是键值对赋值,与位置无关,属于可选参数(不传入会使用默认值)。
key_dequant_scale(Tensor):可选参数,预留参数,当前不支持。
value_dequant_scale(Tensor):可选参数,预留参数,当前不支持。
block_table(Tensor):可选参数,表示PageAttention中kvCache存储使用的block映射表。数据格式支持ND,数据类型支持int32,shape为2维,其中第一维长度为B,第二维长度不小于所有batch中最大的s2对应的block数量,即s2_max / block_size向上取整。
actual_seq_lengths_query(Tensor):可选参数,表示不同Batch中query的有效token数,数据类型支持int32。如果不指定seqlen可传入None,表示和query的shape的S长度相同。
该入参中每个Batch的有效token数不超过query中的维度S大小。支持长度为B的一维tensor。当query的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为B值,该入参中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须>=前一个元素的值。不能出现负值。
actual_seq_lengths_kv(Tensor):可选参数,表示不同Batch中key和value的有效token数,数据类型支持int32。如果不指定None,表示和key的shape的S长度相同。
该入参中每个Batch的有效token数不超过key/value中的维度S大小且不小于0。支持长度为B的一维tensor。
layout_query(str):可选参数,用于标识输入query的数据排布格式,用户不特意指定时可传入默认值"BSND",支持传入BSND和TND。
说明: 1、query数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示hidden层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
layout_kv(str):可选参数,用于标识输入key的数据排布格式,用户不特意指定时可传入默认值"BSND",支持传入PA_BSND,PA_BSND在使能PageAttention时使用。
sparse_mode(int):可选参数,表示sparse的模式。数据类型支持int64。
sparse_mode为0时,代表全部计算。
sparse_mode为3时,代表rightDownCausal模式的mask,对应以右下顶点往左上为划分线的下三角场景。
attention_mode(int):可选参数,表示attention的模式。数据类型支持int64,支持传入2,表示MLA-absorb模式,即QK的D包含rope和nope两部分,且KV是同一份,默认值为0。
quant_scale_repo_mode(int):可选参数,表示量化参数的存放模式。数据类型支持int64,支持传入1,表示combine模式,即量化参数和数据混合存放,默认值为0。
tile_size(int):可选参数,表示per_tile时每个参数对应的数据块大小,仅在per_tile时有效。数据类型支持int64,默认值为0。
rope_head_dim(int):可选参数,表示MLA架构下的rope head dim大小,仅在attention_mode为2时有效。数据类型支持int64,默认值为0。
out(Tensor):公式中的输出。数据格式支持ND,数据类型支持bfloat16。
说明:
query、key、value参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
Q_S和S1表示query shape中的S,KV_S和S2表示key shape中的S,Q_N表示num_query_heads,KV_N表示num_key_value_heads。
"""
)
_add_torch_npu_docstr(
"npu_fusion_attention",
"""
功能描述实现
“Transformer Attention Score”的融合计算, 实现的计算公式如下:
$y=Softmax(Mask(scale*(pse+query*key^{T}),atten_mask),keep_prob)$
$attention=Dropout(y)*value$
接口原型
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False, str softmax_layout="") -> (Tensor, Tensor, Tensor, Tensor, int, int, int)
参数说明
query: Tensor类型, 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 综合约束请见约束说明.
key: Tensor类型, 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 综合约束请见约束说明.
value: Tensor类型, 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 综合约束请见约束说明.
head_num: int类型, 代表head个数, 数据类型支持int64. 综合约束请见约束说明.
input_layout: string类型, 代表输入query、key、value的数据排布格式, 支持BSH、SBH、BSND、BNSD、TND(actual_seq_qlen/actual_seq_kvlen需传值); 后续章节如无特殊说明, S表示query或key、value的sequence length, Sq表示query的sequence length, Skv表示key、value的sequence length, SS表示Sq*Skv.
pse: Tensor类型, 可选参数, 表示位置编码. 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 非varlen场景支持四维输入, 包含BNSS格式、BN1Skv格式、1NSS格式. 如果非varlen场景Sq大于1024或varlen场景、每个batch的Sq与Skv等长且是sparse_mode为0、2、3的下三角掩码场景, 可使能alibi位置编码压缩, 此时只需要输入原始PSE最后1024行进行内存优化, 即alibi_compress = ori_pse[:, :, -1024:, :], 参数每个batch不相同时, 输入BNHSkv(H=1024), 每个batch相同时, 输入1NHSkv(H=1024).
padding_mask: Tensor类型, 暂不支持该传参.
atten_mask: Tensor类型, 可选参数, 取值为1代表该位不参与计算(不生效), 为0代表该位参与计算, 数据类型支持bool、uint8, 数据格式支持ND, 输入shape类型支持BNSS格式、B1SS格式、11SS格式、SS格式. varlen场景只支持SS格式, SS分别是maxSq和maxSkv. 综合约束请见约束说明.
scale: 浮点型, 可选参数, 代表缩放系数, 作为计算流中Muls的scalar值, 数据类型支持float, 默认值为1.
keep_prob: 浮点型, 可选参数, 代表Dropout中1的比例, 取值范围为(0, 1]. 数据类型支持float, 默认值为1, 表示全部保留.
pre_tockens: 整型, 用于稀疏计算的参数, 可选参数, 数据类型支持int64, 默认值为2147483647. 综合约束请见约束说明.
next_tockens: 整型, 用于稀疏计算的参数, 可选参数, 数据类型支持int64, 默认值为2147483647. next_tockens和pre_tockens取值与atten_mask的关系请参见sparse_mode参数, 参数取值与atten_mask分布不一致会导致精度问题. 综合约束请见约束说明.
inner_precise: 整型, 用于提升精度, 数据类型支持int64, 默认值为0.
当前0、1为保留配置值, 2为使能无效行计算, 其功能是避免在计算过程中存在整行mask进而导致精度有损失, 但是该配置会导致性能下降.
如果算子可判断出存在无效行场景, 会自动使能无效行计算, 例如sparse_mode为3, Sq > Skv场景.
prefix: int类型数组, 可选参数, 代表prefix稀疏计算场景每个Batch的N值. 数据类型支持int64, 数据格式支持ND. 综合约束请见约束说明.
actual_seq_qlen: int类型数组, 可选参数, varlen场景时需要传入此参数. 表示query每个S的累加和长度, 数据类型支持int64, 数据格式支持ND. 综合约束请见约束说明.
比如真正的S长度列表为: 2 2 2 2 2, 则actual_seq_qlen传: 2 4 6 8 10.
actual_seq_kvlen: int类型数组, 可选参数, varlen场景时需要传入此参数. 表示key/value每个S的累加和长度. 数据类型支持int64, 数据格式支持ND. 综合约束请见约束说明.
比如真正的S长度列表为: 2 2 2 2 2, 则actual_seq_kvlen传: 2 4 6 8 10.
sparse_mode: 整型, 表示sparse的模式, 可选参数. 数据类型支持int64, 默认值为0, 支持配置值为0、1、2、3、4、5、6、7、8. 当整网的atten_mask都相同且shape小于2048*2048时, 建议使用defaultMask模式, 来减少内存使用量. 综合约束请见约束说明.
softmax_layout: string类型,可选参数,用于控制TND场景下softmax的输出(softmax_max和softmax_sum)的数据排布方式。当前仅在input\_layout=“TND”时进行配置,仅支持传入“TND”。默认情况下,softmax的输出排布为NTD排布;传入TND时,softmax的输出排布为TND排布。
表1 sparse_mode不同取值场景说明
sparse_mode
0: defaultMask模式.
1: allMask模式.
2: leftUpCausal模式.
3: rightDownCausal模式.
4: band模式.
5: prefix非压缩模式. varlen场景不支持.
6: prefix压缩模式.
7: varlen外切场景, rightDownCausal模式. 仅varlen场景支持.
8: varlen外切场景, leftUpCausal模式. 仅varlen场景支持.
atten_mask的工作原理为, 在Mask为True的位置遮蔽query(Q)与key(K)的转置矩阵乘积的值. QKT矩阵在atten_mask为True的位置会被遮蔽
说明: 保留该值, atten_mask中, 应该配置为False; 遮蔽该值, atten_mask中应配置为True. sparse_mode为0时, 代表defaultMask模式. 不传mask: 如果atten_mask未传入则不做mask操作, atten_mask取值为None, 忽略pre_tockens和next_tockens取值.
next_tockens取值为0, pre_tockens大于等于Sq, 表示causal场景sparse, atten_mask应传入下三角矩阵, 此时pre_tockens和next_tockens之间的部分需要计算,atten_mask应传入下三角矩阵
pre_tockens小于Sq, next_tockens小于Skv, 且都大于等于0, 表示band场景, 此时pre_tockens和next_tockens之间的部分需要计算. atten_mask应传入band形状矩阵
next_tockens为负数, 以pre_tockens=9, next_tockens=-3为例, pre_tockens和next_tockens之间的部分需要计算. 说明: next_tockens为负数时, pre_tockens取值必须大于等于next_tockens的绝对值, 且next_tockens的绝对值小于Skv.
pre_tockens为负数, 以next_tockens=7, pre_tockens=-3为例, pre_tockens和next_tockens之间的部分需要计算. 说明: pre_tockens为负数时, next_tockens取值必须大于等于pre_tockens的绝对值, 且pre_tockens的绝对值小于Sq.
sparse_mode为1时, 代表allMask, 即传入完整的atten_mask矩阵. 该场景下忽略next_tockens、pre_tockens取值
sparse_mode为2时, 代表leftUpCausal模式的mask, 对应以左上顶点划分的下三角场景(参数起点为左上角). 该场景下忽略pre_tockens、next_tockens取值.传入的atten_mask为优化后的压缩下三角矩阵(2048*2048)
sparse_mode为3时, 代表rightDownCausal模式的mask, 对应以右下顶点划分的下三角场景(参数起点为右下角). 该场景下忽略pre_tockens、next_tockens取值. atten_mask为优化后的压缩下三角矩阵(2048*2048)
sparse_mode为4时, 代表band场景, 即计算pre_tockens和next_tockens之间的部分, 参数起点为右下角, pre_tockens和next_tockens之间需要有交集. atten_mask为优化后的压缩下三角矩阵(2048*2048).
sparse_mode为5时, 代表prefix非压缩场景, 即在rightDownCasual的基础上, 左侧加上一个长为Sq, 宽为N的矩阵, N的值由可选输入prefix获取, 例如下图中表示batch=2场景下prefix传入数组[4,5], 每个batch轴的N值可以不一样, 参数起点为左上角. 该场景下忽略pre_tockens、next_tockens取值, atten_mask矩阵数据格式须为BNSS或B1SS
sparse_mode为6时, 代表prefix压缩场景, 即prefix场景时, attenMask为优化后的压缩下三角+矩形的矩阵(3072*2048): 其中上半部分[2048, 2048]的下三角矩阵, 下半部分为[1024,2048]的矩形矩阵, 矩形矩阵左半部分全0, 右半部分全1. 该场景下忽略pre_tockens、next_tockens取值.
sparse_mode为7时, 表示varlen且为长序列外切场景(即长序列在模型脚本中进行多卡切query的sequence length); 用户需要确保外切前为使用sparse_mode 3的场景; 当前mode下用户需要设置pre_tockens和next_tockens(起点为右下顶点), 且需要保证参数正确, 否则会存在精度问题. Masked QKT矩阵示意如下, 在第二个batch对query进行切分, key和value不切分, 4x6的mask矩阵被切分成2x6和2x6的mask, 分别在卡1和卡2上计算: 卡1的最后一块mask为band类型的mask, 配置pre_tockens=6(保证大于等于最后一个Skv), next_tockens=-2, actual_seq_qlen应传入{3,5}, actual_seq_kvlen应传入{3,9}. 卡2的mask类型切分后不变, sparse_mode为3, actual_seq_qlen应传入{2,7,11}, actual_seq_kvlen应传入{6,11,15}.
如果配置sparse_mode=7, 但实际只存在一个batch, 用户需按照band模式的要求来配置参数; sparse_mode=7时, 用户需要输入2048x2048的下三角mask作为该融合算子的输入.
基于sparse_mode=3进行外切产生的band模式的sparse的参数应符合以下条件:
pre_tockens >= last_Skv.
next_tockens <= 0.
当前模式下不支持可选输入pse.
sparse_mode为8时, 表示varlen且为长序列外切场景; 用户需要确保外切前为使用sparse_mode 2的场景; 当前mode下用户需要设置pre_tockens和next_tockens(起点为右下顶点), 且需要保证参数正确, 否则会存在精度问题. Masked QKT矩阵示意如下, 在第二个batch对query进行切分, key和value不切分, 5x4的mask矩阵被切分成2x4和3x4的mask, 分别在卡1和卡2上计算: 卡1的mask类型切分后不变, sparse_mode为2, actual_seq_qlen应传入{3,5}, actual_seq_kvlen应传入{3,7}. 卡2的第一块mask为band类型的mask, 配置pre_tockens=4(保证大于等于第一个Skv), next_tockens=1, actual_seq_qlen应传入{3,8,12}, actual_seq_kvlen应传入{4,9,13}.
如果配置sparse_mode=8, 但实际只存在一个batch, 用户需按照band模式的要求来配置参数; sparse_mode=8时, 用户需要输入2048x2048的下三角mask作为该融合算子的输入.
基于sparse_mode=2进行外切产生的band模式的sparse的参数应符合以下条件:
pre_tockens >= first_Skv.
next_tockens范围无约束, 根据实际情况进行配置.
当前模式下不支持可选输入pse.
gen_mask_parallel: 布尔型, DSA生成dropout随机数向量mask的控制开关. 默认值为True: 同AI Core计算并行; 设为False: 同AI Core计算串行.
sync: 布尔型, DSA生成dropout随机数向量mask的控制开关. 默认值为False: dropout mask异步生成; 设为True: dropout mask同步生成.
输出说明
共7个输出, 类型依次为Tensor、Tensor、Tensor、Tensor、int、int、int.
第1个输出为Tensor, 计算公式的最终输出attention_out, 数据类型支持float16、bfloat16、float32.
第2个输出为Tensor, Softmax计算的Max中间结果, 用于反向计算, 数据类型支持float.
第3个输出为Tensor, Softmax计算的Sum中间结果, 用于反向计算, 数据类型支持float.
第4个输出为Tensor, 预留参数, 暂未使用.
第5个输出为int, DSA生成dropoutmask中, Philox算法的seed.
第6个输出为int, DSA生成dropoutmask中, Philox算法的offset.
第7个输出为int, DSA生成dropoutmask的长度.
约束说明
该接口仅在训练场景下使用.
该接口暂不支持图模式, 不支持aclgraph.
输入query、key、value、pse的数据类型必须一致.
输入query、key、value的input_layout必须一致.
输入query、key、value的shape说明:
1. 输入key和value的shape必须一致.
2. B: batchsize必须相等; 非varlen场景B取值范围1~2M; varlen场景B取值范围1~2K.
3. D: Head Dim必须满足Dq=Dk和Dk≥Dv,取值范围1~768.
4. S: sequence length, 取值范围1~1M.
varlen场景下:
1. 要求T(B*S)取值范围1~1M.
2. atten_mask输入不支持补pad,即atten_mask中不能存在某一行全1的场景.
支持输入query的N和key/value的N不相等, 但必须成比例关系, 即Nq/Nkv必须是非0整数, Nq取值范围1~256. 当Nq/Nkv > 1时, 即为GQA\(grouped-query attention); 当Nq/Nkv=1时,即为MHA(multi-head attention). 本文如无特殊说明, N表示的是Nq.
输入key/value的shape必须一致.
sparse_mode取值说明:
1. sparse_mode为1、2、3、4、5、6、7、8时, 应传入对应正确的atten_mask, 否则将导致计算结果错误. 当atten_mask输入为None时, sparse_mode, pre_tockens, next_tockens参数不生效, 固定为全计算.
2. sparse_mode配置为1、2、3、5、6时, 用户配置的pre_tockens、next_tockens不会生效.
3. sparse_mode配置为0、4时, 需保证atten_mask与pre_tockens、next_tockens的范围一致.
4. sparse_mode配置为7、8时,不支持可选参数pse.
prefix稀疏计算场景B不大于32, varlen场景不支持非压缩prefix, 即不支持sparse_mode=5; 当Sq>Skv时, prefix的N值取值范围[0, Skv], 当Sq<=Skv时, prefix的N值取值范围[Skv-Sq, Skv].
支持actual_seq_qlen中某个Batch上的S长度为0; 如果存在S为0的情况, 不支持pse输入, 假设真实的S长度为[2, 2, 0, 2, 2], 则传入的actual_seq_qlen为[2, 4, 4, 6, 8]. actual_seq_qlen的长度取值范围为1~2K, varlen场景下长度最大支持1K.
TND格式下, 支持尾部部分Batch不参与计算, 此时actual_seq_qlen和actual_seq_kvlen尾部传入对应个数个0即可. 假设真实的S长度为[2, 3, 4, 5, 6], 此时后两个Batch不参与计算, 则传入的actual_seq_qlen为[2, 5, 9, 0, 0].
部分场景下, 如果计算量过大可能会导致算子执行超时(aicore error类型报错, errorStr为: timeout or trap error), 此时建议做轴切分处理, 注: 这里的计算量会受B、S、N、D等参数的影响, 值越大计算量越大.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 2.0
PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品
调用示例
单算子模式调用:
import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUFlashAttention(TestCase):
def supported_op_exec(self, query, key, value, atten_mask):
scale = 0.08838
qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
qk = qk + atten_mask * (-10000.0)
softmax_res = torch.nn.functional.softmax(qk, dim=-1)
attention_out = torch.matmul(softmax_res, value)
return attention_out
def custom_op_exec(self, query, key, value, sparse_params):
scale = 0.08838
atten_mask = None
if sparse_params[0] == 0:
shape = [1, 8, 256, 256]
atten_mask_u = np.triu(np.ones(shape), k=sparse_params[1] + 1)
atten_mask_l = np.tril(np.ones(shape), k=-sparse_params[2] - 1)
atten_masks = atten_mask_u + atten_mask_l
atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu()
if sparse_params[0] == 2 or sparse_params[0] == 3 or sparse_params[0] == 4:
atten_masks = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1))
atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu()
return torch_npu.npu_fusion_attention(
query, key, value, head_num=8, input_layout="BNSD", scale=scale, sparse_mode=sparse_params[0],
atten_mask=atten_mask, pre_tockens=sparse_params[1], next_tockens=sparse_params[2])
def get_atten_mask(self, sparse_mode=0, pre_tokens=65536, next_tokens=65536):
atten_masks = []
shape = [1, 8, 256, 256]
if sparse_mode == 0:
atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1)
atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1)
atten_masks = atten_mask_u + atten_mask_l
elif sparse_mode == 1:
atten_masks = np.zeros(shape)
pre_tokens = 65536
next_tokens = 65536
elif sparse_mode == 2:
atten_masks = np.triu(np.ones(shape), k=1)
elif sparse_mode == 3:
atten_masks = np.triu(np.ones(shape), k=1)
elif sparse_mode == 4:
atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1)
atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1)
atten_masks = atten_mask_u + atten_mask_l
atten_mask = torch.tensor(atten_masks).to(torch.float16)
return atten_mask
# sparse_params = [sparse_mode, pre_tokens, next_tokens]
# Prec and prec16 indicate the precision comparison standards for float32 and float16 respectively.
# In this example, 0.01 is used as the standard. You can change the value as required.
def check_result(self, query, key, value, sparse_params):
atten_mask = self.get_atten_mask(sparse_params[0], sparse_params[1], sparse_params[2])
output = self.supported_op_exec(query.float(), key.float(), value.float(), atten_mask)
fa_result = self.custom_op_exec(query.npu(), key.npu(), value.npu(), sparse_params)
self.assertRtolEqual(output.half(), fa_result[0], prec=0.01, prec16=0.01)
def test_npu_flash_attention(self, device="npu"):
query = torch.randn(1, 8, 256, 256, dtype=torch.float16)
key = torch.randn(1, 8, 256, 256, dtype=torch.float16)
value = torch.randn(1, 8, 256, 256, dtype=torch.float16)
# sparse_params: [sparse_mode, pre_tokens, next_tokens]
sparse_params_list = [
[0, 128, 128],
[1, 65536, 65536],
[2, 65536, 0],
[3, 65536, 0],
[4, 128, 128]
]
for sparse_params in sparse_params_list:
self.check_result(query, key, value, sparse_params)
if __name__ == "__main__":
run_tests()
使用pse位置编码的示例:
import math
import torch
import torch_npu
def example_with_pse():
# Set random seeds so the example is reproducible.
torch.manual_seed(0)
torch.npu.manual_seed(0)
# B: batch size, N: head num, S: sequence length, D: head dim.
B, N, S, D = 1, 4, 16, 64
scale = 1.0 / math.sqrt(D)
query = torch.randn(B, N, S, D, dtype=torch.float16, device="npu")
key = torch.randn(B, N, S, D, dtype=torch.float16, device="npu")
value = torch.randn(B, N, S, D, dtype=torch.float16, device="npu")
# Build pse in BNSS format; it is added directly to the attention scores.
pse = torch.randn(B, N, S, S, dtype=torch.float16, device="npu") * 0.01 # BNSS
attention_score = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=N,
input_layout="BNSD",
pse=pse,
scale=scale,
keep_prob=1.0,
)[0]
# Reference implementation: softmax(scale * QK^T + pse) * V.
ref_scores = torch.matmul(query.float().cpu(), key.float().cpu().transpose(2, 3)) * scale
ref_out = torch.matmul(torch.softmax(ref_scores + pse.float().cpu(), dim=-1), value.float().cpu())
max_diff = (attention_score.float().cpu() - ref_out).abs().max().item()
print(f"attention_score shape: {attention_score.shape}, max_diff: {max_diff:.6f}")
if __name__ == "__main__":
if torch.npu.is_available():
example_with_pse()
使用sink注意力头偏置的示例:
import math
import torch
import torch_npu
def example_with_sink():
# Set random seeds so the example is reproducible.
torch.manual_seed(0)
torch.npu.manual_seed(0)
# B: batch size, N: head num, S: sequence length, D: head dim.
B, N, S, D = 1, 4, 16, 64
scale = 1.0 / math.sqrt(D)
query = torch.randn(B, N, S, D, dtype=torch.float16, device="npu")
key = torch.randn(B, N, S, D, dtype=torch.float16, device="npu")
value = torch.randn(B, N, S, D, dtype=torch.float16, device="npu")
# sink provides one extra bias value for each attention head, with shape [head_num].
sink = torch.linspace(-0.2, 0.2, steps=N, dtype=torch.float32, device="npu") # [head_num]
attention_score = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=N,
input_layout="BNSD",
sink=sink,
scale=scale,
keep_prob=1.0,
)[0]
# Reference implementation: append one sink logit, apply softmax, then drop that column before multiplying by V.
ref_scores = torch.matmul(query.float().cpu(), key.float().cpu().transpose(2, 3)) * scale
sink_scores = sink.cpu().view(1, N, 1, 1).expand(B, N, S, 1)
ref_probs = torch.softmax(torch.cat([ref_scores, sink_scores], dim=-1), dim=-1)[..., :-1]
ref_out = torch.matmul(ref_probs, value.float().cpu())
max_diff = (attention_score.float().cpu() - ref_out).abs().max().item()
print(f"attention_score shape: {attention_score.shape}, max_diff: {max_diff:.6f}")
if __name__ == "__main__":
if torch.npu.is_available():
example_with_sink()
"""
)
_add_torch_npu_docstr(
"npu_fusion_attention_v3",
"""
功能描述实现
“Transformer Attention Score”的融合计算, 实现的计算公式如下:
$y=Softmax(Mask(scale*(pse+query*key^{T}),atten_mask),keep_prob)$
$attention=Dropout(y)*value$
该接口为torch_npu.npu_fusion_attention支持图模式的版本,aclgraph支持input_layout为BNSD的场景。
接口原型
torch_npu.npu_fusion_attention_v3(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, SymInt[]? prefix=None, Tensor? actual_seq_qlen=None, Tensor? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False, str softmax_layout="") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明
query: Tensor类型, 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 综合约束请见约束说明.
key: Tensor类型, 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 综合约束请见约束说明.
value: Tensor类型, 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 综合约束请见约束说明.
head_num: int类型, 代表head个数, 数据类型支持int64. 综合约束请见约束说明.
input_layout: string类型, 代表输入query、key、value的数据排布格式, 支持BSH、SBH、BSND、BNSD、TND(actual_seq_qlen/actual_seq_kvlen需传值); 后续章节如无特殊说明, S表示query或key、value的sequence length, Sq表示query的sequence length, Skv表示key、value的sequence length, SS表示Sq*Skv.
pse: Tensor类型, 可选参数, 表示位置编码. 数据类型支持float16、bfloat16、float32, 数据格式支持ND. 非varlen场景支持四维输入, 包含BNSS格式、BN1Skv格式、1NSS格式. 如果非varlen场景Sq大于1024或varlen场景、每个batch的Sq与Skv等长且是sparse_mode为0、2、3的下三角掩码场景, 可使能alibi位置编码压缩, 此时只需要输入原始PSE最后1024行进行内存优化, 即alibi_compress = ori_pse[:, :, -1024:, :], 参数每个batch不相同时, 输入BNHSkv(H=1024), 每个batch相同时, 输入1NHSkv(H=1024).
padding_mask: Tensor类型, 暂不支持该传参.
atten_mask: Tensor类型, 可选参数, 取值为1代表该位不参与计算(不生效), 为0代表该位参与计算, 数据类型支持bool、uint8, 数据格式支持ND, 输入shape类型支持BNSS格式、B1SS格式、11SS格式、SS格式. varlen场景只支持SS格式, SS分别是maxSq和maxSkv. 综合约束请见约束说明.
scale: 浮点型, 可选参数, 代表缩放系数, 作为计算流中Muls的scalar值, 数据类型支持float, 默认值为1.
keep_prob: 浮点型, 可选参数, 代表Dropout中1的比例, 取值范围为(0, 1]. 数据类型支持float, 默认值为1, 表示全部保留.
pre_tockens: 整型, 用于稀疏计算的参数, 可选参数, 数据类型支持int64, 默认值为2147483647. 综合约束请见约束说明.
next_tockens: 整型, 用于稀疏计算的参数, 可选参数, 数据类型支持int64, 默认值为2147483647. next_tockens和pre_tockens取值与atten_mask的关系请参见sparse_mode参数, 参数取值与atten_mask分布不一致会导致精度问题. 综合约束请见约束说明.
inner_precise: 整型, 用于提升精度, 数据类型支持int64, 默认值为0.
当前0、1为保留配置值, 2为使能无效行计算, 其功能是避免在计算过程中存在整行mask进而导致精度有损失, 但是该配置会导致性能下降.
如果算子可判断出存在无效行场景, 会自动使能无效行计算, 例如sparse_mode为3, Sq > Skv场景.
prefix: int类型数组, 可选参数, 代表prefix稀疏计算场景每个Batch的N值. 数据类型支持int64, 数据格式支持ND. 综合约束请见约束说明.
actual_seq_qlen: Tensor类型, 可选参数, varlen场景时需要传入此参数. 表示query每个S的累加和长度, 数据类型支持int64, 数据格式支持ND. 综合约束请见约束说明.
比如真正的S长度列表为: 2 2 2 2 2, 则actual_seq_qlen传: 2 4 6 8 10.
actual_seq_kvlen: Tensor类型, 可选参数, varlen场景时需要传入此参数. 表示key/value每个S的累加和长度. 数据类型支持int64, 数据格式支持ND. 综合约束请见约束说明.
比如真正的S长度列表为: 2 2 2 2 2, 则actual_seq_kvlen传: 2 4 6 8 10.
sparse_mode: 整型, 表示sparse的模式, 可选参数. 数据类型支持int64, 默认值为0, 支持配置值为0、1、2、3、4、5、6、7、8. 当整网的atten_mask都相同且shape小于2048*2048时, 建议使用defaultMask模式, 来减少内存使用量. 综合约束请见约束说明.
softmax_layout: string类型,可选参数,用于控制TND场景下softmax的输出(softmax_max和softmax_sum)的数据排布方式。当前仅在input\_layout=“TND”时进行配置,仅支持传入“TND”。默认情况下,softmax的输出排布为NTD排布;传入TND时,softmax的输出排布为TND排布。
表1 sparse_mode不同取值场景说明
sparse_mode
0: defaultMask模式.
1: allMask模式.
2: leftUpCausal模式.
3: rightDownCausal模式.
4: band模式.
5: prefix非压缩模式. varlen场景不支持.
6: prefix压缩模式.
7: varlen外切场景, rightDownCausal模式. 仅varlen场景支持.
8: varlen外切场景, leftUpCausal模式. 仅varlen场景支持.
atten_mask的工作原理为, 在Mask为True的位置遮蔽query(Q)与key(K)的转置矩阵乘积的值. QKT矩阵在atten_mask为True的位置会被遮蔽
说明: 保留该值, atten_mask中, 应该配置为False; 遮蔽该值, atten_mask中应配置为True. sparse_mode为0时, 代表defaultMask模式. 不传mask: 如果atten_mask未传入则不做mask操作, atten_mask取值为None, 忽略pre_tockens和next_tockens取值.
next_tockens取值为0, pre_tockens大于等于Sq, 表示causal场景sparse, atten_mask应传入下三角矩阵, 此时pre_tockens和next_tockens之间的部分需要计算,atten_mask应传入下三角矩阵
pre_tockens小于Sq, next_tockens小于Skv, 且都大于等于0, 表示band场景, 此时pre_tockens和next_tockens之间的部分需要计算. atten_mask应传入band形状矩阵
next_tockens为负数, 以pre_tockens=9, next_tockens=-3为例, pre_tockens和next_tockens之间的部分需要计算. 说明: next_tockens为负数时, pre_tockens取值必须大于等于next_tockens的绝对值, 且next_tockens的绝对值小于Skv.
pre_tockens为负数, 以next_tockens=7, pre_tockens=-3为例, pre_tockens和next_tockens之间的部分需要计算. 说明: pre_tockens为负数时, next_tockens取值必须大于等于pre_tockens的绝对值, 且pre_tockens的绝对值小于Sq.
sparse_mode为1时, 代表allMask, 即传入完整的atten_mask矩阵. 该场景下忽略next_tockens、pre_tockens取值
sparse_mode为2时, 代表leftUpCausal模式的mask, 对应以左上顶点划分的下三角场景(参数起点为左上角). 该场景下忽略pre_tockens、next_tockens取值.传入的atten_mask为优化后的压缩下三角矩阵(2048*2048)
sparse_mode为3时, 代表rightDownCausal模式的mask, 对应以右下顶点划分的下三角场景(参数起点为右下角). 该场景下忽略pre_tockens、next_tockens取值. atten_mask为优化后的压缩下三角矩阵(2048*2048)
sparse_mode为4时, 代表band场景, 即计算pre_tockens和next_tockens之间的部分, 参数起点为右下角, pre_tockens和next_tockens之间需要有交集. atten_mask为优化后的压缩下三角矩阵(2048*2048).
sparse_mode为5时, 代表prefix非压缩场景, 即在rightDownCasual的基础上, 左侧加上一个长为Sq, 宽为N的矩阵, N的值由可选输入prefix获取, 例如下图中表示batch=2场景下prefix传入数组[4,5], 每个batch轴的N值可以不一样, 参数起点为左上角. 该场景下忽略pre_tockens、next_tockens取值, atten_mask矩阵数据格式须为BNSS或B1SS
sparse_mode为6时, 代表prefix压缩场景, 即prefix场景时, attenMask为优化后的压缩下三角+矩形的矩阵(3072*2048): 其中上半部分[2048, 2048]的下三角矩阵, 下半部分为[1024,2048]的矩形矩阵, 矩形矩阵左半部分全0, 右半部分全1. 该场景下忽略pre_tockens、next_tockens取值.
sparse_mode为7时, 表示varlen且为长序列外切场景(即长序列在模型脚本中进行多卡切query的sequence length); 用户需要确保外切前为使用sparse_mode 3的场景; 当前mode下用户需要设置pre_tockens和next_tockens(起点为右下顶点), 且需要保证参数正确, 否则会存在精度问题. Masked QKT矩阵示意如下, 在第二个batch对query进行切分, key和value不切分, 4x6的mask矩阵被切分成2x6和2x6的mask, 分别在卡1和卡2上计算: 卡1的最后一块mask为band类型的mask, 配置pre_tockens=6(保证大于等于最后一个Skv), next_tockens=-2, actual_seq_qlen应传入{3,5}, actual_seq_kvlen应传入{3,9}. 卡2的mask类型切分后不变, sparse_mode为3, actual_seq_qlen应传入{2,7,11}, actual_seq_kvlen应传入{6,11,15}.
如果配置sparse_mode=7, 但实际只存在一个batch, 用户需按照band模式的要求来配置参数; sparse_mode=7时, 用户需要输入2048x2048的下三角mask作为该融合算子的输入.
基于sparse_mode=3进行外切产生的band模式的sparse的参数应符合以下条件:
pre_tockens >= last_Skv.
next_tockens <= 0.
当前模式下不支持可选输入pse.
sparse_mode为8时, 表示varlen且为长序列外切场景; 用户需要确保外切前为使用sparse_mode 2的场景; 当前mode下用户需要设置pre_tockens和next_tockens(起点为右下顶点), 且需要保证参数正确, 否则会存在精度问题. Masked QKT矩阵示意如下, 在第二个batch对query进行切分, key和value不切分, 5x4的mask矩阵被切分成2x4和3x4的mask, 分别在卡1和卡2上计算: 卡1的mask类型切分后不变, sparse_mode为2, actual_seq_qlen应传入{3,5}, actual_seq_kvlen应传入{3,7}. 卡2的第一块mask为band类型的mask, 配置pre_tockens=4(保证大于等于第一个Skv), next_tockens=1, actual_seq_qlen应传入{3,8,12}, actual_seq_kvlen应传入{4,9,13}.
如果配置sparse_mode=8, 但实际只存在一个batch, 用户需按照band模式的要求来配置参数; sparse_mode=8时, 用户需要输入2048x2048的下三角mask作为该融合算子的输入.
基于sparse_mode=2进行外切产生的band模式的sparse的参数应符合以下条件:
pre_tockens >= first_Skv.
next_tockens范围无约束, 根据实际情况进行配置.
当前模式下不支持可选输入pse.
gen_mask_parallel: 布尔型, DSA生成dropout随机数向量mask的控制开关. 默认值为True: 同AI Core计算并行; 设为False: 同AI Core计算串行.
sync: 布尔型, DSA生成dropout随机数向量mask的控制开关. 默认值为False: dropout mask异步生成; 设为True: dropout mask同步生成.
输出说明
共6个输出, 类型依次为Tensor、Tensor、Tensor、Tensor、Tensor、Tensor.
第1个输出为Tensor, 计算公式的最终输出attention_out, 数据类型支持float16、bfloat16、float32.
第2个输出为Tensor, Softmax计算的Max中间结果, 用于反向计算, 数据类型支持float.
第3个输出为Tensor, Softmax计算的Sum中间结果, 用于反向计算, 数据类型支持float.
第4个输出为Tensor, 预留参数, 暂未使用.
第5个输出为Tensor, DSA生成dropoutmask中, Philox算法的seed. 在aclgraph场景下,返回的是npu Tensor,在非aclgraph场景下,返回的是cpu Tensor.
第6个输出为Tensor, DSA生成dropoutmask中, Philox算法的offset. 在aclgraph场景下,返回的是npu Tensor,在非aclgraph场景下,返回的是cpu Tensor.
约束说明
该接口仅在训练场景下使用.
输入query、key、value、pse的数据类型必须一致.
输入query、key、value的input_layout必须一致.
输入query、key、value的shape说明:
1. 输入key和value的shape必须一致.
2. B: batchsize必须相等; 非varlen场景B取值范围1~2M; varlen场景B取值范围1~2K.
3. D: Head Dim必须满足Dq=Dk和Dk≥Dv,取值范围1~768.
4. S: sequence length, 取值范围1~1M.
varlen场景下:
1. 要求T(B*S)取值范围1~1M.
2. atten_mask输入不支持补pad,即atten_mask中不能存在某一行全1的场景.
支持输入query的N和key/value的N不相等, 但必须成比例关系, 即Nq/Nkv必须是非0整数, Nq取值范围1~256. 当Nq/Nkv > 1时, 即为GQA\(grouped-query attention); 当Nq/Nkv=1时,即为MHA(multi-head attention). 本文如无特殊说明, N表示的是Nq.
输入key/value的shape必须一致.
sparse_mode取值说明:
1. sparse_mode为1、2、3、4、5、6、7、8时, 应传入对应正确的atten_mask, 否则将导致计算结果错误. 当atten_mask输入为None时, sparse_mode, pre_tockens, next_tockens参数不生效, 固定为全计算.
2. sparse_mode配置为1、2、3、5、6时, 用户配置的pre_tockens、next_tockens不会生效.
3. sparse_mode配置为0、4时, 需保证atten_mask与pre_tockens、next_tockens的范围一致.
4. sparse_mode配置为7、8时,不支持可选参数pse.
prefix稀疏计算场景B不大于32, varlen场景不支持非压缩prefix, 即不支持sparse_mode=5; 当Sq>Skv时, prefix的N值取值范围[0, Skv], 当Sq<=Skv时, prefix的N值取值范围[Skv-Sq, Skv].
支持actual_seq_qlen中某个Batch上的S长度为0; 如果存在S为0的情况, 不支持pse输入, 假设真实的S长度为[2, 2, 0, 2, 2], 则传入的actual_seq_qlen为[2, 4, 4, 6, 8]. actual_seq_qlen的长度取值范围为1~2K, varlen场景下长度最大支持1K.
TND格式下, 支持尾部部分Batch不参与计算, 此时actual_seq_qlen和actual_seq_kvlen尾部传入对应个数个0即可. 假设真实的S长度为[2, 3, 4, 5, 6], 此时后两个Batch不参与计算, 则传入的actual_seq_qlen为[2, 5, 9, 0, 0].
部分场景下, 如果计算量过大可能会导致算子执行超时(aicore error类型报错, errorStr为: timeout or trap error), 此时建议做轴切分处理, 注: 这里的计算量会受B、S、N、D等参数的影响, 值越大计算量越大.
支持的PyTorch版本
PyTorch 2.6+
支持的型号
Atlas A2 训练系列产品
调用示例
单算子模式调用:
import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUFlashAttention(TestCase):
def supported_op_exec(self, query, key, value, atten_mask):
scale = 0.08838
qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
qk = qk + atten_mask * (-10000.0)
softmax_res = torch.nn.functional.softmax(qk, dim=-1)
attention_out = torch.matmul(softmax_res, value)
return attention_out
def custom_op_exec(self, query, key, value, sparse_params):
scale = 0.08838
atten_mask = None
if sparse_params[0] == 0:
shape = [1, 8, 256, 256]
atten_mask_u = np.triu(np.ones(shape), k=sparse_params[1] + 1)
atten_mask_l = np.tril(np.ones(shape), k=-sparse_params[2] - 1)
atten_masks = atten_mask_u + atten_mask_l
atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu()
if sparse_params[0] == 2 or sparse_params[0] == 3 or sparse_params[0] == 4:
atten_masks = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1))
atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu()
return torch_npu.npu_fusion_attention_v3(
query, key, value, head_num=8, input_layout="BNSD", scale=scale, sparse_mode=sparse_params[0],
atten_mask=atten_mask, pre_tockens=sparse_params[1], next_tockens=sparse_params[2])
def get_atten_mask(self, sparse_mode=0, pre_tokens=65536, next_tokens=65536):
atten_masks = []
shape = [1, 8, 256, 256]
if sparse_mode == 0:
atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1)
atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1)
atten_masks = atten_mask_u + atten_mask_l
elif sparse_mode == 1:
atten_masks = np.zeros(shape)
pre_tokens = 65536
next_tokens = 65536
elif sparse_mode == 2:
atten_masks = np.triu(np.ones(shape), k=1)
elif sparse_mode == 3:
atten_masks = np.triu(np.ones(shape), k=1)
elif sparse_mode == 4:
atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1)
atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1)
atten_masks = atten_mask_u + atten_mask_l
atten_mask = torch.tensor(atten_masks).to(torch.float16)
return atten_mask
# sparse_params = [sparse_mode, pre_tokens, next_tokens]
# Prec and prec16 indicate the precision comparison standards for float32 and float16 respectively.
# In this example, 0.01 is used as the standard. You can change the value as required.
def check_result(self, query, key, value, sparse_params):
atten_mask = self.get_atten_mask(sparse_params[0], sparse_params[1], sparse_params[2])
output = self.supported_op_exec(query.float(), key.float(), value.float(), atten_mask)
fa_result = self.custom_op_exec(query.npu(), key.npu(), value.npu(), sparse_params)
self.assertRtolEqual(output.half(), fa_result[0], prec=0.01, prec16=0.01)
def test_npu_flash_attention(self, device="npu"):
query = torch.randn(1, 8, 256, 256, dtype=torch.float16)
key = torch.randn(1, 8, 256, 256, dtype=torch.float16)
value = torch.randn(1, 8, 256, 256, dtype=torch.float16)
# sparse_params: [sparse_mode, pre_tokens, next_tokens]
sparse_params_list = [
[0, 128, 128],
[1, 65536, 65536],
[2, 65536, 0],
[3, 65536, 0],
[4, 128, 128]
]
for sparse_params in sparse_params_list:
self.check_result(query, key, value, sparse_params)
if __name__ == "__main__":
run_tests()
"""
)
_add_torch_npu_docstr(
"npu_geglu",
"""
torch_npu. npu_geglu(Tensor self, int dim=-1, int approximate=1) -> (Tensor, Tensor)
功能描述
对输入Tensor完成GeGlu运算。
参数说明
Tensor self:待进行GeGlu计算的入参,npu device侧的aclTensor,数据类型支持FLOAT32、FLOAT16、BFLOAT16(Atlas A2 训练系列产品支持),支持非连续的Tensor,数据格式支持ND。
int dim:可选入参,设定的slice轴,数据类型支持INT64。
int approximate:可选入参,GeGlu计算使用的激活函数索引,0表示使用none,1表示使用tanh,数据类型支持INT64。
out:GeGlu计算的出参,npu device侧的aclTensor,数据类型必须和self一致,支持非连续的Tensor,数据格式支持ND。
outGelu:GeGlu计算的出参,npu device侧的aclTensor,数据类型必须和self一致,支持非连续的Tensor,数据格式支持ND。
约束说明
out、outGelu在dim维的size等于self在dim维size的1/2。
当self.dim()==0时,dim的取值在[-1, 0]范围内;当self.dim()>0时,取值在[-self.dim(), self.dim()-1]范围内。
示例
data_x = np.random.uniform(-2, 2, [24,9216,2560]).astype(np.float16)
x_npu = torch.from_numpy(data_x).npu()
x_npu:
tensor([[[ 0.8750, 0.4766, -0.3535, ..., -1.4619, 0.3542, -1.8389],
[ 0.9424, -0.0291, 0.9482, ..., 0.5640, -1.2959, 1.7666],
[-0.4958, -0.6787, 0.0179, ..., 0.4365, -0.8311, -1.7676],
...,
[-1.1611, 1.4766, -1.1934, ..., -0.5913, 1.1553, -0.4626],
[ 0.4873, -1.8105, 0.5723, ..., 1.3193, -0.1558, -1.6191],
[ 1.6816, -1.2080, -1.6953, ..., -1.3096, 0.4158, -1.2168]],
[[ 1.4287, -1.9863, 1.4053, ..., -1.7676, -1.6709, -1.1582],
[-1.3281, -1.9043, 0.7725, ..., -1.5596, 0.1632, -1.0732],
[ 1.0254, -1.6650, 0.1318, ..., -0.8159, -0.7134, -0.4536],
...,
[ 0.0327, -0.6206, -0.1492, ..., -1.2559, 0.3777, -1.2822],
[-1.1904, 1.1260, -1.3369, ..., -1.4814, 0.4463, 1.0205],
[-0.1192, 1.7783, 0.1040, ..., 1.0010, 1.5342, -0.5728]],
[[-0.3296, 0.5703, 0.6338, ..., 0.2131, 1.1113, 0.9854],
[ 1.4336, -1.7568, 1.8164, ..., -1.2012, -1.8721, 0.6904],
[ 0.6934, 0.3743, -0.9448, ..., -0.9946, -1.6494, -1.3564],
...,
[ 1.1855, -0.9663, -0.8252, ..., 0.2285, -1.5684, -0.4277],
[ 1.1260, 1.2871, 1.2754, ..., -0.5171, -1.1064, 0.9624],
[-1.4639, -0.0661, -1.7178, ..., 1.2656, -1.9023, -1.1641]],
...,
[[-1.8350, 1.0625, 1.6172, ..., 1.4160, 1.2490, 1.9775],
[-0.5615, -1.9990, -0.5996, ..., -1.9404, 0.5068, -0.9829],
[-1.0771, -1.5537, -1.5654, ..., 0.4678, -1.5215, -1.7920],
...,
[-1.3389, -0.3228, -1.1514, ..., 0.8882, -1.9971, 1.2432],
[-1.5439, -1.8154, -1.9238, ..., 0.2556, 0.2131, -1.7471],
[-1.1074, 1.0391, 0.1556, ..., 1.1689, 0.6470, 0.2463]],
[[ 1.2617, -0.8911, 1.9160, ..., -0.3027, 1.7764, 0.3381],
[-1.4160, 1.6201, -0.5396, ..., 1.8271, 1.3086, -1.8770],
[ 1.8252, 1.3779, -0.3535, ..., -1.5215, -1.4727, -1.0420],
...,
[-1.4600, -1.7617, -0.7754, ..., 0.4697, -0.4734, -0.3838],
[ 1.8506, -0.3945, -0.0142, ..., -1.3447, -0.6587, 0.5728],
[ 1.1523, -1.8027, 0.4731, ..., 0.5464, 1.4014, -1.8594]],
[[-0.1467, -0.5752, 0.3298, ..., -1.9902, -1.8281, 1.8506],
[ 0.2473, 1.0693, -1.8184, ..., 1.9277, 1.6543, 1.0088],
[ 0.0804, -0.7939, 1.3486, ..., -1.1543, -0.4053, -0.0055],
...,
[ 0.3672, 0.3274, -0.3369, ..., 1.4951, -1.9580, -0.7847],
[ 1.3525, -0.4780, -0.5000, ..., -0.1610, -1.9209, 1.5498],
[ 0.4905, -1.7832, 0.4243, ..., 0.9492, 0.3335, 0.9565]]],
device='npu:0', dtype=torch.float16)
y_npu, y_gelu_npu = torch_npu.npu_geglu(x_npu, dim=-1, approximate=1)
y_npu:
tensor([[[-9.2590e-02, -1.2054e-01, 1.6980e-01, ..., -6.8542e-02,
-2.5254e+00, -6.9519e-02],
[ 1.2405e-02, -1.4902e+00, 8.0750e-02, ..., 3.4570e-01,
-1.5029e+00, 2.8442e-01],
[-9.0271e-02, 4.3335e-01, -1.7402e+00, ..., 1.3574e-01,
-5.5762e-01, -1.3123e-01],
...,
[ 1.0004e-01, 1.5312e+00, 1.4189e+00, ..., -2.6172e-01,
1.6113e-01, -1.1887e-02],
[-5.9845e-02, 2.0911e-01, -6.4735e-03, ..., 5.1422e-02,
2.6289e+00, 2.5977e-01],
[ 1.3649e-02, -1.3329e-02, -6.9031e-02, ..., 3.5977e+00,
-1.2178e+00, -2.3242e+00]],
[[-3.1816e+00, -2.6719e+00, 1.4038e-01, ..., 2.6660e+00,
7.7820e-02, 2.3999e-01],
[ 2.9297e+00, -1.7754e+00, 2.6703e-02, ..., -1.3318e-01,
-6.2109e-01, -1.9072e+00],
[ 1.1316e-01, 5.8887e-01, 8.2959e-01, ..., 1.1273e-01,
1.1481e-01, 4.2419e-02],
...,
[-2.6831e-01, -1.7288e-02, 2.6343e-01, ..., 9.3750e-02,
-2.2324e+00, 1.2894e-02],
[-2.0630e-01, 5.9619e-01, -1.4210e-03, ..., -1.2598e-01,
-6.5552e-02, 1.1115e-01],
[-1.6143e+00, -1.6150e-01, -4.9774e-02, ..., 8.6426e-02,
1.1879e-02, -1.9795e+00]],
[[ 4.3152e-02, 1.9250e-01, -4.7485e-02, ..., -5.8632e-03,
1.4551e-01, -2.1289e+00],
[ 4.7951e-03, 2.0691e-01, 4.4458e-01, ..., 4.7485e-02,
-4.8889e-02, 1.5684e+00],
[-8.9404e-01, -8.0420e-01, -2.9248e-01, ..., 1.6205e-02,
3.5449e+00, 8.2397e-02],
...,
[-1.9385e+00, -1.8838e+00, 6.0010e-01, ..., -8.5059e-01,
6.1829e-02, 1.0547e-01],
[-5.1086e-02, -1.0760e-01, -7.1228e-02, ..., -9.2468e-02,
4.7900e-01, -3.5278e-01],
[ 1.7078e-01, 1.6846e-01, 2.5528e-02, ..., 1.3708e-01,
1.4954e-01, -2.8418e-01]],
...,
[[-6.3574e-01, -2.0156e+00, 9.3994e-02, ..., 2.2402e+00,
-6.2218e-03, 8.7402e-01],
[ 1.5010e+00, -1.8518e-01, -3.0930e-02, ..., 1.1511e-01,
-3.8300e-02, -1.6150e-01],
[-2.8442e-01, 4.4373e-02, -1.0022e-01, ..., 9.2468e-02,
-1.2524e-01, -1.2115e-01],
...,
[ 3.4760e-02, 1.9812e-01, -9.1431e-02, ..., -1.1650e+00,
2.4011e-01, -1.0919e-01],
[-1.5283e-01, 1.8535e+00, 4.4360e-01, ..., 6.4844e-01,
-2.8784e-01, -2.5938e+00],
[-9.9915e-02, 4.6436e-01, 6.6528e-02, ..., -1.2817e-01,
-1.5686e-01, -5.4962e-02]],
[[-2.3279e-01, 4.5630e-01, -5.4834e-01, ..., 5.9013e-03,
-4.7974e-02, -2.7617e+00],
[-1.0760e-01, -2.0371e+00, 3.7915e-01, ..., 6.4551e-01,
2.6953e-01, -1.0910e-03],
[ 4.9683e-01, 1.2402e+00, -1.0429e-02, ..., 3.4294e-03,
-8.2959e-01, 1.2012e-01],
...,
[ 1.6956e-01, 5.3027e-01, -1.6418e-01, ..., -2.1094e-01,
-9.8267e-02, 2.3364e-01],
[ 4.1687e-02, -1.1365e-01, 1.2598e+00, ..., -5.6299e-01,
1.5967e+00, 9.3445e-02],
[ 9.7656e-02, -4.5410e-01, -2.9395e-01, ..., -1.6565e-01,
-8.2153e-02, -7.0068e-01]],
[[ 1.6345e-01, 2.5806e-01, -6.1951e-02, ..., -6.5857e-02,
-6.0303e-02, -1.9080e-01],
[ 1.9666e-01, 1.8262e+00, -1.1951e-01, ..., 1.0138e-01,
-2.0911e-01, -6.0638e-02],
[-6.9141e-01, -2.5234e+00, -1.2734e+00, ..., 1.0510e-01,
-1.6504e+00, -9.7070e-01],
...,
[-2.5406e-03, -3.1342e-02, -7.0862e-02, ..., 9.2041e-02,
7.7271e-02, 8.0518e-01],
[-1.5161e-01, -6.8848e-02, 7.0801e-01, ..., 7.0166e-01,
-3.3661e-02, -1.4319e-01],
[-3.0899e-02, 1.4490e-01, 1.9763e-01, ..., -8.1116e-02,
7.8955e-01, 1.8347e-01]]], device='npu:0', dtype=torch.float16)
y_gelu_npu:
tensor([[[-1.5771e-01, -1.4331e-01, -1.0846e-01, ..., -1.1133e-01,
1.3818e+00, -1.5076e-01],
[-1.8600e-02, 1.6904e+00, -6.9336e-02, ..., 3.6890e-01,
1.6768e+00, 2.5146e-01],
[ 7.5342e-01, 6.0742e-01, 1.0820e+00, ..., 1.5063e-01,
1.1572e+00, -9.4482e-02],
...,
[-1.5796e-01, 8.4082e-01, 9.2627e-01, ..., -1.6064e-01,
-1.1096e-01, -1.6370e-01],
[ 3.4814e-01, -1.6418e-01, -3.1982e-02, ..., -1.5186e-01,
1.3330e+00, -1.4111e-01],
[-8.4778e-02, -1.1023e-01, -1.0669e-01, ..., 1.9521e+00,
9.5654e-01, 1.5635e+00]],
[[ 1.7881e+00, 1.8359e+00, -1.6663e-01, ..., 1.4609e+00,
-1.6760e-01, -1.6528e-01],
[ 1.9434e+00, 1.7168e+00, -1.1615e-01, ..., -9.8816e-02,
9.4043e-01, 1.2344e+00],
[-1.6064e-01, 5.7031e-01, 1.6475e+00, ..., -1.0809e-01,
-1.6785e-01, -1.6345e-01],
...,
[-1.6797e-01, -4.6326e-02, 2.6904e-01, ..., 6.9458e-02,
1.3174e+00, 1.3486e+00],
[-1.0645e-01, 3.0249e-01, -9.9411e-03, ..., -1.3928e-01,
-1.0974e-01, -7.1533e-02],
[ 1.7012e+00, -1.0254e-01, -8.2825e-02, ..., -4.8492e-02,
-1.1926e-01, 1.7490e+00]],
[[-6.6650e-02, -1.0370e-01, -2.3788e-02, ..., -1.0706e-01,
-1.6980e-01, 1.4209e+00],
[-5.2986e-03, -1.1133e-01, 2.5439e-01, ..., -3.9459e-02,
-6.8909e-02, 1.2119e+00],
[ 6.1035e-01, 6.8506e-01, -1.5039e-01, ..., 5.8136e-02,
1.8232e+00, -6.7383e-02],
...,
[ 1.4434e+00, 1.6787e+00, 1.2422e+00, ..., 7.5488e-01,
-5.0720e-02, -6.8787e-02],
[-1.4600e-01, -1.2213e-01, -1.6711e-01, ..., 3.7280e-01,
1.3125e+00, 2.2375e-01],
[ 3.4985e-01, -1.2659e-01, -4.6722e-02, ..., -1.4685e-01,
1.4856e-01, -1.6406e-01]],
...,
[[ 4.8730e-01, 1.6680e+00, -5.7098e-02, ..., 1.4189e+00,
7.1983e-03, 7.8857e-01],
[ 1.1328e+00, -1.6931e-01, -1.1163e-01, ..., -1.6467e-01,
3.5309e-02, -1.5173e-01],
[-1.6858e-01, -8.9111e-02, -1.4709e-01, ..., -8.1970e-02,
5.4248e-01, 5.0830e-01],
...,
[ 2.1936e-01, 7.7197e-01, 4.8737e-02, ..., 8.7842e-01,
-1.6406e-01, -7.1716e-02],
[-1.2720e-01, 1.9404e+00, 1.0391e+00, ..., 7.3877e-01,
-1.6199e-01, 1.5781e+00],
[-1.6968e-01, 1.0664e+00, -1.6431e-01, ..., -7.5439e-02,
-1.5332e-01, 2.1790e-01]],
[[ 3.0981e-01, 6.0010e-01, 7.9346e-01, ..., 4.0169e-03,
5.8447e-01, 1.7109e+00],
[-1.6699e-01, 1.7646e+00, 5.9326e-01, ..., 3.3813e-01,
-1.5845e-01, -4.7699e-02],
[ 3.7573e-01, 9.4580e-01, -9.5276e-02, ..., 2.4805e-01,
8.3350e-01, 1.2573e-01],
...,
[-1.5369e-01, 1.2021e+00, -1.6626e-01, ..., -1.1108e-01,
1.6084e+00, -1.4807e-01],
[-4.6234e-02, -6.4331e-02, 8.9844e-01, ..., 9.2871e-01,
7.9834e-01, -1.6992e-01],
[-6.4941e-02, 1.1465e+00, -1.5161e-01, ..., -1.5076e-01,
-8.6487e-02, 1.0137e+00]],
[[-1.1731e-01, -1.4404e-01, -8.9050e-02, ..., -1.2128e-01,
-1.0919e-01, -1.6943e-01],
[ 1.5186e-01, 1.1396e+00, -6.5735e-02, ..., -7.4829e-02,
-1.6455e-01, -8.9355e-02],
[ 6.4404e-01, 1.5625e+00, 1.7725e+00, ..., -5.5176e-02,
1.7920e+00, 6.6504e-01],
...,
[ 1.9083e-03, 3.8452e-01, -4.9011e-02, ..., -1.5405e-01,
-1.6003e-01, 1.3975e+00],
[ 1.0437e-01, -8.6182e-02, 5.5713e-01, ..., 1.0645e+00,
-1.3818e-01, 5.1562e-01],
[-1.0229e-01, -1.0529e-01, 2.6562e-01, ..., -5.6702e-02,
1.0830e+00, -1.6833e-01]]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_get_float_status",
"""
torch_npu.npu_get_float_status(self) -> Tensor
功能描述
计算npu_get_float_status算子函数。
参数说明
self (Tensor) - 数据内存地址张量,数据类型为float32。
示例
>>> x = torch.rand(2).npu()
>>> torch_npu.npu_get_float_status(x)
tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_giou",
"""
torch_npu.npu_giou(self, gtboxes, trans=False, is_cross=False, mode=0) -> Tensor
功能描述
首先计算两个框的最小封闭面积和IoU,然后计算封闭区域中不属于两个框的封闭面积的比例,最后从IoU中减去这个比例,得到GIoU。
参数说明
self (Tensor) - 标注框,shape为(N, 4) 数据类型为float16或float32的2D张量。“N”表示标注框的数量,值“4”表示[x1, y1, x2, y2]或[x, y, w, h]。
gtboxes (Tensor) - 真值框,shape为(M, 4) 数据类型为float16或float32的2D张量。“M”表示真值框的数量,值“4”表示[x1, y1, x2, y2]或[x, y, w, h]。
trans (Bool,默认值为False) - 值为True代表“xywh”,值为False代表“xyxy”。
is_cross (Bool,默认值为False) - 控制输出shape是[M, N]还是[1,N]。如果值为True,则输出shape为[M,N]。如果为False,则输出shape为[1,N]。
mode (Int,默认值为0) - 计算模式,取值为0或1。0表示IoU,1表示IoF。
示例
>>> a=np.random.uniform(0,1,(4,10)).astype(np.float16)
>>> b=np.random.uniform(0,1,(4,10)).astype(np.float16)
>>> box1=torch.from_numpy(a).to("npu")
>>> box2=torch.from_numpy(a).to("npu")
>>> output = torch_npu.npu_giou(box1, box2, trans=True, is_cross=False, mode=0)
>>> output
tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_grid_assign_positive",
"""
torch_npu.npu_grid_assign_positive(self, overlaps, box_responsible_flags, max_overlaps, argmax_overlaps, gt_max_overlaps, gt_argmax_overlaps, num_gts, pos_iou_thr, min_pos_iou, gt_max_assign_all) -> Tensor
功能描述
执行position-sensitive的候选区域池化梯度计算。
参数说明
self (Tensor) - float16或float32类型的张量, shape为(n, )。
overlaps (Tensor) - 数据类型与assigned_gt_inds相同,表示gt_bboxes和bboxes之间的IoU,shape为(k,n)。
box_responsible_flags (Tensor) - 支持uint8数据类型。表示框是否responsible的标志。
max_overlaps (Tensor) - 数据类型与assigned_gt_inds. overlaps.max(axis=0)相同。
argmax_overlaps (Tensor) - 支持uint32数据类型,overlaps.argmax(axis=0)。
gt_max_overlaps (Tensor) - 数据类型与assigned_gt_inds. overlaps.max(axis=1)相同。
gt_argmax_overlaps (Tensor) - 支持uint32数据类型, overlaps.argmax(axis=1)。
num_gts (Tensor) - 支持uint32数据类型,real k ,shape为 (1, )。
pos_iou_thr (Float) - 正检测框的IoU阈值。
min_pos_iou (Float) - 检测框被视为正检测框的最小IoU
gt_max_assign_all (Bool) - 是否将与某个gt有相同最高重叠的所有检测框分配给该gt。
示例
>>> assigned_gt_inds = torch.rand(4).npu()
>>> overlaps = torch.rand(2,4).npu()
>>> box_responsible_flags = torch.tensor([1, 1, 1, 0], dtype=torch.uint8).npu()
>>> max_overlap = torch.rand(4).npu()
>>> argmax_overlap = torch.tensor([1, 0, 1, 0], dtype=torch.int32).npu()
>>> gt_max_overlaps = torch.rand(2).npu()
>>> gt_argmax_overlaps = torch.tensor([1, 0],dtype=torch.int32).npu()
>>> output = torch_npu.npu_grid_assign_positive(assigned_gt_inds, overlaps, box_responsible_flags, max_overlap, argmax_overlap, gt_max_overlaps, gt_argmax_overlaps, 128, 0.5, 0., True)
>>> output.shape
torch.Size([4])
"""
)
_add_torch_npu_docstr(
"npu_gru",
"""
torch_npu.npu_gru(input, hx, weight_input, weight_hidden, bias_input, bias_hidden, seq_length, has_biases, num_layers, dropout, train, bidirectional, batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述
计算DynamicGRUV2。
参数说明
input (Tensor) - 数据类型:float16;格式:FRACTAL_NZ。
hx (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
weight_input (Tensor) - 数据类型:float16;格式:FRACTAL_Z。
weight_hidden (Tensor) - 数据类型:float16;格式:FRACTAL_Z。
bias_input (Tensor) - 数据类型:float16, float32;格式:ND。
bias_hidden (Tensor) - 数据类型:float16, float32;格式:ND。
seq_length (Tensor) - 数据类型:int32;格式:ND。
has_biases (Bool,默认值为True)
num_layers (Int)
dropout (Float)
train (Bool,默认值为True) - 标识训练是否在op进行的bool参数。
bidirectional (Bool,默认值为True)
batch_first (Bool,默认值为True)
输出说明
y (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
output_h (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
update (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
reset (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
new (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
hidden_new (Tensor) - 数据类型:float16, float32;格式:FRACTAL_NZ。
"""
)
_add_torch_npu_docstr(
"npu_hans_encode",
"""
torch_npu.npu_hans_encode(input, statistic, reshuff, out=(pdf, mantissa, fixed, var))
功能描述
对输入张量基于概率密度分布(PDF)进行无损压缩
参数说明
input: Device侧的Tensor类型,表示输入的待压缩张量;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型;输入Shape无限制,数据元素大小仅支持64的倍数且大于等于32768。
statistic: bool类型,控制是否重新统计pdf(概率密度分布);设置为True时会重新统计输入input指数位字节的概率密度分布并覆盖pdf,设置为False时会使用输入的pdf进行压缩;默认值为False;
reshuff: bool类型,控制是否将fixed中多核压缩的结果连续化;限制为fixed大小大于等于压缩上界时候才能使用,详细见约束。设置为True则将多核压缩的结果连续化,设置为False时则不做处理;设置为True时var参数失效;该参数需同步传入解码;默认值为False;
输出说明
pdf:Device侧的Tensor类型,表示input指数位字节的概率密度分布,数据类型为INT32,shape为[1, 256],其中每一个元素的值表示其对应索引,在input中出现的次数;当statistic设置为True时会统计输入input指数位的pdf并覆盖原有pdf,设置为False时会使用当前输入的pdf进行压缩;
mantissa:可为Device侧的Tensor类型、或Host侧内存通过虚拟内存映射至Deive,表示input输入的尾数部分;数据类型与input保持一致;输入Shape无限制,输入大小见约束。
fixed:Device侧的Tensor类型,表示input指数位字节压缩的定长部分,一般由上层应用设定固定容量的空间来存储压缩结果;数据类型与input保持一致;输入Shape无限制,输入大小见约束。
var:可为Device侧的Tensor类型、或Host侧内存通过虚拟内存映射至Deive,表示input指数位字节压缩的变长部分,一般由上层应用设定容量大小;数据类型与input保持一致;输入Shape无限制,输入大小见约束。
约束说明
输入input的元素数量为64的倍数且大于等于32768。
pdf的shape为[1, 256],数据类型为INT32。
mantissa.numel() * mantissa.element_size() = input.numel() * (input.element_size() – 1),尾数的大小可根据input输入的类型和大小严格计算。
fixed.numel() * fixed.element_size() >= 512,即fixed的大小必须大于512Byte来存储压缩的元信息。
fixed.numel() * fixed.element_size() + var.numel() * var.element_size() >= input.numel() + input.numel() / 64 + 8448 * 当前硬件Vector核数,即fixed与var的空间大小总和必须大于压缩上界。
如果reshuff为True,则fixed.numel() * fixed.element_size() 需要大于input.numel() + input.numel() / 64 + 8448 * 硬件vector核数,即保证压缩结果同时存在于fixed上,fixed的大小需大于等于压缩上界。
支持的型号
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例
import torch
import torch_npu
data_shape = (4096, 512)
statistic = True
reshuff = False
input_tensor = torch.randn(data_shape, dtype=dtype).npu()
pdf = torch.zeros(256, dtype=torch.int32).npu()
mantissa_numel = input_tensor.numel() * (input_tensor.element_size() - 1)
mantissa = torch.zeros(mantissa_numel // input_tensor.element_size(), dtype=input_tensor.dtype).npu()
fixed = torch.zeros(input_tensor.numel(), dtype=input_tensor.dtype).npu()
var = torch.zeros(input_tensor.numel(), dtype=input_tensor.dtype).npu()
pdf, mantissa, fixed, var = torch_npu.npu_hans_encode(input_tensor, statistic, reshuff, out=(pdf, mantissa, fixed, var))
"""
)
_add_torch_npu_docstr(
"npu_hans_decode",
"""
torch_npu.npu_hans_decode( mantissa, fixed, var, pdf, reshuff, out=out)
功能描述
基于概率密度分布(PDF)对压缩后的结果进行无损解压缩
参数说明(包括 类型、默认值、含义、参数使用限制)
mantissa:可为Device侧的Tensor类型、或Host侧内存通过虚拟内存映射至Deive,表示压缩前张量的尾数部分。数据类型支持FLOAT16、FLOAT32、BFLOAT16类型;输入Shape无限制,为npu_hans_encode的输出。
fixed:Device侧的Tensor类型,表示压缩前张量的指数位字节压缩的定长部分;数据类型与input保持一致;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型;输入Shape无限制,为npu_hans_encode的输出。
var:可为Device侧的Tensor类型、或Host侧内存通过虚拟内存映射至Deive,表示压缩前张量的指数位字节压缩的变长部分。数据类型支持FLOAT16、FLOAT32、BFLOAT16类型;输入Shape无限制,为npu_hans_encode的输出。
pdf:Device侧的Tensor类型,表示压缩时采用的概率密度分布,数据类型为INT32,shape为[1, 256]。
reshuff: bool类型,表示在压缩时是否将fixed中多核压缩的结果进行了连续化,设置为True则表示已将多核压缩的结果连续化,设置为False时则表示没有将fixed压缩的结果连续化;默认值为False。
输出说明
out:Device侧的Tensor类型,表示解压缩后的张量,数据类型与mantissa等输入一致,Shape无限制,大小详见约束;
约束说明
输出out的元素数量为64的倍数且大于等于32768。
pdf的shape为[1, 256],数据类型为INT32。
mantissa.numel() * mantissa.element_size() = out.numel() * (out.element_size() – 1)。
支持的型号
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例
import torch
import torch_npu
data_shape = (4096, 512)
statistic = True
reshuff = False
input_tensor = torch.randn(data_shape, dtype=dtype).npu()
recover = torch.zeros(data_shape, dtype=dtype).npu()
pdf = torch.zeros(256, dtype=torch.int32).npu()
mantissa_numel = input_tensor.numel() * (input_tensor.element_size() - 1)
mantissa = torch.zeros(mantissa_numel // input_tensor.element_size(), dtype=input_tensor.dtype).npu()
fixed = torch.zeros(input_tensor.numel(), dtype=input_tensor.dtype).npu()
var = torch.zeros(input_tensor.numel(), dtype=input_tensor.dtype).npu()
pdf, mantissa, fixed, var = torch_npu.npu_hans_encode(input_tensor, statistic, reshuff, out=(pdf, mantissa, fixed, var))
recover = torch_npu.npu_hans_decode(mantissa, fixed, var, pdf, reshuff, out=recover)
"""
)
_add_torch_npu_docstr(
"npu_ifmr",
"""
torch_npu.npu_ifmr(Tensor data, Tensor data_min, Tensor data_max, Tensor cumsum, float min_percentile, float max_percentile, float search_start, float search_end, float search_step, bool with_offset) -> (Tensor, Tensor)
功能描述
使用“begin,end,strides”数组对ifmr结果进行计数。
参数说明
data (Tensor) - 特征图张量。
data_min (Tensor) - 特征图最小值的张量。
data_max (Tensor) - 特征图最大值的张量。
cumsum (Tensor) - cumsum bin数据张量。
min_percentile (Float) - 最小初始化百分位数。
max_percentile (Float) - 最大初始化百分位数。
search_start (Float) - 搜索起点。
search_end (Float) - 搜索终点。
search_step (Float) - 搜索步长。
with_offset (Bool) - 是否使用offset。
输出说明
scale (Tensor) - 最优尺度。
offset (Tensor) - 最优offset。
示例
>>> import torch
>>> import torch_npu
>>> torch.npu.set_compile_mode(jit_compile=True)
>>> input = torch.rand((2,2,3,4),dtype=torch.float32).npu()
>>> input
tensor([[[[0.4508, 0.6513, 0.4734, 0.1924],
[0.0402, 0.5502, 0.0694, 0.9032],
[0.4844, 0.5361, 0.9369, 0.7874]],
[[0.5157, 0.1863, 0.4574, 0.8033],
[0.5986, 0.8090, 0.7605, 0.8252],
[0.4264, 0.8952, 0.2279, 0.9746]]],
[[[0.0803, 0.7114, 0.8773, 0.2341],
[0.6497, 0.0423, 0.8407, 0.9515],
[0.1821, 0.5931, 0.7160, 0.4968]],
[[0.7977, 0.0899, 0.9572, 0.0146],
[0.2804, 0.8569, 0.2292, 0.1118],
[0.5747, 0.4064, 0.8370, 0.1611]]]], device='npu:0')
>>> min_value = torch.min(input)
>>> min_value
tensor(0.0146, device='npu:0')
>>> max_value = torch.max(input)
>>> max_value
tensor(0.9746, device='npu:0')
>>> hist = torch.histc(input.to('cpu'), bins=128, min=min_value.to('cpu'), max=max_value.to('cpu'))
>>> hist
tensor([1., 0., 0., 2., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 2., 1., 0., 0., 0., 0., 2., 1., 0., 0., 0., 0., 0., 1.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1.,
0., 0., 1., 0., 0., 2., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0.,
0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 2., 0., 0.,
1., 1., 1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1.,
0., 1.])
>>> cdf = torch.cumsum(hist,dim=0).int().npu()
>>> cdf
tensor([ 1, 1, 1, 3, 3, 3, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7,
7, 8, 8, 8, 10, 11, 11, 11, 11, 11, 13, 14, 14, 14, 14, 14, 14, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16,
17, 17, 17, 17, 18, 19, 19, 20, 21, 21, 22, 22, 23, 23, 23, 24, 24, 25,
25, 25, 26, 26, 26, 28, 28, 28, 28, 28, 28, 28, 30, 30, 30, 30, 30, 30,
30, 30, 31, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 34, 35, 37, 37, 37,
38, 39, 40, 40, 41, 41, 41, 42, 42, 43, 44, 44, 44, 44, 45, 45, 46, 47,
47, 48], device='npu:0', dtype=torch.int32)
>>> scale, offset = torch_npu.npu_ifmr(input, min_value, max_value, cdf, min_percentile=0.999999, max_percentile=0.999999, search_start=0.7, search_end=1.3, search_step=0.01, with_offset=False)
>>> scale
tensor(0.0080, device='npu:0')
>>> offset
tensor(0., device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_indexing",
"""
torch_npu.npu_indexing(self, begin, end, strides, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0) -> Tensor
功能描述
使用“begin,end,strides”数组对index结果进行计数。
参数说明
self (Tensor) - 输入张量。
begin (ListInt) - 待选择的第一个值的index。
end (ListInt) - 待选择的最后一个值的index。
strides (ListInt) - index增量。
begin_mask (Int,默认值为0) - 位掩码(bitmask),其中位“i”为“1”意味着忽略开始值,尽可能使用最大间隔。
end_mask (Int,默认值为0) - 类似于“begin_mask”。
ellipsis_mask (Int,默认值为0) - 位掩码,其中位“i”为“1”意味着第“i”个位置实际上是省略号。
new_axis_mask (Int,默认值为0) - 位掩码,其中位“i”为“1”意味着在第“i”位创建新的1D shape。
shrink_axis_mask (Int,默认值为0) - 位掩码,其中位“i”意味着第“i”位应缩小维数。
示例
>>> input = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]], dtype=torch.int32).to("npu")
>>> input
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]], device='npu:0', dtype=torch.int32)
>>> output = torch_npu.npu_indexing(input, [0, 0], [2, 2], [1, 1])
>>> output
tensor([[1, 2],
[5, 6]], device='npu:0', dtype=torch.int32)
"""
)
_add_torch_npu_docstr(
"npu_iou",
"""
torch_npu.npu_iou(bboxes, gtboxes, mode=0) -> Tensor
torch_npu.npu_ptiou(bboxes, gtboxes, mode=0) -> Tensor
功能描述
根据ground-truth和预测区域计算交并比(IoU)或前景交叉比(IoF)。
参数说明
bboxes (Tensor) - 输入张量。
gtboxes (Tensor) - 输入张量。
mode (Int,默认值为0) - 0为IoU模式,1为IoF模式。
示例
>>> bboxes = torch.tensor([[0, 0, 10, 10],[10, 10, 20, 20],[32, 32, 38, 42]], dtype=torch.float16).to("npu")
>>> gtboxes = torch.tensor([[0, 0, 10, 20],[0, 10, 10, 10],[10, 10, 20, 20]], dtype=torch.float16).to("npu")
>>> output_iou = torch_npu.npu_iou(bboxes, gtboxes, 0)
>>> output_iou
tensor([[0.4985, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.9961, 0.0000]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_layer_norm_eval",
"""
torch_npu.npu_layer_norm_eval(input, normalized_shape, weight=None, bias=None, eps=1e-05) -> Tensor
功能描述
对层归一化结果进行计算。与torch.nn.functional.layer_norm相同, 优化NPU设备实现。
参数说明
input (Tensor) - 输入张量。
normalized_shape (ListInt) - size为预期输入的输入shape。
weight (Tensor, 可选,默认值为None) - gamma张量。
bias (Tensor, 可选默认值为None) - beta张量。
eps (Float,默认值为1e-5) - 为保证数值稳定性添加到分母中的ε值。
示例
>>> input = torch.rand((6, 4), dtype=torch.float32).npu()
>>> input
tensor([[0.1863, 0.3755, 0.1115, 0.7308],
[0.6004, 0.6832, 0.8951, 0.2087],
[0.8548, 0.0176, 0.8498, 0.3703],
[0.5609, 0.0114, 0.5021, 0.1242],
[0.3966, 0.3022, 0.2323, 0.3914],
[0.1554, 0.0149, 0.1718, 0.4972]], device='npu:0')
>>> normalized_shape = input.size()[1:]
>>> normalized_shape
torch.Size([4])
>>> weight = torch.Tensor(*normalized_shape).npu()
>>> weight
tensor([ nan, 6.1223e-41, -8.3159e-20, 9.1834e-41], device='npu:0')
>>> bias = torch.Tensor(*normalized_shape).npu()
>>> bias
tensor([5.6033e-39, 6.1224e-41, 6.1757e-39, 6.1224e-41], device='npu:0')
>>> output = torch_npu.npu_layer_norm_eval(input, normalized_shape, weight, bias, 1e-5)
>>> output
tensor([[ nan, 6.7474e-41, 8.3182e-20, 2.0687e-40],
[ nan, 8.2494e-41, -9.9784e-20, -8.2186e-41],
[ nan, -2.6695e-41, -7.7173e-20, 2.1353e-41],
[ nan, -1.3497e-41, -7.1281e-20, -6.9827e-42],
[ nan, 3.5663e-41, 1.2002e-19, 1.4314e-40],
[ nan, -6.2792e-42, 1.7902e-20, 2.1050e-40]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_attention_update",
"""
接口原型:
npu_attention_update(Tensor[] lse, Tensor[] local_out, int update_type) -> (Tensor, Tensor)
功能描述
将各SP域PA算子的输出的中间结果lse, attention out 两个局部结果更新成全局结果。
计算公式:
lsemax = max(lsei)
lse = Σexp(lsei - lsemax)
lsem = lsemax + log(lse)
O = ΣOi * exp(lsei - lsem)
参数说明
lse:Device侧的TensorList,每个Tensor形状为 (batch * seqLen * headNum),各sp域计算的lse。数据类型支持 FLOAT,格式支持 ND。
local_out:Device侧的TensorList,TensorList长度为sp,每个Tensor形状为 (batch * seqLen * headNum, head_dim),各sp域计算的output。数据类型支持 FLOAT,FLOAT16,BFLOAT16,格式支持 ND。
update_type:int64_t 类型,指定执行的操作类型。
输出说明
out:Device侧的Tensor,形状为 (batch * seqLen * headNum, head_dim),数据类型为 FLOAT,FLOAT16,BFLOAT16,格式为 ND。
lse_out: Device侧的Tensor,形状为 (batch * seqLen * headNum),数据类型为 FLOAT格式为 ND。update_type=1时有效。
支持的型号
----------------
Atlas A3训练系列产品/Atlas A3推理系列产品
Atlas A2训练系列产品/Atlas 800I A2推理产品/A200I A2 Box异构组件
调用示例
----------------
import torch
import torch_npu
dtype = torch.float32
N = 4
head_dim = 32
lse = [
torch.randn(N, dtype=dtype, device='npu'), # 1D: [N]
torch.randn(N, dtype=dtype, device='npu'), # 1D: [N]
]
local_out = [
torch.randn(N, head_dim, dtype=dtype, device='npu'), # 2D: [N, head_dim]
torch.randn(N, head_dim, dtype=dtype, device='npu'), # 2D: [N, head_dim]
]
update_type = 0
out, lse_out = torch_npu.npu_attention_update(lse, local_out, update_type)
"""
)
_add_torch_npu_docstr(
"npu_ring_attention_update",
"""
接口原型:
npu_ring_attention_update(Tensor prev_attn_out, Tensor prev_softmax_max, Tensor prev_softmax_sum, Tensor cur_attn_out, Tensor cur_softmax_max, Tensor cur_softmax_sum, *, Tensor? actual_seq_qlen=None, str input_layout="SBH", str input_softmax_layout="") -> (Tensor, Tensor, Tensor)
功能描述
将两次FlashAttention的输出按照softmax的max和sum进行更新,输出新的attention结果、softmax_max和softmax_sum。
计算公式:
softmax_max = max(prev_softmax_max, cur_softmax_max)
softmax_sum = prev_softmax_sum * exp(prev_softmax_max - softmax_max) + cur_softmax_sum * exp(cur_softmax_max - softmax_max)
attn_out = prev_attn_out * exp(prev_softmax_max - softmax_max) * prev_softmax_sum / softmax_sum
+ cur_attn_out * exp(cur_softmax_max - softmax_max) * cur_softmax_sum / softmax_sum
参数说明
prev_attn_out: Tensor类型, 第一次FlashAttention的输出。数据类型支持FLOAT16、FLOAT、BFLOAT16, 数据格式支持ND。
prev_softmax_max: Tensor类型, 第一次FlashAttention的softmax max结果。数据类型支持FLOAT, 数据格式支持ND。
prev_softmax_sum: Tensor类型, 第一次FlashAttention的softmax sum结果。数据类型支持FLOAT, 数据格式支持ND。
cur_attn_out: Tensor类型, 第二次FlashAttention的输出。数据类型和shape需与prev_attn_out一致。
cur_softmax_max: Tensor类型, 第二次FlashAttention的softmax max结果。数据类型和shape需与prev_softmax_max一致。
cur_softmax_sum: Tensor类型, 第二次FlashAttention的softmax sum结果。数据类型和shape需与prev_softmax_sum一致。
actual_seq_qlen: Tensor类型, 可选参数, TND场景下必选, 表示从0开始累计的query序列长度前缀和。数据类型支持int64, 数据格式支持ND。
input_layout: string类型, 可选参数, attention输入输出排布。支持"SBH"和"TND", 默认值为"SBH"。
input_softmax_layout: string类型, 可选参数, softmax相关输入排布。支持""、"SBH"、"TND", 默认值为""。仅在input_layout为"TND"时生效。
输出说明
attn_out: Tensor类型, 更新后的attention输出。shape和数据类型与prev_attn_out一致。
softmax_max: Tensor类型, 更新后的softmax max。shape与prev_softmax_max一致, 数据类型为FLOAT。
softmax_sum: Tensor类型, 更新后的softmax sum。shape与prev_softmax_sum一致, 数据类型为FLOAT。
支持的型号
----------------
Ascend 950PR/Ascend 950DT
Atlas A3训练系列产品/Atlas A3推理系列产品
Atlas A2训练系列产品/Atlas A2推理系列产品
调用示例
----------------
import torch
import torch_npu
prev_attn_out = torch.randn((4, 2, 32), dtype=torch.float16, device="npu")
cur_attn_out = torch.randn((4, 2, 32), dtype=torch.float16, device="npu")
prev_softmax_max = torch.rand((2, 2, 4, 1), dtype=torch.float32, device="npu").repeat(1, 1, 1, 8)
prev_softmax_sum = torch.rand((2, 2, 4, 1), dtype=torch.float32, device="npu").repeat(1, 1, 1, 8)
cur_softmax_max = torch.rand((2, 2, 4, 1), dtype=torch.float32, device="npu").repeat(1, 1, 1, 8)
cur_softmax_sum = torch.rand((2, 2, 4, 1), dtype=torch.float32, device="npu").repeat(1, 1, 1, 8)
attn_out, softmax_max, softmax_sum = torch_npu.npu_ring_attention_update(
prev_attn_out, prev_softmax_max, prev_softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum)
"""
)
_add_torch_npu_docstr(
"npu_linear",
"""
torch_npu.npu_linear(input, weight, bias=None) -> Tensor
功能描述
将矩阵“a”乘以矩阵“b”,生成“a*b”。
参数说明
input (Tensor) - 2D矩阵张量。数据类型:float32、float16、int32、int8。格式:[ND, NHWC, FRACTAL_NZ]。
weight (Tensor) - 2D矩阵张量。数据类型:float32、float16、int32、int8。格式:[ND, NHWC, FRACTAL_NZ]。
bias (Tensor,可选,默认值为None) - 1D张量。数据类型:float32、float16、int32。格式:[ND, NHWC]。
示例
>>> x=torch.rand(2,16).npu()
>>> w=torch.rand(4,16).npu()
>>> b=torch.rand(4).npu()
>>> output = torch_npu.npu_linear(x, w, b)
>>> output
tensor([[3.6335, 4.3713, 2.4440, 2.0081],
[5.3273, 6.3089, 3.9601, 3.2410]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_lstm",
"""
torch_npu.npu_lstm(x, weight, bias, seqMask, h, c, has_biases, num_layers, dropout, train, bidirectional, batch_first, flag_seq, direction)
功能描述
计算DynamicRNN。
参数说明
x (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
weight (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_ZN_LSTM。
bias (Tensor) - 1D张量。数据类型:float16, float32;格式:ND。
seqMask (Tensor) - 张量。仅支持为FRACTAL_NZ格式的float16和ND格式的int32类型。
h (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
c (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
has_biases (Bool) - 如果值为True,则存在偏差。
num_layers (Int) - 循环层数,目前只支持单层。
dropout (Float) - 如果值为非零,则在除最后一层外的每个LSTM层的输出上引入一个dropout层,丢弃概率等于dropout参数值。目前不支持。
train (Bool,默认值为True) - 标识训练是否在op进行的bool参数。
bidirectional (Bool) - 如果值为True,LSTM为双向。当前不支持。
batch_first (Bool) - 如果值为True,则输入和输出张量将表示为(batch, seq, feature)。当前不支持。
flag_seq (Bool) - 如果值为True,输入为PackSequnce。当前不支持。
direction (Bool) - 如果值为True,则方向为“REDIRECTIONAL”,否则为“UNIDIRECTIONAL”。
输出说明
y (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
output_h (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
output_c (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
i (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
j (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
f (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
o (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
tanhct (Tensor) - 4D张量。数据类型:float16, float32;格式:FRACTAL_NZ。
"""
)
_add_torch_npu_docstr(
"npu_masked_fill_range",
"""
torch_npu.npu_masked_fill_range(self, start, end, value, axis=-1) -> Tensor
功能描述
同轴上被range.boxes屏蔽(masked)的填充张量。自定义屏蔽填充范围算子。
参数说明
self (Tensor) - shape为1D (D,)、2D (N,D)或3D (N,D)的float32/float16/int32/int8 ND张量。
start (Tensor) - 屏蔽填充开始位置。shape为(num,N)的int32 3D张量。
end (Tensor) - 屏蔽填充结束位置。shape为(num,N)的int32 3D张量。
value (Tensor) - 屏蔽填充值。shape为(num,)的float32/float16/int32/int8 2D张量。
axis (Int,默认值为-1) - 带有int32屏蔽填充的轴。
示例
>>> a=torch.rand(4,4).npu()
>>> a
tensor([[0.9419, 0.4919, 0.2874, 0.6560],
[0.6691, 0.6668, 0.0330, 0.1006],
[0.3888, 0.7011, 0.7141, 0.7878],
[0.0366, 0.9738, 0.4689, 0.0979]], device='npu:0')
>>> start = torch.tensor([[0,1,2]], dtype=torch.int32).npu()
>>> end = torch.tensor([[1,2,3]], dtype=torch.int32).npu()
>>> value = torch.tensor([1], dtype=torch.float).npu()
>>> out = torch_npu.npu_masked_fill_range(a, start, end, value, 1)
>>> out
tensor([[1.0000, 0.4919, 0.2874, 0.6560],
[0.6691, 1.0000, 0.0330, 0.1006],
[0.3888, 0.7011, 1.0000, 0.7878],
[0.0366, 0.9738, 0.4689, 0.0979]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_max",
"""
torch_npu.npu_max(self, dim, keepdim=False) -> (Tensor, Tensor)
功能描述
使用dim对最大结果进行计算。类似于torch.max, 优化NPU设备实现。
参数说明
self (Tensor) - 输入张量。
dim (Int) - 待降低维度。
keepdim (Bool,默认值为False) - 输出张量是否保留dim。
输出说明
values (Tensor) - 输入张量中的最大值。
indices (Tensor) - 输入张量中最大值的indices。
示例
>>> input = torch.randn(2, 2, 2, 2, dtype = torch.float32).npu()
>>> input
tensor([[[[-1.8135, 0.2078],
[-0.6678, 0.7846]],
[[ 0.6458, -0.0923],
[-0.2124, -1.9112]]],
[[[-0.5800, -0.4979],
[ 0.2580, 1.1335]],
[[ 0.6669, 0.1876],
[ 0.1160, -0.1061]]]], device='npu:0')
>>> outputs, indices = torch_npu.npu_max(input, 2)
>>> outputs
tensor([[[-0.6678, 0.7846],
[ 0.6458, -0.0923]],
[[ 0.2580, 1.1335],
[ 0.6669, 0.1876]]], device='npu:0')
>>> indices
tensor([[[1, 1],
[0, 0]],
[[1, 1],
[0, 0]]], device='npu:0', dtype=torch.int32)
"""
)
_add_torch_npu_docstr(
"npu_min",
"""
torch_npu.npu_min(self, dim, keepdim=False) -> (Tensor, Tensor)
功能描述
使用dim对最小结果进行计算。类似于torch.min, 优化NPU设备实现。
参数说明
self (Tensor) - 输入张量。
dim (Int) - 待降低维度。
keepdim (Bool) - 输出张量是否保留dim。
输出说明
values (Tensor) - 输入张量中的最小值。
indices (Tensor) - 输入张量中最小值的indices。
示例
>>> import torch
>>> import torch_npu
>>> input = torch.randn(2, 2, 2, 2, dtype = torch.float32).npu()
>>> input
tensor([[[[-0.9909, -0.2369],
[-0.9569, -0.6223]],
[[ 0.1157, -0.3147],
[-0.7761, 0.1344]]],
[[[ 1.6292, 0.5953],
[ 0.6940, -0.6367]],
[[-1.2335, 0.2131],
[ 1.0748, -0.7046]]]], device='npu:0')
>>> outputs, indices = torch_npu.npu_min(input, 2)
>>> outputs
tensor([[[-0.9909, -0.6223],
[-0.7761, -0.3147]],
[[ 0.6940, -0.6367],
[-1.2335, -0.7046]]], device='npu:0')
>>> indices
tensor([[[0, 1],
[1, 0]],
[[1, 1],
[0, 1]]], device='npu:0', dtype=torch.int32)
"""
)
_add_torch_npu_docstr(
"npu_mish",
"""
按元素计算Mish激活函数结果。Mish激活函数定义:mish(x) = x * tanh(softplus(x)),其中softplus(x) = ln(1 + e^x)。
参数解释:
self (Tensor) - 数据类型:float16、float32。
约束条件:
无
示例:
>>> x = torch.rand(10, 30, 10).npu()
>>> y = torch_npu.npu_mish(x)
>>> y.shape
torch.Size([10, 30, 10])
"""
)
_add_torch_npu_docstr(
"npu_multi_head_attention",
"""
torch_npu.npu_multi_head_attention(Tensor query, Tensor key, Tensor value, Tensor query_weight, Tensor key_weight, Tensor value_weight, Tensor attn_mask, Tensor out_proj_weight, Tensor query_bias, Tensor key_bia, Tensor value_bias, Tensor out_proj_bias, Tensor dropout_mask, int attn_head_num, int attn_dim_per_head, int src_len, int tgt_len, float dropout_prob, bool softmax_use_float) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述
实现Transformer模块中的MultiHeadAttention计算逻辑。
参数说明
query: Tensor类型,仅支持float16
key: Tensor类型,仅支持float16
value: Tensor类型,仅支持float16
query_weight: Tensor类型,仅支持float16
key_weight: Tensor类型,仅支持float16
value_weight: Tensor类型,仅支持float16
attn_mask: Tensor类型,仅支持float16
out_proj_weight: Tensor类型,仅支持float16
query_bias: Tensor类型,仅支持float16
key_bias: Tensor类型,仅支持float16
value_bias: Tensor类型,仅支持float16
out_proj _bias: Tensor类型,仅支持float16
dropout_mask: Tensor类型,仅支持float16
attn_head_num: Attention Head numbers, Int型
attn_dim_per_head:Attention dim of a Head , Int型
src_len:source length, Int型
tgt_len:target length, Int型
dropout_prob:dropout keep probability, Float型
softmax_use_float:SoftMax Use Float32 to keep precision, Bool型
输出说明
y: Tensor类型,仅支持float16
dropout_mask: Tensor类型,仅支持float16
query_res: Tensor类型,仅支持float16
key_res: Tensor类型,仅支持float16
value_res: Tensor类型,仅支持float16
attn_scores: Tensor类型,仅支持float16
attn_res: Tensor类型,仅支持float16
context: Tensor类型,仅支持float16
约束说明
Attr attn_head_num:需16对齐
Attr attn_dim_per_head:需16对齐
Attr src_len:需16对齐
tgt_len:需16对齐
示例
import torch
import torch_npu
import numpy as np
batch = 8
attn_head_num = 16
attn_dim_per_head = 64
src_len = 64
tgt_len = 64
dropout_prob = 0.0
softmax_use_float = True
weight_col = attn_head_num * attn_dim_per_head
query = torch.from_numpy(np.random.uniform(-1, 1, (batch * tgt_len, weight_col)).astype("float16")).npu()
key = torch.from_numpy(np.random.uniform(-1, 1, (batch * src_len, weight_col)).astype("float16")).npu()
value = torch.from_numpy(np.random.uniform(-1, 1, (batch * tgt_len, weight_col)).astype("float16")).npu()
query_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
key_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
value_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
out_proj_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
attn_mask = torch.from_numpy(np.random.uniform(-1, 1, (batch, attn_head_num, tgt_len, src_len)).astype("float16")).npu()
query_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
key_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
value_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
out_proj_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
dropout_mask = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
npu_result, npu_dropout_mask, npu_query_res, npu_key_res, npu_value_res, npu_attn_scores, npu_attn_res, npu_context = torch_npu.npu_multi_head_attention (query, key, value, query_weight, key_weight, value_weight, attn_mask, out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias, dropout_mask, attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob, softmax_use_float)
print(npu_result)
tensor([[ 623.5000, 75.5000, 307.0000, ..., 25.3125, -418.7500,
35.9688],
[-254.2500, -165.6250, 176.2500, ..., 87.3750, 78.0000,
65.2500],
[ 233.2500, 207.3750, 324.7500, ..., 38.6250, -264.2500,
153.7500],
...,
[-110.2500, -92.5000, -74.0625, ..., -68.0625, 195.6250,
-157.6250],
[ 300.0000, -184.6250, -6.0039, ..., -15.7969, -299.0000,
-93.1875],
[ -2.5996, 36.8750, 100.0625, ..., 112.7500, 202.0000,
-166.3750]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_nms_rotated",
"""
torch_npu.npu_nms_rotated(dets, scores, iou_threshold, scores_threshold=0, max_output_size=-1, mode=0) -> (Tensor, Tensor)
功能描述
按分数降序选择旋转标注框的子集。
参数说明
dets (Tensor) - shape为[num_boxes, 5]的2D浮点张量
scores (Tensor) - shape为[num_boxes]的1D浮点张量,表示每个框(每行框)对应的一个分数。
iou_threshold (Float) - 表示框与IoU重叠上限阈值的标量。
scores_threshold (Float,默认值为0) - 表示决定何时删除框的分数阈值的标量。
max_output_size (Int,默认值为-1) - 标量整数张量,表示非最大抑制下要选择的最大框数。为-1时即不施加任何约束。
mode (Int,默认值为0) - 指定dets布局类型。如果mode设置为0,则dets的输入值为x、y、w、h和角度。如果mode设置为1,则dets的输入值为x1、y1、x2、y2和角度。
输出说明
selected_index (Tensor) - shape为[M]的1D整数张量,表示从dets张量中选定的index,其中M <= max_output_size。
selected_num (Tensor) - 0D整数张量,表示selected_indices中有效元素的数量。
约束说明
目前不支持mode=1的场景。
示例
>>> dets=torch.randn(100,5).npu()
>>> scores=torch.randn(100).npu()
>>> dets.uniform_(0,100)
>>> scores.uniform_(0,1)
>>> output1, output2 = torch_npu.npu_nms_rotated(dets, scores, 0.2, 0, -1, 1)
>>> output1
tensor([76, 48, 15, 65, 91, 82, 21, 96, 62, 90, 13, 59, 0, 18, 47, 23, 8, 56,
55, 63, 72, 39, 97, 81, 16, 38, 17, 25, 74, 33, 79, 44, 36, 88, 83, 37,
64, 45, 54, 41, 22, 28, 98, 40, 30, 20, 1, 86, 69, 57, 43, 9, 42, 27,
71, 46, 19, 26, 78, 66, 3, 52], device='npu:0', dtype=torch.int32)
>>> output2
tensor([62], device='npu:0', dtype=torch.int32)
"""
)
_add_torch_npu_docstr(
"npu_nms_v4",
"""
torch_npu.npu_nms_v4(boxes, scores, max_output_size, iou_threshold, scores_threshold, pad_to_max_output_size=False) -> (Tensor, Tensor)
功能描述
按分数降序选择标注框的子集。
参数说明
boxes (Tensor) - shape为[num_boxes, 4]的2D浮点张量。
scores (Tensor) - shape为[num_boxes]的1D浮点张量,表示每个框(每行框)对应的一个分数。
max_output_size (Scalar) - 表示non-max suppression下要选择的最大框数的标量。
iou_threshold (Tensor) - 0D浮点张量,表示框与IoU重叠上限的阈值。
scores_threshold (Tensor) - 0D浮点张量,表示决定何时删除框的分数阈值。
pad_to_max_output_size (Bool,默认值为False) - 如果为True,则输出的selected_indices将填充为max_output_size长度。
输出说明
selected_indices (Tensor) - shape为[M]的1D整数张量,表示从boxes张量中选定的index,其中M <= max_output_size。
valid_outputs (Tensor) - 0D整数张量,表示selected_indices中有效元素的数量,有效元素首先出现。
示例
>>> import torch
>>> import torch_npu
>>> boxes=torch.randn(100,4).npu()
>>> scores=torch.randn(100).npu()
>>> boxes.uniform_(0,100)
>>> scores.uniform_(0,1)
>>> max_output_size = 20
>>> iou_threshold = torch.tensor(0.5).npu()
>>> scores_threshold = torch.tensor(0.3).npu()
>>> npu_output = torch_npu.npu_nms_v4(boxes, scores, max_output_size, iou_threshold, scores_threshold)
>>> npu_output
(tensor([57, 65, 25, 45, 43, 12, 52, 91, 23, 78, 53, 11, 24, 62, 22, 67, 9, 94,
54, 92], device='npu:0', dtype=torch.int32), tensor(20, device='npu:0', dtype=torch.int32))
"""
)
_add_torch_npu_docstr(
"npu_nms_with_mask",
"""
torch_npu.npu_nms_with_mask(input, iou_threshold) -> (Tensor, Tensor, Tensor)
功能描述
生成值0或1,用于nms算子确定有效位。
参数说明
input (Tensor) - 输入张量
iou_threshold (Scalar) - 阈值。如果超过此阈值,则值为1,否则值为0。
输出说明
selected_boxes (Tensor) - shape为[N,5]的2D张量,表示filtered box,包括proposal box和相应的置信度分数。
selected_idx (Tensor) - shape为[N]的1D张量,表示输入建议框的index。
selected_mask (Tensor) - shape为[N]的1D张量,判断输出建议框是否有效。
约束说明
输入input的2nd-dim必须等于8。
示例
>>> import torch
>>> import torch_npu
>>> input = torch.tensor([[0.0, 1.0, 2.0, 3.0, 0.6, 0.5, 0.4, 0.3], [6.0, 7.0, 8.0, 9.0, 0.4, 0.5, 0.6, 0.7]], dtype=torch.float16).to("npu")
>>> iou_threshold = 0.5
>>> output1, output2, output3, = torch_npu.npu_nms_with_mask(input, iou_threshold)
>>> output1
tensor([[0.0000, 1.0000, 2.0000, 3.0000, 0.6001],
[6.0000, 7.0000, 8.0000, 9.0000, 0.3999]], device='npu:0', dtype=torch.float16)
>>> output2
tensor([0, 1], device='npu:0', dtype=torch.int32)
>>> output3
tensor([1, 1], device='npu:0', dtype=torch.uint8)
"""
)
_add_torch_npu_docstr(
"npu_normalize_batch",
"""
torch_npu.npu_normalize_batch(self, seq_len, normalize_type=0) -> Tensor
功能描述
执行批量归一化。
参数说明
self (Tensor) - 支持float32数据类型,shape为(n, c, d)。
seq_len (Tensor) - 支持Int32数据类型,shape为(n, ), 表示每批次标准化数据量 。
normalize_type (Int,默认值为0) - 支持 "per_feature"或"all_features"。值为0表示 "per_feature",值为1表示"all_features"。
示例
>>> a=np.random.uniform(1,10,(2,3,6)).astype(np.float32)
>>> b=np.random.uniform(3,6,(2)).astype(np.int32)
>>> x=torch.from_numpy(a).to("npu")
>>> seqlen=torch.from_numpy(b).to("npu")
>>> out = torch_npu.npu_normalize_batch(x, seqlen, 0)
>>> out
tensor([[[ 1.1496, -0.6685, -0.4812, 1.7611, -0.5187, 0.7571],
[ 1.1445, -0.4393, -0.7051, 1.0474, -0.2646, -0.1582],
[ 0.1477, 0.9179, -1.0656, -6.8692, -6.7437, 2.8621]],
[[-0.6880, 0.1337, 1.3623, -0.8081, -1.2291, -0.9410],
[ 0.3070, 0.5489, -1.4858, 0.6300, 0.6428, 0.0433],
[-0.5387, 0.8204, -1.1401, 0.8584, -0.3686, 0.8444]]],
device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_one_hot",
"""
torch_npu.npu_one_hot(input, num_classes=-1, depth=1, on_value=1, off_value=0) -> Tensor
功能描述
返回一个one-hot张量。input中index表示的位置采用on_value值,而其他所有位置采用off_value的值。
参数说明
input (Tensor) - 任何shape的class值。
num_classes (Int,默认值为-1) - 待填充的轴。
depth (Int,默认值为1) - one_hot维度的深度。
on_value (Scalar,默认值为1) - 当indices[j] == i时输出中的填充值。
off_value (Scalar,默认值为0) - 当indices[j] != i时输出中的填充值。
示例
>>> a=torch.IntTensor([5, 3, 2, 1]).npu()
>>> b=torch_npu.npu_one_hot(a, depth=5)
>>> b
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_pad",
"""
torch_npu.npu_pad(input, paddings) -> Tensor
功能描述
填充张量。
参数说明
input (Tensor) - 输入张量。
paddings (ListInt) - 数据类型:int32、int64。
示例
>>> input = torch.tensor([[20, 20, 10, 10]], dtype=torch.float16).to("npu")
>>> paddings = [1, 1, 1, 1]
>>> output = torch_npu.npu_pad(input, paddings)
>>> output
tensor([[ 0., 0., 0., 0., 0., 0.],
[ 0., 20., 20., 10., 10., 0.],
[ 0., 0., 0., 0., 0., 0.]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_ps_roi_pooling",
"""
torch_npu.npu_ps_roi_pooling(x, rois, spatial_scale, group_size, output_dim) -> Tensor
功能描述
执行Position Sensitive ROI Pooling。
参数说明
x (Tensor) - 描述特征图的NC1HWC0张量。维度C1必须等于(int(output_dim+15)/C0)) group_size。
rois (Tensor) - shape为[batch, 5, rois_num]的张量,用于描述ROI。每个ROI由五个元素组成:“batch_id”、“x1”、“y1”、“x2”和“y2”,其中“batch_id”表示输入特征图的index,“x1”、“y1”、“x2”,和“y2”必须大于或等于“0.0”。
spatial_scale (Float32) - 将输入坐标映射到ROI坐标的缩放系数。
group_size (Int32) - 指定用于编码position-sensitive评分图的组数。该值必须在(0,128)范围内。
output_dim (Int32) - 指定输出通道数。必须大于0。
示例
>>> roi = torch.tensor([[[1], [2], [3], [4], [5]],[[6], [7], [8], [9], [10]]], dtype = torch.float16).npu()
>>> x = torch.tensor([[[[ 1]], [[ 2]], [[ 3]], [[ 4]],[[ 5]], [[ 6]], [[ 7]], [[ 8]]],[[[ 9]], [[10]], [[11]], [[12]],[[13]], [[14]], [[15]], [[16]]]], dtype = torch.float16).npu()
>>> out = torch_npu.npu_ps_roi_pooling(x, roi, 0.5, 2, 2)
>>> out
tensor([[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]],
[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_ptiou",
"""
torch_npu.npu_ptiou(bboxes, gtboxes, mode=0) -> Tensor
功能描述
根据ground-truth和预测区域计算交并比(IoU)或前景交叉比(IoF)。
参数说明
bboxes (Tensor) - 输入张量。
gtboxes (Tensor) - 输入张量。
mode (Int,默认值为0) - 0为IoU模式,1为IoF模式。
示例
>>> bboxes = torch.tensor([[0, 0, 10, 10],
>>> [10, 10, 20, 20],
>>> [32, 32, 38, 42]], dtype=torch.float16).to("npu")
>>> gtboxes = torch.tensor([[0, 0, 10, 20],
>>> [0, 10, 10, 10],
>>> [10, 10, 20, 20]], dtype=torch.float16).to("npu")
>>> output_iou = torch_npu.npu_iou(bboxes, gtboxes, 0)
>>> output_iou
tensor([[0.4985, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.9961, 0.0000]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_random_choice_with_mask",
"""
torch_npu.npu_random_choice_with_mask(x, count=256, seed=0, seed2=0) -> (Tensor, Tensor)
功能描述
混洗非零元素的index。
参数说明
x (Tensor) - 输入张量。
count (Int,默认值为256) - 输出计数。如果值为0,则输出所有非零元素。
seed (Int,默认值为0) - 数据类型:int32,int64。
seed2 (Int,默认值为0) - 数据类型:int32,int64。
输出说明
y (Tensor) - 2D张量, 非零元素的index。
mask (Tensor) - 1D张量, 确定对应index是否有效。
示例
>>> x = torch.tensor([1, 0, 1, 0], dtype=torch.bool).to("npu")
>>> result, mask = torch_npu.npu_random_choice_with_mask(x, 2, 1, 0)
>>> result
tensor([[0], [2]], device='npu:0', dtype=torch.int32)
>>> mask
tensor([True, True], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_reshape",
"""
torch_npu.npu_reshape(self, shape, bool can_refresh=False) -> Tensor
功能描述
reshape张量。仅更改张量shape,其数据不变。
参数说明
self (Tensor) - 输入张量。
shape (ListInt) - 定义输出张量的shape。
can_refresh (Bool,默认值为False) - 是否就地刷新reshape。
约束说明
该运算符不能被aclopExecute API直接调用。
示例
>>> a=torch.rand(2,8).npu()
>>> out=torch_npu.npu_reshape(a,(4,4))
>>> out
tensor([[0.6657, 0.9857, 0.7614, 0.4368],
[0.3761, 0.4397, 0.8609, 0.5544],
[0.7002, 0.3063, 0.9279, 0.5085],
[0.1009, 0.7133, 0.8118, 0.6193]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_rms_norm",
"""
torch_npu.npu_rms_norm(Tensor self, Tensor gamma, float epsilon=1e-06) -> (Tensor, Tensor)
功能描述
RmsNorm算子是大模型常用的归一化操作,相比LayerNorm算子,其去掉了减去均值的部分。
参数说明
self:Tensor类型,支持float16、bfloat16、float32,输入shape支持2-8维。
gamma:Tensor类型,数据类型需要和self保持一致,输入shape支持2-8维,通常为self的最后一维。
epsilon:float数据类型,用于防止除0错误。
输出说明
共两个输出,格式为: (Tensor, Tensor)
第1个输出为Tensor,计算公式的最终输出y;
第2个输出为Tensor,rms_norm的reverse rms中间结果,用于反向计算。
约束说明
输入数据类型仅支持float16、bfloat16和float32。
示例
import torch
import torch_npu
x = torch.randn(24, 1, 128).npu()
w = torch.randn(128).npu()
out1 = torch_npu.npu_rms_norm(x, w, epsilon=1e-5)[0]
print(out1)
tensor([[[-0.1123, 0.3398, 0.0986, ..., -2.1250, -0.8477, -0.3418]],
[[-0.0591, 0.3184, -0.5000, ..., 1.0312, -1.1719, -0.1621]],
[[-0.1445, 0.3828, -0.3438, ..., -0.9102, -0.5703, 0.0073]],
...,
[[-0.1631, -0.3477, 0.4297, ..., 0.9219, 0.1621, 0.3125]],
[[-0.1387, 0.0815, 0.0967, ..., 1.7109, 0.1455, -0.1406]],
[[ 0.0698, 1.3438, -0.0127, ..., -2.2656, -0.4473, 0.3281]]],
device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_roi_align",
"""
torch_npu.npu_roi_align(features, rois, spatial_scale, pooled_height, pooled_width, sample_num, roi_end_mode) -> Tensor
功能描述
从特征图中获取ROI特征矩阵。自定义FasterRcnn算子。
参数说明
features (Tensor) - 4HD张量
rois (Tensor) - ROI位置,shape为(N, 4)的2D张量。“N”表示ROI的数量,“4”表示ROI所在图像的index,分别为“x0”、“y0”、“x1”和“y1”。
spatial_scale (Float32) - 指定“features”与原始图像的缩放比率。
pooled_height (Int32) - 指定H维度。
pooled_width (Int32) - 指定W维度。
sample_num (Int32,默认值为2) - 指定每次输出的水平和垂直采样频率。若此属性设置为0,则采样频率等于“rois”的向上取整值(一个浮点数)。
roi_end_mode (Int32,默认值为1)
示例
>>> import torch
>>> import torch_npu
>>> x = torch.FloatTensor([[[[1, 2, 3 , 4, 5, 6],
>>> [7, 8, 9, 10, 11, 12],
>>> [13, 14, 15, 16, 17, 18],
>>> [19, 20, 21, 22, 23, 24],
>>> [25, 26, 27, 28, 29, 30],
>>> [31, 32, 33, 34, 35, 36]]]]).npu()
>>> rois = torch.tensor([[0, -2.0, -2.0, 22.0, 22.0]]).npu()
>>> out = torch_npu.npu_roi_align(x, rois, 0.25, 3, 3, 2, 0)
>>> out
tensor([[[[ 4.5000, 6.5000, 8.5000],
[16.5000, 18.5000, 20.5000],
[28.5000, 30.5000, 32.5000]]]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_rotary_mul",
"""
torch_npu.npu_rotary_mul(Tensor input, Tensor r1, Tensor r2, str rotary_mode='half', Tensor? rotate=None) -> Tensor
功能描述
实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。
不传入rotate参数:
half模式:
x1, x2 = torch.chunk(input, 2, -1)
x_new = torch.cat((-x2, x1), dim=-1)
output = r1 * input + r2 * x_new
interleave模式:
x1 = input[..., ::2]
x2 = input[..., 1::2]
x_new = rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ...(d two)", two=2)
output = r1 * input + r2 * x_new
传入rotate时(由开发者生成rotate矩阵,生成方式参考示例):
x_new = input @ rotate
output = r1 * input + r2 * x_new
参数说明
input:必选输入,4维Tensor,数据类型float16, bfloat16, float32
cos: 必选输入,4维Tensor,数据类型float16, bfloat16, float32
sin: 必选输入,4维Tensor,数据类型float16, bfloat16, float32
rotary_mode: 可选属性,数据类型string,用于选择计算模式,支持'half'、'interleave'两种模式。缺省为half。
rotate:可选参数,表示实现input位置变换的等价变化矩阵,输入维度支持2维,数据类型支持`float16`,`bfloat16`,`float32`,默认为空,不支持反向传播。
约束说明
jit_compile=False场景:
half模式:
input: layout支持: BNSD、BSND、SBND; D < 896,且为2的倍数; B, N < 1000; 当需要计算cos/sin的反向梯度时,B*N <= 1024
r1: 数据范围:[-1, 1]; 对应input layout的支持情况:
input为BNSD: 11SD、B1SD、BNSD;
input为BSND: 1S1D、BS1D、BSND;
input为SBND: S11D、SB1D、SBND.
r2: 同r1
half模式下,当输入layout是BNSD,且D为非32Bytes对齐时,建议不使用该融合算子(模型启动脚本中不开启--use-fused-rotary-pos-emb选项),否则可能出现性能下降。
interleave模式:
input: layout支持: BNSD、BSND、SBND; B * N < 1000; D < 896, 且D为2的倍数;
r1: 数据范围:[-1, 1]; 对应input layout的支持情况:
input为BNSD: 11SD;
input为BSND: 1S1D;
input为SBND: S11D.
r2: 同r1
支持Atlas A2训练系列产品,Atlas A3训练系列产品。
jit_compile=True场景:
仅支持rotary_mode为half模式,且r1/r2 layout一般为11SD、1S1D、S11D。
shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。
广播场景下,广播轴的总数据量不能超过1024
支持Atlas训练系列产品,Atlas A2训练系列产品, Atlas 推理系列产品。
rotate推荐使用场景
interleave模式
half模式仅在以下场景时推荐使用:输入矩阵x需要在最后一个维度切分多份时,可以通过构造旋转编码矩阵实现一次调用获得性能收益,以x的layout为BSND需要切分为3份为例:
x切分为3份,$x = [x1|x2|x3]_{(dim=4)} ∈ R^{B×S×N×D}, x1 ∈ R^{B×S×N×D1},x2 ∈ R^{B×S×N×D2},x3 ∈ R^{B×S×N×D3}, 其中D = D1 + D2 + D3$,
那么可以构造一个rotate矩阵,实现调用一次完成x的旋转位置编码计算功能,rotate矩阵构造如下:
$$rotate = diag(rotate1, rotate2, rotate3) = \begin{pmatrix}rotate1&0&0\\0&rotate2&0\\0&0&rotate3\\\end{pmatrix}$$
其中rotate1、rotate2、rotate3分别为x1、x2、x3的旋转编码矩阵,单个旋转矩阵构建参考调用示例。
示例1
>>> x = torch.rand(2, 2, 5, 128).npu()
>>> r1 = torch.rand(1, 2, 1, 128).npu()
>>> r2 = torch.rand(1, 2, 1, 128).npu()
>>> out = torch_npu.npu_rotary_mul(x, r1, r2)
示例2
>>> n = 128
>>> rotate = torch.zeros(n, n, dtype=torch.bfloat16) # interleave
>>> for i in range(0, n, 2):
... rotate[i + 0, i + 1] = 1
... rotate[i + 1, i + 0] = 1
>>> x = torch.rand(2, 2, 5, n).npu()
>>> r1 = torch.rand(1, 1, 5, n).npu()
>>> r2 = torch.rand(1, 1, 5, n).npu()
>>> out = torch_npu.npu_rotary_mul(x, r1, r2, "interleave", rotate)
示例3
>>> n = 128
>>> rotate = torch.zeros(n, n, dtype=torch.bfloat16) # half
>>> half = n // 2
>>> rotate[:half, half:] = torch.eye(half)
>>> rotate[half:, :half] = -torch.eye(half)
>>> x = torch.rand(2, 2, 5, n).npu()
>>> r1 = torch.rand(1, 1, 5, n).npu()
>>> r2 = torch.rand(1, 1, 5, n).npu()
>>> out = torch_npu.npu_rotary_mul(x, r1, r2, "half", rotate)
"""
)
_add_torch_npu_docstr(
"npu_mrope",
"""
torch_npu.npu_mrope(Tensor positions, Tensor query, Tensor key, Tensor cos_sin_cache, int head_size, *, int[]? mrope_section, str? rotary_mode='half', str? cache_mode='default') -> (Tensor, Tensor)
功能描述
实现旋转位置编码。通过传入cos和sin的cache执行旋转位置编码计算。
参数说明
positions (Tensor): 输入索引,用于选取位置编码张量。要求是一个维度为1D或2D的Tensor,shape为 (numTokens)或(3, numTokens),1D维度输入是rope模式,2D维度输入是mrope模式。numTokens表示一个序列中的token数量。支持非连续的Tensor,支持空Tensor。数据类型支持INT32、INT64,数据格式支持ND。
queryIn (Tensor): 要执行旋转位置编码的第一个张量,维度为2D的Tensor,shape为 (numTokens, numQHeads*headSize)。numQHeads表示query的注意力头数量。headSize表示每个注意力头维度大小。支持非连续的Tensor,支持空Tensor。数据类型支持BFLOAT16、FLOAT16、FLOAT32,数据格式支持ND。
keyIn (Tensor): 要执行旋转位置编码的第二个张量,维度为2D的Tensor,shape为 (numTokens, numKHeads*headSize)。numKHeads表示key的注意力头数量。headSize表示每个注意力头维度大小。支持非连续的Tensor,支持空Tensor。数据类型支持BFLOAT16、FLOAT16、FLOAT32,数据格式支持ND。
cosSinCache (Tensor): 参与计算的位置编码张量,要求shape为一个2D的(maxSeqLen, rotaryDim)。maxSeqLen表示模型处理的序列的最大长度。rotaryDim表示旋转位置嵌入的维度大小。支持非连续的Tensor,支持空Tensor。数据类型支持BFLOAT16、FLOAT16、FLOAT32,数据格式支持ND。
headSize(int): 表示每个注意力头维度大小。数据类型int64。
mropeSection(int[]): 可选参数,multimodal section配置,用于整合输入的位置编码张量信息,输入mropeSection属性表示使能mrope模式。默认值为不使能mrope模式(即rope模式)输入为[0, 0, 0]。
rotary_mode(str): 可选参数,旋转模式,'half'表示rotate_half(GPT-NeoX style)计算模式,'interleave'表示rotate_interleaved(GPT-J style)计算模式。默认值为'half'。
cache_mode(str): 可选参数,cos和sin的拼接方式,'default'表示三段式拼接,'interleave'表示交错式拼接。默认值为'default'。
约束说明
queryIn、keyIn、cosSinCache只支持2维shape输入。
当输入是BFLOAT16或FLOAT16时,rotary_dim要求是32的倍数,当输入是FLOAT32时,rotary_dim要求是16的倍数。
当输入tensor positions中值域超过cosSinCache的0维maxSeqLen,会有越界报错。
mrope模式下,mropeSection输入mropeSection[0]+mropeSection[1]+mropeSection[2]==rotary_dim/2
示例
>>> num_tokens = 8
>>> num_q_heads = 32
>>> num_kv_heads = num_q_heads
>>> head_size = 128
>>> max_seq_len = num_tokens
>>> rotary_dim = head_size
>>> positions = torch.arange(num_tokens, dtype=torch.int64).repeat(3, 1).npu()
>>> query = torch.rand(num_tokens, num_q_heads*head_size, dtype=torch.float32).npu()
>>> key = torch.rand(num_tokens, num_kv_heads*head_size, dtype=torch.float32).npu()
>>> cos_sin_cache = torch.rand(max_seq_len, rotary_dim, dtype=torch.float32).npu()
>>> rotary_mode = 'half'
>>> cache_mode = 'default'
>>> mrope_section = [16, 24, 24]
>>> query_out, key_out = torch_npu.npu_mrope(positions, query, key, cos_sin_cache, head_size, mrope_section=mrope_section, rotary_mode=rotary_mode, cache_mode=cache_mode)
"""
)
_add_torch_npu_docstr(
"npu_rotated_box_decode",
"""
torch_npu.npu_rotated_box_decode(anchor_boxes, deltas, weight) -> Tensor
功能描述
旋转标注框编码。
参数说明
anchor_box (Tensor) - shape为(B,5,N)的3D输入张量,表示锚点框。“B”表示批处理大小数量,“N”表示标注框数量,值“5”表示“x0”、“x1”、“y0”、“y1”和“angle”。
deltas (Tensor) - shape为(B,5,N)数据类型为float32 (float16)的3D张量。
weight (Tensor,默认值为[1.0, 1.0, 1.0, 1.0, 1.0]) - “x0”、“x1”、“y0”、“y1”和“angle”的浮点列表。
示例
>>> anchor_boxes = torch.tensor([[[4.137],[33.72],[29.4], [54.06], [41.28]]], dtype=torch.float16).to("npu")
>>> deltas = torch.tensor([[[0.0244], [-1.992], [0.2109], [0.315], [-37.25]]], dtype=torch.float16).to("npu")
>>> weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu()
>>> out = torch_npu.npu_rotated_box_decode(anchor_boxes, deltas, weight)
>>> out
tensor([[[ 1.7861],
[-10.5781],
[ 33.0000],
[ 17.2969],
[-88.4375]]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_rotated_box_encode",
"""
torch_npu.npu_rotated_box_encode(anchor_box, gt_bboxes, weight) -> Tensor
功能描述
旋转标注框编码。
参数说明
anchor_box (Tensor) - shape为(B,5,N)的3D输入张量,表示锚点框。“B”表示批处理大小数量,“N”表示标注框数量,值“5”表示“x0”、“x1”、“y0”、“y1”和“angle”。
gt_bboxes (Tensor) - shape为(B,5,N)数据类型为float32 (float16)的3D张量。
weight (Tensor,默认值为[1.0, 1.0, 1.0, 1.0, 1.0]) - “x0”、“x1”、“y0”、“y1”和“angle”的浮点列表。
示例
>>> anchor_boxes = torch.tensor([[[30.69], [32.6], [45.94], [59.88], [-44.53]]], dtype=torch.float16).to("npu")
>>> gt_bboxes = torch.tensor([[[30.44], [18.72], [33.22], [45.56], [8.5]]], dtype=torch.float16).to("npu")
>>> weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu()
>>> out = torch_npu.npu_rotated_box_encode(anchor_boxes, gt_bboxes, weight)
>>> out
tensor([[[-0.4253],
[-0.5166],
[-1.7021],
[-0.0162],
[ 1.1328]]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_rotated_iou",
"""
torch_npu.npu_rotated_iou(self, query_boxes, trans=False, mode=0, is_cross=True,v_threshold=0.0, e_threshold=0.0) -> Tensor
功能描述
计算旋转框的IoU。
参数说明
self (Tensor) - 梯度增量数据,shape为(B, 5, N)数据类型为float32的3D张量。
query_boxes (Tensor) - 标注框,shape为(B, 5, K) 数据类型为float32的3D张量。
trans (Bool,默认值为False) - 值为True表示“xyxyt”,值为False表示“xywht”。
is_cross (Bool,默认值为True) - 值为True时表示交叉计算,为False时表示一对一计算。
mode (Int,默认值为0) - 计算模式,取值为0或1。0表示IoU,1表示IoF。
v_threshold (Float,可选,默认值为0.0) - provide condition relaxation for intersection calculation.
e_threshold (Float,可选,默认值为0.0) - provide condition relaxation for intersection calculation.
示例
>>> import torch
>>> import torch_npu
>>> import numpy as np
>>> a=np.random.uniform(0,1,(2,2,5)).astype(np.float16)
>>> b=np.random.uniform(0,1,(2,3,5)).astype(np.float16)
>>> box1=torch.from_numpy(a).to("npu")
>>> box2=torch.from_numpy(a).to("npu")
>>> output = torch_npu.npu_rotated_iou(box1, box2, trans=False, mode=0, is_cross=True)
>>> output
tensor([[[3.3325e-01, 1.0162e-01],
[1.0162e-01, 1.0000e+00]],
[[0.0000e+00, 0.0000e+00],
[0.0000e+00, 5.9605e-08]]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_rotated_overlaps",
"""
torch_npu.npu_rotated_overlaps(self, query_boxes, trans=False) -> Tensor
功能描述
计算旋转框的重叠面积。
参数说明
self (Tensor) -梯度增量数据,shape为(B, 5, N)数据类型为float32的3D张量。
query_boxes (Tensor) - 标注框,shape为(B, 5, K) 数据类型为float32的3D张量。
trans (Bool,默认值为False) - 值为True表示“xyxyt”,值为False表示“xywht”。
示例
>>> import torch
>>> import torch_npu
>>> import numpy as np
>>> a=np.random.uniform(0,1,(1,3,5)).astype(np.float16)
>>> b=np.random.uniform(0,1,(1,2,5)).astype(np.float16)
>>> box1=torch.from_numpy(a).to("npu")
>>> box2=torch.from_numpy(a).to("npu")
>>> output = torch_npu.npu_rotated_overlaps(box1, box2, trans=False)
>>> output
tensor([[[0.0000, 0.1562, 0.0000],
[0.1562, 0.3713, 0.0611],
[0.0000, 0.0611, 0.0000]]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_scaled_masked_softmax",
"""
torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor
功能描述
计算输入张量x缩放并按照mask遮蔽后的Softmax结果。
参数说明
x(Tensor)- 输入的logits。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。
mask(Tensor)- 输入的掩码。支持数据类型:bool。支持格式:[ND,FRACTAL_NZ]。
scale(float,默认值为1.0)- x的缩放系数。
fixed_triu_mask(bool,默认值为False)- 是否使用自动生成的上三角bool掩码。
约束说明
当前输入x的shape,只支持转为[NCHW]格式后,H和W轴长度大于等于32、小于等于4096、且能被32整除的场景。
输入mask的shape,必须能被broadcast成x的shape。
示例
>>> import torch
>>> import torch_npu
>>>
>>> shape = [4, 4, 2048, 2048]
>>> x = torch.rand(shape).npu()
>>> mask = torch.zeros_like(x).bool()
>>> scale = 1.0
>>> fixed_triu_mask = False
>>>
>>> output = torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)
>>> output.shape
torch.size([4, 4, 2048, 2048])
"""
)
_add_torch_npu_docstr(
"npu_scatter",
"""
torch_npu.npu_scatter(self, indices, updates, dim) -> Tensor
功能描述
使用dim对scatter结果进行计数。类似于torch.scatter,优化NPU设备实现。
参数说明
self (Tensor) - 输入张量。
indices (Tensor) - 待scatter的元素index,可以为空,也可以与src有相同的维数。当为空时,操作返回“self unchanged”。
updates (Tensor) - 待scatter的源元素。
dim (Int) - 要进行index的轴。
支持的型号:
Atlas 训练系列产品
示例
>>> input = torch.tensor([[1.6279, 0.1226], [0.9041, 1.0980]]).npu()
>>> input
tensor([[1.6279, 0.1226],
[0.9041, 1.0980]], device='npu:0')
>>> indices = torch.tensor([0, 1],dtype=torch.int32).npu()
>>> indices
tensor([0, 1], device='npu:0', dtype=torch.int32)
>>> updates = torch.tensor([-1.1993, -1.5247]).npu()
>>> updates
tensor([-1.1993, -1.5247], device='npu:0')
>>> dim = 0
>>> output = torch_npu.npu_scatter(input, indices, updates, dim)
>>> output
tensor([[-1.1993, 0.1226],
[ 0.9041, -1.5247]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_sign_bits_pack",
"""
torch_npu.npu_sign_bits_pack(Tensor self, int size) -> Tensor
功能描述
将float类型1位Adam打包为uint8。
参数说明
x(Tensor) - 1D float张量。
size(Int) - reshape时输出张量的第一个维度。
约束说明
Size可被float打包的输出整除。如果x的size可被8整除,则输出的size为(size of x)/8;否则,输出的size为(size of x // 8) + 1。将在小端位置添加-1浮点值以填充可整除性。Atlas 训练系列产品支持float32和float16类型输入。Atlas 推理系列产品(Ascend 310P处理器)支持float32和float16类型输入。Atlas 200/300/500 推理产品仅支持float16类型输入。
示例
>>> a = torch.tensor([5,4,3,2,0,-1,-2, 4,3,2,1,0,-1,-2],dtype=torch.float32).npu()
>>> b = torch_npu.npu_sign_bits_pack(a, 2)
>>> b
tensor([[159],[15]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_sign_bits_unpack",
"""
torch_npu.npu_sign_bits_unpack(x, size, dtype) -> Tensor
功能描述
将uint8类型1位Adam拆包为float。
参数说明
x(Tensor) - 1D uint8张量。
size(Int) - reshape时输出张量的第一个维度。
dtype(torch.dtype) - 值为1设置输出类型为float16,值为0设置输出类型为float32。
约束说明
Size可被uint8s拆包的输出整除。输出大小为(size of x) * 8。
示例
>>> a = torch.tensor([159, 15], dtype=torch.uint8).npu()
>>> b = torch_npu.npu_sign_bits_unpack(a, 2, torch.float32)
>>> b
tensor([[1., 1., 1., 1., 1., -1., -1., 1.],
[1., 1., 1., 1., -1., -1., -1., -1.]], device='npu:0')
(binary form of 159 is ob00001111)
"""
)
_add_torch_npu_docstr(
"npu_silu",
"""
torch_npu.npu_silu(self) -> Tensor
功能描述
计算self的Swish。Swish是一种激活函数,计算公式为' x * sigmoid(x) '。
参数说明
self (Tensor) - 数据类型:float16、float32
示例
>>> a=torch.rand(2,8).npu()
>>> output = torch_npu.npu_silu(a)
>>> output
tensor([[0.4397, 0.7178, 0.5190, 0.2654, 0.2230, 0.2674, 0.6051, 0.3522],
[0.4679, 0.1764, 0.6650, 0.3175, 0.0530, 0.4787, 0.5621, 0.4026]],
device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_slice",
"""
torch_npu.npu_slice(self, offsets, size) -> Tensor
功能描述
从张量中提取切片。
参数说明
self (Tensor) - 输入张量。
offsets (ListInt) - 数据类型:int32,int64。
size (ListInt) - 数据类型:int32,int64。
示例
>>> input = torch.tensor([[1,2,3,4,5], [6,7,8,9,10]], dtype=torch.float16).to("npu")
>>> offsets = [0, 0]
>>> size = [2, 2]
>>> output = torch_npu.npu_slice(input, offsets, size)
>>> output
tensor([[1., 2.],
[6., 7.]], device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_softmax_cross_entropy_with_logits",
"""
torch_npu.npu_softmax_cross_entropy_with_logits(features, labels) -> Tensor
功能描述
计算softmax的交叉熵cost。
参数说明
features (Tensor) - 张量,一个“batch_size * num_classes”矩阵。
labels (Tensor) - 与“features”同类型的张量。一个“batch_size * num_classes”矩阵。
"""
)
_add_torch_npu_docstr(
"npu_sort_v2",
"""
torch_npu.npu_sort_v2(self, dim=-1, descending=False, out=None) -> Tensor
功能描述
沿给定维度,对输入张量元素进行升序排序(不返回索引)。若dim未设置,则选择输入的最后一个维度。如果descending为True,则元素将按值降序排序。
参数说明
self (Tensor) - 输入张量。
dim (Int, 可选,默认值为-1) - 进行排序的维度。
descending (Bool, 可选,默认值为None) - 排序顺序控制(升序或降序)。
约束说明
目前仅支持输入的最后一个维度(dim=-1)。
示例
>>> x = torch.randn(3, 4).npu()
>>> x
tensor([[-0.0067, 1.7790, 0.5031, -1.7217],
[ 1.1685, -1.0486, -0.2938, 1.3241],
[ 0.1880, -2.7447, 1.3976, 0.7380]], device='npu:0')
>>> sorted_x = torch_npu.npu_sort_v2(x)
>>> sorted_x
tensor([[-1.7217, -0.0067, 0.5029, 1.7793],
[-1.0488, -0.2937, 1.1689, 1.3242],
[-2.7441, 0.1880, 0.7378, 1.3975]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_stride_add",
"""
torch_npu.npu_stride_add(x1, x2, offset1, offset2, c1_len) -> Tensor
功能描述
添加两个张量的partial values,格式为NC1HWC0。
参数说明
x1 (Tensor) - 5HD张量。
x2 (Tensor) - 与“x1”类型相同shape相同(C1值除外)的张量。
offset1 (Scalar) - “x1”中C1的offset value。
offset2 (Scalar) - “x2”中C1的offset value。
c1_len (Scalar) - “y”的C1 len。该值必须小于“x1”和“x2”中C1与offset的差值。
示例
>>> a=torch.tensor([[[[[1.]]]]]).npu()
>>> b=torch_npu.npu_stride_add(a, a, 0, 0, 1)
>>> b
tensor([[[[[2.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]]]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_mhc_pre",
"""
接口原型:
torch_npu.npu_mhc_pre(Tensor x, Tensor phi, Tensor alpha, Tensor bias, *, Tensor? gamma=None, float norm_eps=1e-6, float hc_eps=1e-6, int out_flag=0) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述:
通过一系列计算,可得到 MHC (流形约束超连接)架构中 hidden 层对应的投影矩阵 Hres 和 Hpost,以及作为 Atten 或 MLP 层输入的矩阵 Hin。
输入说明:
x: Tensor类型,必选输入,待计算数据,代表网络中 mHC 层的输入数据。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 BFLOAT16 和 FLOAT16,数据维度可为 3 维 [T, n, D] 和 4 维 [B, S, n, D]。
phi: Tensor类型,必选输入,mHC 的参数矩阵。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度为 2 维 [n^2 + 2n, nD]。
alpha: Tensor类型,必选输入,mHC 的缩放参数。支持的数据类型为 FLOAT32,数据维度为 1 维 [3]。
bias: Tensor类型,必选输入,mHC 的 bias 参数。支持的数据类型为 FLOAT32,数据维度为 1 维 [n^2 + 2n]。
gamma: Tensor类型,可选输入,表示进行 RMSNorm 计算时的缩放因子。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度为 2 维 [n, D]。
norm_eps: Float类型,可选输入,RMSNorm 的防除零参数。
hc_eps: Float类型,可选输入,H_pre 经过 sigmoid 运算后的 eps 参数。
out_flag: Int类型,可选输入,表示是否输出中间结果标识,默认值为0(仅输出最终变换结果)。
n:shape 中的 n 常取 4、6、8。
输出说明:
Hin: Tensor类型,必选输出,输出的 h_in,作为 Atten/MLP 层的输入。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 BFLOAT16 和 FLOAT16,数据维度可为 2 维 [T, D] 和 3 维 [B, S, D]。
Hpost: Tensor类型,必选输出,输出的 mHC 的 h_post 变换矩阵。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n] 和 3 维 [B, S, n]。
Hres: Tensor类型,必选输出,输出的 mHC 的 h_res 变换矩阵,未做 sinkhorn 变换。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度可为 3 维 [T, n, n] 和 4 维 [B, S, n, n]。
invRms: Tensor类型,可选输出,RMSNorm 计算得到的 1 / r。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度可为 1 维 [T] 和 2 维 [B, S]。
hMix: Tensor类型,可选输出,x 与 phi 矩阵乘的结果。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n^2 + 2n] 和 3 维 [B, S, n^2 + 2n]。
hPre: Tensor类型,可选输出,做完 sigmoid 计算之后的 h_pre 矩阵。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n] 和 3 维 [B, S, n]。
约束说明:
该接口支持pytorch调用(torch_npu)。
该接口支持图模式。
支持的PyTorch版本:
PyTorch 2.7.1及更高版本
支持的型号:
- 昇腾950 AI处理器
调用示例:
1. 单算子模式调用:
import torch
import torch_npu
T, n, D = 1024, 8, 2560
x = torch.randn(T, n, D, dtype=torch.bfloat16).npu()
phi = torch.randn(n * n + 2 * n, n * D, dtype=torch.float32).npu()
alpha = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).npu()
bias_pre = torch.full((n,), 0.01, dtype=torch.float32)
bias_post = torch.full((n,), 0.01, dtype=torch.float32)
bias_res = torch.full((n, n), 0.01, dtype=torch.float32)
bias = torch.cat([bias_pre, bias_post, bias_res.reshape(-1)], dim=0).npu()
gamma = torch.randn(n, D, dtype=torch.float32).npu()
out_flag = 1
h_in, h_post, h_res, inv_rms, h_mix, h_pre = torch_npu.npu_mhc_pre(
x,
phi,
alpha,
bias,
gamma=gamma,
out_flag=out_flag
)
2. 图模式调用:
import os
import torch
import torch_npu
import torchair
os.environ["ENABLE_ACLNN"] = "false"
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
class MyModule(torch.nn.Module):
def forward(self, x, phi, alpha, bias, gamma, out_flag):
return torch_npu.npu_mhc_pre(
x,
phi,
alpha,
bias,
gamma=gamma,
out_flag=out_flag
)
T, n, D = 256, 8, 2560
x = torch.randn(T, n, D, dtype=torch.bfloat16).npu()
phi = torch.randn(n * n + 2 * n, n * D, dtype=torch.float32).npu()
alpha = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).npu()
bias_pre = torch.full((n,), 0.01, dtype=torch.float32)
bias_post = torch.full((n,), 0.01, dtype=torch.float32)
bias_res = torch.full((n, n), 0.01, dtype=torch.float32)
bias = torch.cat([bias_pre, bias_post, bias_res.reshape(-1)], dim=0).npu()
gamma = torch.randn(n, D, dtype=torch.float32).npu()
out_flag = 0
model = MyModule().npu().eval()
model = torch.compile(model, backend=npu_backend, dynamic=False)
with torch.no_grad():
outputs = model(x, phi, alpha, bias, gamma, out_flag)
torch.npu.synchronize()
"""
)
_add_torch_npu_docstr(
"npu_mhc_sinkhorn",
"""
接口原型:
torch_npu.npu_mhc_sinkhorn(Tensor x, *, float eps=1e-6, SymInt num_iters=20, int out_flag=0) -> (Tensor, Tensor, Tensor)
功能描述
算子功能:通过将残差分支投影矩阵约束为双随机矩阵,在保持多路残差连接表达的能力的基础上,实现恒等映射兼容、范数保证、训练稳定的深度Transformer特征传递。
参数说明:
x: Tensor类型, 数据类型支持float32, 数据格式支持ND, shape是3维(T, n, n)或者4维(B, S, n, n), 其中n仅支持4, 6, 8. 待计算数据,表示mHC层的输入数据。
eps: Scalar类型, 可选参数. 数据类型支持float32, 表示归一化防除零参数,默认值为1e-6.
num_iters: Scalar类型, 可选参数. 数据类型支持int64, 表示迭代次数, 取值范围[1, 100], 默认值为20.
out_flag: Scalar类型, 可选参数. 数据类型支持int64, 表示是否输出中间结果标识, 默认值为0(仅输出最终变换结果).
输出说明:
一个Tensor类型的输出, 代表mhc_sinkhorn. 数据类型和张量形状与输入x保持一致。数据格式支持ND.
约束说明:
该接口支持pytorch调用(torch_npu).
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.7.1
支持的型号:
Atlas A5 训练系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
device = "npu:0"
x_shape = [1, 128, 4 , 4]
x = torch.rand(x_shape, dtype=torch.float32).clamp(min=1e-4)
x_npu = x.npu()
eps = 1e-6
num_iters = 20
out_flag = 0
y = torch_npu.npu_mhc_sinkhorn(x_npu, eps=eps, num_iters=num_iters, out_flag=out_flag)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=False)
config = CompilerConfig()
config.mode="reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class MhcSinkhornModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, eps, num_iters, out_flag):
y = torch_npu.npu_mhc_sinkhorn(x, eps=eps, num_iters=num_iters, out_flag=out_flag)
return y
x_shape = [1, 128, 4 , 4]
x = torch.rand(x_shape, device="npu", dtype=torch.float32)
eps = 1e-6
num_iters = 20
out_flag = 0
mhc_sinkhorn_model = MhcSinkhornModel().npu()
mhc_sinkhorn_model = torch.compile(mhc_sinkhorn_model, backend=npu_backend, dynamic=True)
y = mhc_sinkhorn_model(x, eps=eps, num_iters=num_iters, out_flag=out_flag)
"""
)
_add_torch_npu_docstr(
"npu_mhc_sinkhorn_backward",
"""
接口原型:
torch_npu.npu_mhc_sinkhorn_backward(Tensor grad_y, Tensor norm, Tensor sum) -> Tensor
功能描述
算子功能:MhcSinkhornBackward是MhcSinkhorn的反向算子。mHC(Manifold-Constrained Hyper-Connections)架构中的MhcSinkhorn算子对输入矩阵做sinkhorn变换得到双随机矩阵$\mathbf{H}_{\text{res}}$,输出的双随机矩阵的所有元素≥0、每一行之和为1且每一列之和为1 (具有范数保持、组合封闭性和凸组合几何解释三大特性)。
参数说明:
gradOutput: Tensor类型, 必选参数. 数据类型支持float32, 数据格式支持ND, shape是3维(T, n, n)或者4维(B, S, n, n), 其中n仅支持4, 6, 8. Sinkhorn变换输出的H_res的梯度。
normOut: Tensor类型, 必选参数. 数据类型支持float32, 数据格式支持ND, shape是1维, 表示Sinkhorn变换正向计算保存的中间norm结果。
sumOut: Tensor类型, 必选参数. 数据类型支持float32, 数据格式支持ND, shape是1维, 表示Sinkhorn变换正向计算保存的中间sum结果。
输出说明:
一个Tensor类型的输出, Sinkhorn变换的输入的H_res的梯度. 数据类型和张量形状与输入gradOutput保持一致。数据格式支持ND.
约束说明:
该接口支持pytorch调用(torch_npu).
支持的PyTorch版本
PyTorch 2.7.1
支持的型号:
Atlas A5 训练系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
device = "npu:0"
x_shape = [128, 4 , 4]
T, n, _ = x_shape
x = torch.rand(x_shape, dtype=torch.float32, requires_grad=True).clamp(min=1e-4)
x_npu = x.npu()
eps = 1e-6
num_iters = 20
out_flag = 1
# 正向传播
y, norm_out, sum_out = torch_npu.npu_mhc_sinkhorn(x_npu, eps=eps, num_iters=num_iters, out_flag=out_flag)
# 反向传播
grad_y = torch.randn(T, n, n, dtype=torch.float32)
torch.autograd.backward(tensors=[y], grad_tensors=[grad_y.npu()])
"""
)
_add_torch_npu_docstr(
"npu_mhc_pre_backward",
"""
接口原型:
npu_mhc_pre_backward(Tensor x, Tensor phi, Tensor alpha, Tensor grad_h_in, Tensor grad_h_post, Tensor grad_h_res, Tensor inv_rms, Tensor h_mix, Tensor h_pre, Tensor h_post, Tensor? gamma=None, float hc_eps=1e-6, Tensor? grad_x_post=None) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述:
MhcPreBackward算子是MhcPre的反向算子,可选融合post反向grad_x的add操作,首先说明一下正向算子的功能。
正向算子功能:通过一系列计算,可得到 MHC (流形约束超连接)架构中 hidden 层对应的投影矩阵 Hres 和 Hpost,以及作为 Atten 或 MLP 层输入的矩阵 Hin。
输入说明:
x: Tensor类型,必选输入,待计算数据,代表网络中 mHC 层的输入数据。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 BFLOAT16 和 FLOAT16,数据维度可为 3 维 [T, n, D] 和 4 维 [B, S, n, D]。
phi: Tensor类型,必选输入,mHC 的参数矩阵。数据格式为 ND,支持非连续 Tensor,支持的数据类型为 FLOAT32,数据维度为 2 维 [n^2 + 2n, nD]。
alpha: Tensor类型,必选输入,mHC 的缩放参数。支持的数据类型为 FLOAT32,数据维度为 1 维 [3]。
grad_h_in: Tensor类型,必选输入,h_in作为Atten/MLP层的输入,正向输出h_in对应的梯度。数据格式为 ND,支持的数据类型为 BFLOAT16 和 FLOAT16,数据维度可为 2 维 [T, D] 和 3 维 [B, S, D]。
grad_h_post: Tensor类型,必选输入,正向输出h_post的梯度。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n] 和 3 维 [B, S, n]。
grad_h_res: Tensor类型,必选输入,正向输出h_res的梯度。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度可为 3 维 [T, n, n] 和 4 维 [B, S, n, n]。
inv_rms: Tensor类型,必选输入,正向RmsNorm计算的invRms。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度可为 1 维 [T] 和 2 维 [B, S]。
h_mix: Tensor类型,必选输入,正向计算x与phi矩阵乘的结果。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n^2 + 2n] 和 3 维 [B, S, n^2 + 2n]。
h_pre: Tensor类型,必选输入,做完 sigmoid 计算之后的 h_pre 矩阵。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n] 和 3 维 [B, S, n]。
h_post: Tensor类型,必选输入,正向的h_post输出。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度可为 2 维 [T, n] 和 3 维 [B, S, n]。
gamma: Tensor类型,可选输入,RmsNorm的缩放系数。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度为 2 维 [n, D]。
hc_eps: Float类型,可选输入,H_pre 经过 sigmoid 运算后的 eps 参数。
grad_x_post: Tensor类型,可选输入,post反向输出的grad_x,数据格式为 ND,支持的数据类型为 BFLOAT16 和 FLOAT16,数据维度可为 3 维 [T, n, D] 和 4 维 [B, S, n, D]。
n:shape 中的 n 常取 4、6、8。
输出说明:
grad_x: Tensor类型,必选输出,x对应的梯度。数据格式为 ND,支持的数据类型为 BFLOAT16 和 FLOAT16,数据维度可为 3 维 [T, n, D] 和 4 维 [B, S, n, D]。
grad_phi: Tensor类型,必选输出,phi对应的梯度。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度为 2 维 [n^2 + 2n, nD]。
grad_alpha: Tensor类型,必选输出,alpha对应的梯度。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度为 1 维 [3]。
grad_bias: Tensor类型,必选输出,bias对应的梯度。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度为 1 维 [n^2 + 2n]。
grad_gamma: Tensor类型,可选输出,gamma对应的梯度。数据格式为 ND,支持的数据类型为 FLOAT32,数据维度为 2 维 [n, D]。
约束说明:
该接口支持pytorch调用(torch_npu)。
该接口支持单算子模式。
支持的PyTorch版本:
PyTorch 2.7.1及更高版本
支持的型号:
- 昇腾950 AI处理器
调用示例:
单算子模式调用:
import torch
import torch_npu
import numpy as np
B=1
S=1024
n=4
D=2560
x = torch.randn(B, S, n, D).bfloat16()
phi = torch.randn(n*n + 2*n, n*D)
alpha = torch.tensor([1.1, 0.9, 1.05])
bias = torch.randn(n*n + 2*n) * 0.1
gamma = torch.randn(n, D)
grad_x_post = torch.randn(B, S, n, D).bfloat16()
x_ = x.detach().clone().requires_grad_(True)
phi_ = phi.detach().clone().requires_grad_(True)
alpha_ = alpha.detach().clone().requires_grad_(True)
bias_ = bias.detach().clone().requires_grad_(True)
dh_in = torch.randn(B, S, D).bfloat16()
dh_post = torch.randn(B, S, n)
dh_res = torch.randn(B, S, n, n)
# 正向传播
h_in, h_post, h_res, inv_rms, h_mix, h_pre = torch_npu.npu_mhc_pre(
x_.npu(), phi_.npu(), alpha_.npu(), bias_.npu(), gamma=gamma.npu(), out_flag=1
)
# 调用npu接口
dx1, dphi1, da1, db1, dgamma1 = torch_npu.npu_mhc_pre_backward(
x.npu(), phi.npu(), alpha.npu(),
dh_in.npu(), dh_post.npu(), dh_res.npu(),
inv_rms.npu(), h_mix.npu(), h_pre.npu(), h_post.npu(), gamma=gamma.npu(), grad_x_post=grad_x_post.npu()
)
"""
)
_add_torch_npu_docstr(
"npu_transpose",
"""
torch_npu.npu_transpose(self, perm, require_contiguous=True) -> Tensor
功能描述
返回原始张量视图,其维度已permute,结果连续。支持FakeTensor模式。
参数说明
self (Tensor) - 输入张量。
perm (ListInt) - 对应维度排列。
require_contiguous(Bool,默认值为True) - 用户是否需要对输入Tensor做转连续。设置为False时,表示不对输入Tensor做转连续。用户明确输入Tensor为连续Tensor或转置Tensor时,才能设置为True。
示例
>>> x = torch.randn(2, 3, 5).npu()
>>> x.shape
torch.Size([2, 3, 5])
>>> x1 = torch_npu.npu_transpose(x, (2, 0, 1))
>>> x1.shape
torch.Size([5, 2, 3])
"""
)
_add_torch_npu_docstr(
"npu_yolo_boxes_encode",
"""
torch_npu.npu_transpose(self, perm, require_contiguous=True) -> Tensor
功能描述
返回原始张量视图,其维度已permute,结果连续。支持FakeTensor模式。
参数说明
self (Tensor) - 输入张量。
perm (ListInt) - 对应维度排列。
require_contiguous(Bool,默认值为True) - 用户是否显式指定npu_contiguous算子适配需要对输入Tensor做转连续。默认为False,低性能模式。用户明确知道输入Tensor为连续Tensor或转置Tensor时,才能设置为True使用高性能模式。
示例
>>> x = torch.randn(2, 3, 5).npu()
>>> x.shape
torch.Size([2, 3, 5])
>>> x1 = torch_npu.npu_transpose(x, (2, 0, 1))
>>> x1.shape
torch.Size([5, 2, 3])
>>> x2 = x.npu_transpose(2, 0, 1)
>>> x2.shape
torch.Size([5, 2, 3])
"""
)
_add_torch_npu_docstr(
"one_",
"""
torch_npu.one_(self) -> Tensor
用1填充self张量。
参数解释:
self (Tensor) - 输入张量。
约束条件:
无
示例:
>>> x = torch.rand(2, 3).npu()
>>> x
tensor([[0.6072, 0.9726, 0.3475],
[0.3717, 0.6135, 0.6788]], device='npu:0')
>>> torch_npu.one_(x)
tensor([[1., 1., 1.],
[1., 1., 1.]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"npu_swiglu",
"""
接口原型:
torch_npu.npu_swiglu(Tensor self, int dim=-1) -> (Tensor)
功能描述:
提供swiglu的激活函数。
公式如下:
outputs = swiglu\(x, dim = -1) = swish(A) * B = A * sigmoid(A) * B
“x”是输入Tensor。
“dim”是切分维度,默认为-1。
“A”和“B”是x沿dim维度切分的Tensor。
参数说明:
“x”:Tensor类型,shape支持1-8维,dtype支持FP32、FP16或BF16类型。
“dim”:Int类型,默认为-1。
输出说明:
输出为Tensor,计算公式的最终输出outputs。
支持的型号:
Atlas A2 训练系列产品
调用示例:
import torch
import torch_npu
input_tensor = torch.randn(2, 32, 6, 6)
output = torch_npu.npu_swiglu(input_tensor.npu(), dim = -1)
"""
)
_add_torch_npu_docstr(
"npu_trans_quant_param",
"""
功能描述:
完成量化计算参数scale数据类型的转换.
接口原型:
torch_npu.npu_trans_quant_param(Tensor scale, Tensor? offset=None, int? round_mode=0) -> Tensor
参数说明:
scale: Tensor类型, 数据类型支持float32, 数据格式支持ND, shape是1维(t,)或者2维(1, n). 其中t=1或n, 其中n与matmul计算中的右矩阵中的n一致.
offset: Tensor类型, 可选参数. 数据类型支持float32, 数据格式支持ND, shape是1维(t,)或者2维(1, n). t=1或n, 其中n与matmul计算中的右矩阵中的n一致.
round_mode: torch.int8类型,用于指定FP32填充到FP19的模式,可选参数。支持的枚举值为0和1。0表示截断填充,1表示R_INT模式。默认为0。
输出说明:
一个Tensor类型的输出, 代表trans_quant_param的计算结果.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
传入的scale或out不能为空.
scale、offset或out的数据类型和数据格式需要在支持的范围之内.
scale、offset的shape需要为1维(t,)或者2维(1, n). 其中t=1或n, 其中n与matmul计算中的右矩阵中的n一致.
当scale的shape为两维(1, n)时, scale和offset的shape需要保持一致, 且输出shape也为(1, n).
支持的PyTorch版本
PyTorch 2.5
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
Atlas 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import logging
import os
scale = torch.randn(16, dtype=torch.float32)
offset = torch.randn(16, dtype=torch.float32)
npu_out = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu(), round_mode=0)
图模式调用
图模式下, npu_trans_quant_param计算出的结果tensor为uint64数据类型. 由于torch不支持该数据类型, 需要搭配其他接口使用, 如示例代码中的npu_quant_matmul.
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale, offset, bias):
scale_1 = torch_npu.npu_trans_quant_param(scale, offset, round_mode=0)
return torch_npu.npu_quant_matmul(x1, x2, scale_1, offset=offset, bias=bias)
cpu_model = MyModel()
model = cpu_model.npu()
cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8)
cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8)
scale = torch.randn(1, dtype=torch.float32)
offset = torch.randn(1, dtype=torch.float32)
bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32)
model = torch.compile(model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset.npu(), bias.npu())
"""
)
_add_torch_npu_docstr(
"npu_mhc_post",
"""
接口原型:
torch_npu.npu_mhc_post(Tensor x, Tensor h_res, Tensor h_out, Tensor h_post) -> Tensor
功能描述
算子功能: 融合主分支特征h_out与残差分支特征x_l(或residual),通过门控H_post和经sinkhorn投影后的双随机矩阵H_res机制实现信息流动.
计算公式为:
x_{l+1} = (H_{l}^{res})^{T} \times x_l + h_{l}^{out} \otimes H_{t}^{post}
参数说明
x: Tensor类型, 必选输入, 待计算数据,代表网络中mHC层的输入数据. 数据格式为ND,支持非连续Tensor, 支持的数据类型为BFLOAT16和FLOAT16, 数据维度可为3维[T, n, D]和4维.[B, S, n, D].
h_res: Tensor类型, 必选输入, mHC的h_res变换矩阵,是做完sinkhorn变换之后的双随机矩阵.数据格式为ND,支持非连续Tensor, 支持的数据类型为FLOAT32, 数据维度可为3维[T, n, n]和4维[B, S, n, n].
h_out: Tensor类型, 必选输入,Attn/MLP层输出.数据格式为ND,支持非连续Tensor, 支持的数据类型为BFLOAT16和FLOAT16, 数据维度可为3维[B, S, D]和2维[T, D].
h_post: Tensor类型, 必选输入,mHC的h_post变换矩阵.数据格式为ND,支持非连续Tensor, 支持的数据类型为FLOAT32, 数据维度可为3维[B, S, n]和2维[T, n].
n:shape中n常取4,6,8.
输出说明
y: Tensor类型, 必选输入, 网络中mHC层的输出数据. 数据格式为ND,支持非连续Tensor, 支持的数据类型为BFLOAT16和FLOAT16, 数据维度可为3维[T, n, D]和4维.[B, S, n, D].
约束说明
该接口支持pytorch调用(torch_npu).
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.7.1
支持的型号
-昇腾950 AI处理器
调用示例
单算子模式调用
import torch
import torch_npu
device = "npu:0"
x_shape = [1,1,4,512]
h_res_shape = [1,1,4,4]
h_out_shape = [1,1,512]
h_post_shape = [1,1,2]
x = (torch.rand(x_shape, dtype=torch.bfloat16)).clamp(min=1e-4)
h_res = (torch.rand(h_res_shape, dtype=torch.float32)).clamp(min=1e-4)
h_out = (torch.rand(h_out_shape, dtype=torch.bfloat16)).clamp(min=1e-4)
h_post = (torch.rand(h_post_shape, dtype=torch.float32)).clamp(min=1e-4)
x_npu = x.npu()
h_res_npu = h_res.npu()
h_out_npu = h_out.npu()
h_post_npu = h_post.npu()
y_npu = torch_npu.npu_mhc_post(x_npu, h_res_npu, h_out_npu, h_post_npu)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=False)
config = CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class MhcPostModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, h_res, h_out, h_post):
y = torch_npu.npu_mhc_post(x, h_res, h_out, h_post)
return y
x_shape = [1,1,4,512]
h_res_shape = [1,1,4,4]
h_out_shape = [1,1,512]
h_post_shape = [1,1,4]
x = torch.rand(x_shape, device='npu', dtype=torch.float16)
h_res = torch.rand(h_res_shape, device='npu', dtype=torch.float32)
h_out = torch.rand(h_out_shape, device='npu', dtype=torch.float16)
h_post = torch.rand(h_post_shape, device='npu', dtype=torch.float32)
mhc_post_model = MhcPostModel().npu()
mhc_post_model = torch.compile(mhc_post_model, backend=npu_backend, dynamic=True)
y = mhc_post_model(x, h_res, h_out, h_post)
"""
)
_add_torch_npu_docstr(
"npu_mhc_post_backward",
"""
接口原型:
torch_npu.npu_mhc_post_backward(Tensor grad_output, Tensor x, Tensor h_res, Tensor h_out, Tensor h_post) -> (Tensor, Tensor, Tensor, Tensor)
功能描述
算子功能: MhcPostBackward是MhcPost的反向算子. 融合主分支特征h_out与残差分支特征x_l(或residual),通过门控h_post和经sinkhorn投影后的双随机矩阵h_res机制实现信息流动.
参数说明
grad_output: Tensor类型, 必选输入, 表示前向输出的梯度张量. 数据格式为ND, 支持的数据类型为BFLOAT16和FLOAT16, 数据维度可为3维[T, n, D]和4维[B, S, n, D].
x: Tensor类型, 必选输入, 待计算数据, 代表网络中mHC层的输入数据. 数据格式为ND, 支持的数据类型为BFLOAT16和FLOAT16, 数据维度可为3维[T, n, D]和4维[B, S, n, D].
h_res: Tensor类型, 必选输入, mHC的h_res变换矩阵, 是做完sinkhorn变换之后的双随机矩阵. 数据格式为ND, 支持的数据类型为FLOAT32, 数据维度可为3维[T, n, n]和4维[B, S, n, n].
h_out: Tensor类型, 必选输入, Attn/MLP层输出. 数据格式为ND, 支持的数据类型为BFLOAT16和FLOAT16, 数据维度可为3维[B, S, D]和2维[T, D].
h_post: Tensor类型, 必选输入, mHC的h_post变换矩阵. 数据格式为ND, 支持的数据类型为FLOAT32, 数据维度可为3维[B, S, n]和2维[T, n].
n: shape中n仅支持4, 6, 8.
输出说明
grad_x: Tensor类型, 对前向输入x的梯度, shape和数据类型与x一致, 数据格式为ND.
grad_h_res: Tensor类型, 对前向输入h_res的梯度, shape和数据类型与h_res一致, 数据格式为ND.
grad_h_out: Tensor类型, 对前向输入h_out的梯度, shape和数据类型与h_out一致, 数据格式为ND.
grad_h_post: Tensor类型, 对前向输入h_post的梯度, shape和数据类型与h_post一致, 数据格式为ND.
约束说明
该接口支持pytorch调用(torch_npu).
支持的PyTorch版本
PyTorch 2.7.1
支持的型号
-昇腾950 AI处理器
调用示例
单算子模式调用
import torch
import torch_npu
grad_output_shape = (1, 4, 1024)
x_shape = (1, 4, 1024)
h_res_shape = (1, 4, 4)
h_out_shape = (1, 1024)
h_post_shape = (1, 4)
grad_output = (torch.rand(grad_output_shape, dtype=torch.bfloat16)).clamp(min=1e-4)
x = (torch.rand(x_shape, dtype=torch.bfloat16)).clamp(min=1e-4)
h_res = (torch.rand(h_res_shape, dtype=torch.float32)).clamp(min=1e-4)
h_out = (torch.rand(h_out_shape, dtype=torch.bfloat16)).clamp(min=1e-4)
h_post = (torch.rand(h_post_shape, dtype=torch.float32)).clamp(min=1e-4)
grad_output_npu = grad_output.npu()
x_npu = x.npu()
h_res_npu = h_res.npu()
h_out_npu = h_out.npu()
h_post_npu = h_post.npu()
grad_x, grad_h_res, grad_h_out, grad_h_post = torch_npu.npu_mhc_post_backward(grad_output_npu, x_npu, h_res_npu, h_out_npu, h_post_npu)
"""
)
_add_torch_npu_docstr(
"npu_dynamic_quant",
"""
功能描述:
算子功能: 对输入的张量进行per-token对称动态量化.
如果是MoE(Mixture of Experts, 混合专家模型)场景, 会引入group_index, smooth_scales中包含多组smooth向量, 按group_index中的数值作用到x的不同行上. 具体的, 假如x包含m个token, smooth_scales有n行, smooth_scales[0]会作用到x[0:group_index[0]]上, smooth_scales[i]会作用到x[group_index[i-1]: group_index[i]]上, i=1, 2, ..., n-1.
计算公式:
如果smooth_scales不存在:
scale=rowMax(abs(x))/DTYPE_MAX
y=round(x/scale)
如果smooth_scales存在:
scale=rowMax(abs(x×smooth_scales))/DTYPE_MAX
y=round(x×smooth_scales/scale)
owMax表示求一行的最大值, DTYPE_MAX表示常量, 是y输出对应的数据类型的最大值.
接口原型:
torch_npu.npu_dynamic_quant(Tensor x, *, Tensor? smooth_scales=None, Tensor? group_index=None, int? dst_type=None) ->(Tensor, Tensor)
参数说明:
x: Tensor类型, 需要进行量化的源数据张量, 必选输入, 数据类型支持torch.float16、torch.bfloat16, 数据格式支持ND, 支持非连续的Tensor. 输入x的维度必须大于1. 进行int4量化时, 要求x形状的最后一维是8的整数倍.
smooth_scales: Tensor类型, 对x进行scales的张量, 可选输入, 数据类型支持torch.float16、torch.bfloat16, 数据格式支持ND, 支持非连续的Tensor. shape必须是1维, 和x的最后一维相等.
单算子模式: smooth_scales的dtype必须和x保持一致.
group_index: Tensor类型, 对smooth_scales进行分组的下标, 可选输入, 仅在MoE场景下生效. 数据类型支持int32, 数据格式支持ND, 支持非连续的Tensor.
dst_type: int类型, 指定量化输出的类型, 可选输入, 传None时当做torch.int8处理.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持取值torch.int8、torch.quint4x2.
Atlas A3 训练系列产品: 支持取值torch.int8、torch.quint4x2.
输出说明:
y: 量化后的输出Tensor, 数据类型由dst_type指定. 当dst_type是torch.quint4x2时, y的数据类型为int32, 形状最后一维为x最后一维除以8, 其余维度与x一致, 每个int32元素包含8个int4结果. 其他场景下y形状与输入x一致, 数据类型由dst_type指定.
scale: Tensor类型, 非对称动态量化过程中计算出的缩放系数, 数据类型为float32, 形状为x的形状剔除最后一维.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
该接口仅在如下产品支持MoE场景.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
使用smooth_scales时:
若不使用group_index, smooth_scales必须是一维Tensor, 元素数量与x的最后一维大小一致.
若使用group_index, smooth_scales必须是二维Tensor, 第二维元素数量与x的最后一维大小一致, group_index必须是一维数组, 元素数量与smooth_scales第一维一致. group_index中的元素必须是单调递增的, 其最后一个元素的值, 应等于x的元素数量除以x的最后一个维度.
支持的PyTorch版本
PyTorch 2.5
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用
只有一个输入x
import torch
import torch_npu
x = torch.rand((3, 3), dtype = torch.float16).to("npu")
output, scale = torch_npu.npu_dynamic_quant(x)
print(output)
print(scale)
使用smooth_scales输入
import torch
import torch_npu
x = torch.rand((3, 3), dtype = torch.float16).to("npu")
smooth_scales = torch.rand((3,), dtype = torch.float16).to("npu")
output, scale = torch_npu.npu_dynamic_quant(x, smooth_scales=smooth_scales)
print(output)
print(scale)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class DynamicQuantModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_tensor, smooth_scales=None, group_index=None, dst_type=None):
out, scale = torch_npu.npu_dynamic_quant(input_tensor, smooth_scales=smooth_scales, group_index=group_index, dst_type=dst_type)
return out, scale
x = torch.randn((2, 4, 6),device='npu',dtype=torch.float16).npu()
smooth_scales = torch.randn((6),device='npu',dtype=torch.float16).npu()
dynamic_quant_model = DynamicQuantModel().npu()
dynamic_quant_model = torch.compile(dynamic_quant_model, backend=npu_backend, dynamic=True)
out, scale = dynamic_quant_model(x, smooth_scales=smooth_scales)
print(out)
print(scale)
"""
)
_add_torch_npu_docstr(
"npu_dynamic_quant_asymmetric",
"""
功能描述:
算子功能: 对输入的张量进行per-token非对称动态量化. 其中输入的最后一个维度对应一个token, 每个token作为一组进行量化.
计算公式: 假设待量化张量为x,
scale=(rowMax(x)-rowMin(x))/(DST_MAX-DST_MIN)
offset=DST_MAX-rowMax(x)/scale
y=round(x/scale+offset)
owMax、rowMin代表按行取最大值、按行取最小值, 此处的“行”对应x最后一个维度的数据, 即一个token.
DST_MAX、DST_MIN分别对应量化后的最大值和最小值, 在进行int8量化时, 二者分别对应+127、-128, 进行int4量化时, 分别对应+7、-8
若使用smooth quant, 会引入smooth_scales输入, 其形状与x最后一个维度大小一致, 在进行量化前, 会先令x乘以smooth_scales, 再按上述公式进行量化
若使用smooth quant, MoE(Mixture of Experts, 混合专家模型)场景下会引入smooth_scales和group_index, 此时smooth_scales中包含多组smooth向量, 按group_index中的数值作用到x的不同行上. 具体的, 假如x包含m个token, smooth_scales有n行, smooth_scales[0]会作用到x[0:group_index[0]]上, smooth_scales[i]会作用到x[group_index[i-1]: group_index[i]]上, i=[1, 2, ..., n-1].
接口原型:
torch_npu.npu_dynamic_quant_asymmetric(Tensor x, *, Tensor? smooth_scales=None, Tensor? group_index=None, ScalarType? dst_type=None) -> (Tensor, Tensor, Tensor)
参数说明:
x: Tensor类型, 需要进行量化的源数据张量, 必选输入, 数据类型支持float16、bfloat16, 数据格式支持ND, 支持非连续的Tensor. 输入x的维度必须大于1. 进行int4量化时, 要求x形状的最后一维是8的整数倍.
smooth_scales: Tensor类型, 对x进行平滑缩放的张量, 可选输入, 数据类型需要与x保持一致, 数据格式支持ND, 支持非连续的Tensor.
group_index: Tensor类型, 在MoE场景下, 对smooth_scales进行分组的下标, 可选输入, 数据类型支持int32, 数据格式支持ND, 支持非连续的Tensor.
dst_type: ScalarType类型, 用于选择进行int8/int4量化, 可选输入, 输入值只能是torch.int8和torch.quint4x2, 默认为int8量化.
输出说明:
y: 量化后的输出Tensor, 在进行int8量化时, y的数据类型为int8, 形状与x一致; 在进行int4量化时, y的数据类型为int32, 形状最后一维为x最后一维除以8, 其余维度与x一致, 每个int32元素包含8个int4结果.
scale: Tensor类型, 非对称动态量化过程中计算出的缩放系数, 数据类型为float32, 形状为x的形状剔除最后一维.
offset: Tensor类型, 非对称动态量化过程中计算出的偏移系数, 数据类型为float32, 形状为x的形状剔除最后一维.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
使用可选输入smooth_scales、group_index、dst_type时, 必须使用关键字传参.
使用smooth_scales时:
若不使用group_index, smooth_scales必须是一维Tensor, 元素数量与x的最后一维大小一致.
若使用group_index, smooth_scales必须是二维Tensor, 第二维元素数量与x的最后一维大小一致, group_index必须是一维数组, 元素数量与smooth_scales第一维一致. group_index中的元素必须是单调递增的, 其最后一个元素的值, 应等于x的元素数量除以x的最后一个维度.
支持的PyTorch版本
PyTorch2.5
PyTorch2.4
PyTorch2.3
PyTorch2.1
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用
只有一个输入x, 进行int8量化
import torch
import torch_npu
x = torch.rand((3, 8), dtype=torch.half).npu()
y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x)
print(y, scale, offset)
只有一个输入x, 进行int4量化
import torch
import torch_npu
x = torch.rand((3, 8), dtype=torch.half).npu()
y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x, dst_type=torch.quint4x2)
print(y, scale, offset)
使用smooth_scales输入, 非MoE场景(不使用group_index), 进行int8量化
import torch
import torch_npu
x = torch.rand((3, 8), dtype=torch.half).npu()
smooth_scales = torch.rand((8,), dtype=torch.half).npu()
y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x, smooth_scales=smooth_scales)
print(y, scale, offset)
使用smooth_scales输入, MoE场景(使用group_index), 进行int8量化
import torch
import torch_npu
x = torch.rand((3, 8), dtype=torch.half).npu()
smooth_scales = torch.rand((2, 8), dtype=torch.half).npu()
group_index = torch.Tensor([1, 3]).to(torch.int32).npu()
y, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(x, smooth_scales=smooth_scales, group_index=group_index)
print(y, scale, offset)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:4')
torch_npu.npu.set_device(device)
class DynamicQuantModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_tensor, smooth_scales=None, group_index=None, dst_type=None):
out, scale, offset = torch_npu.npu_dynamic_quant_asymmetric(input_tensor, smooth_scales=smooth_scales, group_index=group_index, dst_type=dst_type)
return out, scale, offset
x = torch.randn((2, 4, 6),device='npu',dtype=torch.float16).npu()
smooth_scales = torch.randn((6),device='npu',dtype=torch.float16).npu()
dynamic_quant_model = DynamicQuantModel().npu()
dynamic_quant_model = torch.compile(dynamic_quant_model, backend=npu_backend, dynamic=True)
out, scale, offset = dynamic_quant_model(x, smooth_scales=smooth_scales)
print(out)
print(scale)
print(offset)
"""
)
_add_torch_npu_docstr(
"npu_quant_matmul_reduce_sum",
"""
功能描述:完成量化的分组矩阵计算,然后所有组的矩阵计算结果相加后输出
计算公式:
out = torch.zeros(m, n)
for i in range(batch):
out += (x1[i, ...] @ x2[i, ...]) * x1Scale[i, :, None] * x2Scale[None, :]
函数原型:
npu_quant_matmul_reduce_sum(x1, x2, *, x1_scale=None, x2_scale=None) -> Tensor
参数说明:
- x1: Tensor类型,必选参数,对应公式中的x1。数据类型支持`int8`,数据格式支持ND,shape支持3维,形状为(batch, m, k)。
- x2: Tensor类型,必选参数,对应公式中的x2。数据类型支持`int8`,数据格式支持NZ,shape支持3维,形状为(batch, k, n)。
- x1_scale: Tensor类型,必选关键字参数。对应公式中的x1Scale。数据类型支持`float32`,数据格式支持ND,shape支持2维,形状为(batch, m)。
- x2_scale: Tensor类型,必选关键字参数。数据类型支持`bfloat16`,数据格式支持ND,shape支持1维,形状为(n,)。
输出说明:
out: Tensor类型,算子的计算结果。输出的数据类型为`bfloat16`,数据格式为ND,shape为2维,形状为(m, n)。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持静态图模式。
- 传入的x1、x2、x1_scale、x2_scale不能是空。
- 输入和输出支持以下数据类型组合:
| x1 | w2 | x1Scale | x2Scale | out |
|------|------|---------|----------|----------|
| int8 | int8 | float32 | bfloat16 | bfloat16 |
支持的PyTorch版本:
PyTorch2.1及以上
支持的型号:
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
- Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例:
- 单算子调用
import torch
import torch_npu
b,m,k,n = (2,3,4,5)
x1 = torch.ones((b, m, k), dtype=torch.int8).npu()
x2_nd = torch.ones((b, k, n), dtype=torch.int8).npu()
x2 = torch_npu.npu_format_cast(x2_nd.contiguous(), 29)
x1_scale = torch.ones((b, m), dtype=torch.float32).npu()
x2_scale = torch.ones((n,), dtype=torch.bfloat16).npu()
y = torch_npu.npu_quant_matmul_reduce_sum(x1, x2, x1_scale=x1_scale, x2_scale=x2_scale)
- 图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
# "ENABLE_ACLNN"是否使能走aclnn, true: 回调走aclnn, false: 在线编译
os.environ["ENABLE_ACLNN"] = "false"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale, pertoken_scale):
return torch_npu.npu_quant_matmul_reduce_sum(x1, x2, x1_scale=pertoken_scale, x2_scale=scale)
cpu_model = MyModel()
model = cpu_model.npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
b,m,k,n = (2,3,4,5)
x1 = torch.ones((b, m, k), dtype=torch.int8).npu()
x2_nd = torch.ones((b, k, n), dtype=torch.int8).npu()
x2 = torch_npu.npu_format_cast(x2_nd.contiguous(), 29)
pertoken_scale = torch.ones((b, m), dtype=torch.float32).npu()
scale = torch.ones((n,), dtype=torch.bfloat16).npu()
npu_out = model(x1, x2, scale, pertoken_scale)
print(npu_out)
"""
)
_add_torch_npu_docstr(
"npu_quant_matmul",
"""
功能描述:
完成量化的矩阵乘计算, 最小支持输入维度为2维, 最大支持输入维度为6维.
接口原型:
torch_npu.npu_quant_matmul(Tensor x1, Tensor x2, Tensor scale, *, Tensor? offset=None, Tensor? pertoken_scale=None, Tensor? bias=None, ScalarType? output_dtype=None) -> Tensor
参数说明:
x1: Tensor类型, 数据格式支持ND, shape需要在2-6维范围.
Atlas 推理系列加速卡产品: 数据类型支持int8.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8和int32. 其中int32表示int4类型矩阵乘计算, 每个int32数据存放8个int4数据.
Atlas A3 训练系列产品: 数据类型支持int8和int32. 其中int32表示int4类型矩阵乘计算, 每个int32数据存放8个int4数据.
x2: Tensor类型(weight), 数据格式支持ND, shape需要在2-6维范围.
Atlas 推理系列加速卡产品: 数据类型与x1的数据类型须保持一致.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型与x1的数据类型保持一致.
Atlas A3 训练系列产品: 数据类型与x1的数据类型保持一致.
scale: Tensor类型, 数据格式支持ND, 如需传入int64数据类型的scale, 需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的scale.
Atlas 推理系列加速卡产品: 数据类型支持float32、int64. shape需要是1维(t, ), t=1或n, 其中n与x2的n一致.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32、int64、bfloat16. shape需要是1维(t, ), t=1或n, 其中n与x2的n一致.
Atlas A3 训练系列产品: 数据类型支持float32、int64、bfloat16. shape需要是1维(t, ), t=1或n, 其中n与x2的n一致.
offset: Tensor类型, 可选参数. 数据类型支持float32, 数据格式支持ND, shape需要是1维(t,), t=1或n, 其中n与x2的n一致.
当x1数据类型为int8时, 才支持该参数.
pertoken_scale: Tensor类型, 可选参数. 数据类型支持float32, 数据格式支持ND.
Atlas 推理系列加速卡产品: 不支持pertoken_scale.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32. shape需要是1维(m,), 其中m与x1的m一致.
Atlas A3 训练系列产品: 数据类型支持float32. shape需要是1维(m,), 其中m与x1的m一致.
bias: Tensor类型, 可选参数, 数据格式支持ND, shape支持1维(n,)、2维(1, n)或3维(batch, 1, n), n与x2的n一致, 同时batch值需要等于x1和x2 boardcast后推导出的batch值. 当输出是2、4、5、6维情况下, bias的shape必须为1维. 当输出是2维情况下, bias的shape可以为1维或2维. 当输出是3维情况下, bias的shape可以为1维或3维.
Atlas 推理系列加速卡产品: 数据类型支持int32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int32、bfloat16、float16、float32.
Atlas A3 训练系列产品: 数据类型支持int32、bfloat16、float16、float32.
output_dtype: ScalarType类型int类型, 可选参数. 表示输出Tensor的数据类型. 默认值为None, 代表输出Tensor数据类型为int8.
Atlas 推理系列加速卡产品: 支持输入torch.int8、torch.float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持输入torch.int8、torch.float16、torch.bfloat16、torch.int32.
Atlas A3 训练系列产品: 支持输入torch.int8、torch.float16、torch.bfloat16、torch.int32.
输出说明:
result: Tensor类型, 代表量化matmul的计算结果.
如果output_dtype为torch.float16, 输出的数据类型为float16.
如果output_dtype为torch.int8或者None, 输出的数据类型为int8.
如果output_dtype为torch.bfloat16, 输出的数据类型为bfloat16.
如果output_dtype为torch.int32, 输出的数据类型为int32.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
传入的x1、x2、scale不能是空.
x1、x2、bias、scale、offset、pertoken_scale、output_dtype的数据类型和数据格式需要在支持的范围之内.
当x1的数据类型为float8_e4m3fn, x2_dtype为torch_npu.float4_e2m1或torch_npu.float4_e1m2的情况下, x1、x2的k值必须是64的倍数并且大小不能超过65535, x2的n值大小不能超过65535. 其他情况, x1与x2最后一维的shape大小不能超过65535.
目前输出int8或float16且无pertoken_scale情况下, 图模式不支持scale直接传入float32数据类型.
如果在PyTorch图模式中使用本接口, 且环境变量ENABLE_ACLNN=false, 则在调用接口前需要对shape为(n, k//8)的x2数据进行转置, 转置过程应写在图中.
支持将x2转为昇腾亲和的数据排布以提高搬运效率. 需要调用torch_npu.npu_format_cast完成输入x2(weight)为昇腾亲和的数据排布功能.
Atlas 推理系列加速卡产品: 必须先将x2转置后再转亲和format.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 推荐x2不转置直接转亲和format.
Atlas A3 训练系列产品: 推荐x2不转置直接转亲和format.
int4类型计算的额外约束:
当x1、x2的数据类型均为int32, 每个int32类型的数据存放8个int4数据. 输入的int32 shape需要将数据原本int4类型时shape的最后一维缩小8倍. int4数据的shape最后一维应为8的倍数, 例如: 进行(m, k)乘(k, n)的int4类型矩阵乘计算时, 需要输入int32类型、shape为(m, k//8)、(k, n//8)的数据, 其中k与n都应是8的倍数. x1只能接受shape为(m, k//8)且数据排布连续的数据, x2可以接受(k, n[g1] //8)且数据排布连续的数据或shape为(k//8, n)且是由数据连续排布的(n, k//8)转置而来的数据.
数据排布连续是指数组中所有相邻的数, 包括换行时内存地址连续, 使用Tensor.is_contiguous返回值为true则表明tensor数据排布连续.
输入参数间支持的数据类型组合情况如下:
表1 Atlas 推理系列产品
x1:int8, int8
x2:int8, int8
scale:int64/float32, int64/float32
offset:None, float32/None
bias:int32/None, int32/None
pertoken_scale:None, None
output_dtype:float16, int8
表1 (Atlas A2 训练系列产品/Atlas 800I A2 推理产品)(Atlas A3 训练系列产品)
x1:int8, int8, int8, int8, int32, int8
x2:int8, int8, int8, int8, int32, int8
scale:int64/float32, int64/float32, float32/bfloat16, float32, int64/float32, float32/bfloat16
offset:None, float32/None, None, None, None, None
bias:int32/None, int32/None, int32/bfloat16/float32/None, int32/float16/float32/None, int32/None, int32/None
pertoken_scale:None, None, float32/None, float32, None, None
output_dtype:float16, int8, bfloat16, float16, float16, int32
支持的PyTorch版本
PyTorch 2.5
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列加速卡产品
Atlas A3 训练系列产品
调用示例:
单算子调用
int8类型输入场景:
import torch
import torch_npu
import logging
import os
cpu_x1 = torch.randint(-5, 5, (1, 256, 768), dtype=torch.int8)
cpu_x2 = torch.randint(-5, 5, (31, 768, 16), dtype=torch.int8)
scale = torch.randn(16, dtype=torch.float32)
offset = torch.randn(16, dtype=torch.float32)
bias = torch.randint(-5, 5, (31, 1, 16), dtype=torch.int32)
# Method 1: You can directly call npu_quant_matmul
npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), offset=offset.npu(), bias=bias.npu())
# Method 2: You can first call npu_trans_quant_param to convert scale and offset from float32 to int64 when output dtype is not torch.bfloat16 and pertoken_scale is none
scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), offset.npu())
npu_out = torch_npu.npu_quant_matmul(cpu_x1.npu(), cpu_x2.npu(), scale_1, bias=bias.npu())
图模式调用(ND数据格式)
输出float16
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
# "ENABLE_ACLNN"是否使能走aclnn, true: 回调走aclnn, false: 在线编译
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale, offset, bias):
return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, output_dtype=torch.float16)
cpu_model = MyModel()
model = cpu_model.npu()
cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8)
cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8)
scale = torch.randn(1, dtype=torch.float32)
# pertoken_scale为空时, 输出fp16必须先调用npu_trans_quant_param, 将scale(offset)从float转为int64.
scale_1 = torch_npu.npu_trans_quant_param(scale.npu(), None)
bias = torch.randint(-1,1, (15, 1, 128), dtype=torch.int32)
# dynamic=True: 动态图模式, dynamic=False: 静态图模式
model = torch.compile(model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale_1, None, bias.npu())
输出bfloat16, 示例代码如下, 仅支持如下产品:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
return torch_npu.npu_quant_matmul(x1, x2.t(), scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale, output_dtype=torch.bfloat16)
cpu_model = MyModel()
model = cpu_model.npu()
m = 15
k = 11264
n = 6912
bias_flag = True
cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
model = torch.compile(model, backend=npu_backend, dynamic=True)
if bias_flag:
npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), None, bias.npu(), pertoken_scale.npu())
else:
npu_out = model(cpu_x1.npu(), cpu_x2.npu(), scale.npu(), None, None, pertoken_scale.npu())
图模式调用(高性能数据排布方式)
将x2转置(batch, n, k)后转format, 示例代码如下, 仅支持Atlas 推理系列加速卡产品.
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale, offset, bias):
return torch_npu.npu_quant_matmul(x1, x2.transpose(2,1), scale, offset=offset, bias=bias)
cpu_model = MyModel()
model = cpu_model.npu()
cpu_x1 = torch.randint(-1, 1, (15, 1, 512), dtype=torch.int8).npu()
cpu_x2 = torch.randint(-1, 1, (15, 512, 128), dtype=torch.int8).npu()
# Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,n,k) layout
cpu_x2_t_29 = torch_npu.npu_format_cast(cpu_x2.transpose(2,1).contiguous(), 29)
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
# Process scale from float32 to int64 offline to improve performance
scale_1 = torch_npu.npu_trans_quant_param(scale, offset)
model = torch.compile(model, backend=npu_backend, dynamic=False)
npu_out = model(cpu_x1, cpu_x2_t_29, scale_1, offset, bias)
将x2非转置(batch, k, n)后转format, 示例代码如下, 仅支持如下产品:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
import numpy as np
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale, offset, bias, pertoken_scale):
return torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias, pertoken_scale=pertoken_scale, output_dtype=torch.bfloat16)
cpu_model = MyModel()
model = cpu_model.npu()
m = 15
k = 11264
n = 6912
bias_flag = True
cpu_x1 = torch.randint(-1, 1, (m, k), dtype=torch.int8)
cpu_x2 = torch.randint(-1, 1, (n, k), dtype=torch.int8)
# Process x2 into a high-bandwidth format(29) offline to improve performance, please ensure that the input is continuous with (batch,k,n) layout
x2_notranspose_29 = torch_npu.npu_format_cast(cpu_x2.npu().transpose(1,0).contiguous(), 29)
scale = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
pertoken_scale = torch.randint(-1,1, (m,), dtype=torch.float32)
bias = torch.randint(-1,1, (n,), dtype=torch.bfloat16)
model = torch.compile(model, backend=npu_backend, dynamic=True)
if bias_flag:
npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, bias.npu(), pertoken_scale.npu())
else:
npu_out = model(cpu_x1.npu(), x2_notranspose_29, scale.npu(), None, None, pertoken_scale.npu())
"""
)
_add_torch_npu_docstr(
"npu_quant_matmul_gelu",
"""
功能描述:
完成量化矩阵乘和GELU激活函数的融合计算, 支持A8W8和A4W4量化. 该接口融合了量化矩阵乘和GELU激活, 减少内存访问, 提升性能.
接口原型:
torch_npu.npu_quant_matmul_gelu(Tensor x1, Tensor x2, Tensor x1_scale, Tensor x2_scale, *, Tensor? bias=None, str? approximate="gelu_erf") -> Tensor
参数说明:
x1: Tensor类型, 输入激活值, 数据格式支持ND, shape需要在2-6维范围.
数据类型支持int8(A8W8量化)、int32(A4W4量化, 每个int32数据存放8个int4数据)、quint4x2(A4W4量化, 直接INT4类型).
x2: Tensor类型(权重), 数据格式支持ND或NZ, shape需要在2-6维范围.
数据类型与x1的数据类型保持一致. 支持昇腾亲和的NZ数据排布格式, 可通过torch_npu.npu_format_cast转换为NZ格式以提升性能(仅A8W8场景).
x1_scale: Tensor类型, x1的量化scale参数, 数据格式支持ND.
数据类型支持float32. shape需要是1维(m,), 其中m与x1的m一致. 采用per-token量化方式, 每个token(行)有一个独立的scale值.
x2_scale: Tensor类型, x2的量化scale参数, 数据格式支持ND.
数据类型支持float32或bfloat16. shape需要是1维(n,)或(1,), 其中n与x2的n一致. 采用per-channel量化方式, 每个输出通道有一个独立的scale值, 或使用per-tensor量化(shape为(1,)).
bias: Tensor类型, 可选参数, 默认值为None, 偏置项, 数据格式支持ND.
数据类型支持int32、float32、bfloat16、float16.
A4W4量化场景下: shape仅支持1维(n,), n与x2的n一致.
A8W8量化场景下: shape支持1维(n,)或3维(batch, 1, n), n与x2的n一致.
approximate: str类型, 可选参数, 默认值为"gelu_erf". 指定GELU激活函数的类型.
支持"gelu_tanh"(GELU的tanh近似版本)和"gelu_erf"(GELU的erf精确版本).
输出说明:
result: Tensor类型, 代表量化矩阵乘融合GELU激活的计算结果.
如果x2_scale的数据类型为float32, 输出的数据类型为float16.
如果x2_scale的数据类型为bfloat16, 输出的数据类型为bfloat16.
输出shape为(batch, m, n), 其中batch根据x1和x2的batch维度广播得到.
约束说明:
该接口支持推理场景下使用.
x1、x2、x1_scale、x2_scale不能为空.
x1、x2的数据类型和数据格式需要在支持的范围之内.
x1、x2最后一维的shape大小不能超过65535.
approximate必须为"gelu_tanh"或"gelu_erf".
对于A4W4量化(INT4/INT32类型输入):
A4W4量化场景支持两种输入类型: quint4x2(直接INT4类型)和int32(打包存储, 每个int32数据存放8个int4数据).
x1和x2的内轴(k轴)必须为偶数.
当x2为int32类型时, x2的shape为(k, n//8), n必须是8的倍数.
当x2为quint4x2类型时, x2的shape为(k, n), n必须是8的倍数.
A4W4量化仅支持ND格式, 不支持NZ格式.
转置信息由算子内部根据tensor的stride自动推导, 无需手动指定.
对于A8W8量化:
支持ND格式和NZ格式.
如果需要使用NZ格式以提升性能, 可以手动调用torch_npu.npu_format_cast完成输入x2(weight)的NZ格式转换.
转置信息由算子内部根据tensor的stride自动推导, 无需手动指定.
支持的PyTorch版本
PyTorch 2.10
PyTorch 2.9
PyTorch 2.8
PyTorch 2.7
PyTorch 2.6
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例:
单算子调用(A8W8量化, ND格式, gelu_tanh激活)
import torch
import torch_npu
# 准备输入数据
m, k, n = 128, 256, 512
x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
x1_scale = torch.randn(m, dtype=torch.float32).abs().npu() * 0.01
x2_scale = torch.randn(n, dtype=torch.float32).abs().npu() * 0.01
# 调用融合算子
output = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale, x2_scale, approximate="gelu_tanh")
print(output.shape) # torch.Size([128, 512])
print(output.dtype) # torch.float16
单算子调用(A8W8量化, ND格式, gelu_erf激活, 带bias)
import torch
import torch_npu
m, k, n = 128, 256, 512
x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
x1_scale = torch.randn(m, dtype=torch.float32).abs().npu() * 0.01
x2_scale = torch.randn(n, dtype=torch.float32).abs().npu() * 0.01
bias = torch.randn(n, dtype=torch.float32).npu() * 0.1
# 使用gelu_erf激活并添加bias
output = torch_npu.npu_quant_matmul_gelu(
x1, x2, x1_scale, x2_scale, bias=bias, approximate="gelu_erf"
)
单算子调用(A8W8量化, NZ格式, gelu_tanh激活)
import torch
import torch_npu
m, k, n = 128, 256, 512
x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
# 将x2转换为NZ格式以提升性能
x2_nz = torch_npu.npu_format_cast(x2.contiguous(), 29) # 29为ACL_FORMAT_FRACTAL_NZ
x1_scale = torch.randn(m, dtype=torch.float32).abs().npu() * 0.01
x2_scale = torch.randn(n, dtype=torch.float32).abs().npu() * 0.01
# 自动识别NZ格式并调用对应接口
output = torch_npu.npu_quant_matmul_gelu(x1, x2_nz, x1_scale, x2_scale, approximate="gelu_tanh")
单算子调用(A8W8量化, BF16输出)
import torch
import torch_npu
m, k, n = 64, 128, 256
x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
x1_scale = torch.randn(m, dtype=torch.float32).abs().npu() * 0.01
x2_scale = torch.randn(n, dtype=torch.bfloat16).abs().npu() * 0.01 # BF16 scale
# 输出数据类型由x2_scale的类型决定, 此处输出为bfloat16
output = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale, x2_scale, approximate="gelu_tanh")
print(output.dtype) # torch.bfloat16
单算子调用(A4W4量化)
import torch
import torch_npu
m, k, n = 128, 256, 512
# 生成INT4数据(以INT32格式存储)
x1_fp = torch.randn(m, k, dtype=torch.float32).npu()
x2_fp = torch.randn(k, n, dtype=torch.float32).npu()
# 量化为INT4
scale_tmp = torch.ones(1, dtype=torch.float32).npu()
x1 = torch_npu.npu_quantize(x1_fp, scale_tmp, None, torch.quint4x2, -1, False)
x2 = torch_npu.npu_quantize(x2_fp, scale_tmp, None, torch.quint4x2, -1, False)
x1_scale = torch.randn(m, dtype=torch.float32).abs().npu() * 0.01
x2_scale = torch.randn(n, dtype=torch.float32).abs().npu() * 0.01
# A4W4量化仅支持ND格式, 不支持NZ格式
output = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale, x2_scale, approximate="gelu_tanh")
单算子调用(使用默认approximate="gelu_erf")
import torch
import torch_npu
m, k, n = 64, 128, 256
x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
x1_scale = torch.randn(m, dtype=torch.float32).abs().npu() * 0.01
x2_scale = torch.randn(n, dtype=torch.float32).abs().npu() * 0.01
# 不指定approximate参数, 使用默认值"gelu_erf"
output = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale, x2_scale)
print(output.dtype) # torch.float16
"""
)
_add_torch_npu_docstr(
"npu_matmul_compress_dequant",
"""
接口原型:
torch_npu.npu_matmul_compress_dequant(Tensor x1, Tensor x2, Tensor compress_index, Tensor bias, Tensor scale, *, Tensor? offsetW=None, int? offsetX=None) -> Tensor
功能描述:
对压缩存储的权重重建后进行矩阵乘与反量化计算。即使用压缩索引(compress_index)对压缩权重(x2)解压,与输入(x1)做矩阵乘,加上偏置(bias),再按 scale 做反量化,得到 float16 结果。适用于 8x8 块压缩的权重量化推理场景。
参数说明:
x1 (Tensor) - 矩阵乘左输入,2 维,形状 [M, K]。数据类型支持 int8。
x2 (Tensor) - 压缩后的权重重建矩阵,与 compress_index 配合使用,数据类型支持 int8。
compress_index (Tensor) - 压缩索引,用于从压缩格式还原权重。
bias (Tensor) - 偏置,2 维,形状与输出的第二维 N 一致,例如 (1, N)。数据类型支持 int32。
scale (Tensor) - 反量化 scale,2 维,第二维 n 与矩阵乘输出的 N 对应,例如 (1, n)。用于将整型结果反量化为 float16。
offsetW (Tensor, 可选) - 权重量化偏移。当前仅支持传入 None,请勿传入具体张量。
offsetX (int, 可选) - 输入量化偏移。当前仅支持 0,默认值为 0。
输出说明:
返回 Tensor,形状 [M, N],其中 M 为 x1 的第 0 维,N 为 bias 的第 1 维。数据格式为 ND,数据类型为 float16。
约束说明:
x1、scale、bias 必须为 2 维。offsetW 当前仅支持 None;offsetX 当前仅支持 0。依赖 CANN 提供的 aclnnMatmulCompressDequant,请确保 CANN 版本支持该算子。
示例:
>>> import torch
>>> import torch_npu
>>> m, k, n = 16, 256, 128
>>> x1 = torch.ones((m, k), dtype=torch.int8).npu()
>>> x2 = torch.zeros((k, n), dtype=torch.int8).npu() # 实际使用时为压缩权重重建结果
>>> compress_index = torch.zeros(8, dtype=torch.int8).npu() # 压缩索引,shape 与压缩格式一致
>>> bias = torch.zeros((1, n), dtype=torch.int32).npu()
>>> scale = torch.ones((1, n), dtype=torch.float32).npu()
>>> out = torch_npu.npu_matmul_compress_dequant(x1, x2, compress_index, bias, scale)
>>> out.shape
torch.Size([16, 128])
"""
)
_add_torch_npu_docstr(
"npu_weight_quant_batchmatmul",
"""
功能描述:
该接口用于实现矩阵乘计算中weight输入和输出的量化操作, 支持per-tensor、per-channel、per-group多场景量化.
不同产品支持的量化算法不同, 如表 支持的量化场景所示.
表1 支持的量化场景产品型号
量化方式
Atlas 推理系列加速卡产品: per-tensor、per-channel
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: per-tensor、per-channel、per-group
Atlas A3 训练系列产品: per-tensor、per-channel、per-group
接口原型:
torch_npu.npu_weight_quant_batchmatmul(Tensor x, Tensor weight, Tensor antiquant_scale, Tensor? antiquant_offset=None, Tensor? quant_scale=None, Tensor? quant_offset=None, Tensor? bias=None, int antiquant_group_size=0, int inner_precise=0) -> Tensor
参数说明:
x : Tensor类型, 即矩阵乘中的x. 数据格式支持ND, 支持带transpose的非连续的Tensor, 支持输入维度为两维(M, K) .
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16.
weight: Tensor类型, 即矩阵乘中的weight. 支持带transpose的非连续的Tensor, 支持输入维度为两维(K, N), 维度需与x保持一致. 当数据格式为ND时, per-channel场景下为提高性能推荐使用transpose后的weight输入.
Atlas 推理系列加速卡产品: 数据类型支持int8. 数据格式支持ND、FRACTAL_NZ, 其中FRACTAL_NZ格式只在“图模式”有效, 需依赖接口torch_npu.npu_format_cast完成ND到FRACTAL_NZ的转换, 可参考调用示例.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、int32(通过int32承载int4的输入, 可参考7.2.1.74-torch_npu.npu_convert_weight_to_int4pack调用示例). 数据格式支持ND、FRACTAL_NZ.
Atlas A3 训练系列产品: 数据类型支持int8、int32(通过int32承载int4的输入, 可参考7.2.1.74-torch_npu.npu_convert_weight_to_int4pack调用示例). 数据格式支持ND、FRACTAL_NZ.
antiquant_scale: Tensor类型, 反量化的scale, 用于weight矩阵反量化, 数据格式支持ND. 支持带transpose的非连续的Tensor. antiquant_scale支持的shape与量化方式相关:
per_tensor模式: 输入shape为(1,)或(1, 1).
per_channel模式: 输入shape为(1, N)或(N,).
per_group模式: 输入shape为(ceil(K, antiquant_group_size), N).
antiquant_scale支持的dtype如下: Atlas 推理系列加速卡产品: 数据类型支持float16, 其数据类型需与x保持一致. Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int64. 若输入为float16、bfloat16, 其数据类型需与x保持一致. 若输入为int64, x数据类型必须为float16且不带transpose输入, 同时weight数据类型必须为int8、数据格式为ND、带transpose输入, 可参考调用示例. 此时只支持per-channel场景, M范围为[1, 96], 且K和N要求64对齐. Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int64. 若输入为float16、bfloat16, 其数据类型需与x保持一致. 若输入为int64, x数据类型必须为float16且不带transpose输入, 同时weight数据类型必须为int8、数据格式为ND、带transpose输入, 可参考调用示例. 此时只支持per-channel场景, M范围为[1, 96], 且K和N要求64对齐.
antiquant_offset: Tensor类型, 反量化的offset, 用于weight矩阵反量化. 可选参数, 默认值为None, 数据格式支持ND, 支持带transpose的非连续的Tensor, 支持输入维度为两维(1, N)或一维(N, )、(1, ).
Atlas 推理系列加速卡产品: 数据类型支持float16, 其数据类型需与antiquant_scale保持一致.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int32. per-group场景shape要求为(ceil_div(K, antiquant_group_size), N).
若输入为float16、bfloat16, 其数据类型需与antiquant_scale保持一致.
若输入为int32, antiquant_scale的数据类型必须为int64.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int32. per-group场景shape要求为(ceil_div(K, antiquant_group_size), N).
若输入为float16、bfloat16, 其数据类型需与antiquant_scale保持一致.
若输入为int32, antiquant_scale的数据类型必须为int64.
quant_scale: Tensor类型, 量化的scale, 用于输出矩阵的量化, 可选参数, 默认值为None, 仅在weight格式为ND时支持. 数据类型支持float32、int64, 数据格式支持ND, 支持输入维度为两维(1, N)或一维(N, )、(1, ). 当antiquant_scale的数据类型为int64时, 此参数必须为空.
Atlas 推理系列加速卡产品: 暂不支持此参数.
quant_offset: Tensor类型, 量化的offset, 用于输出矩阵的量化, 可选参数, 默认值为None, 仅在weight格式为ND时支持. 数据类型支持float32, 数据格式支持ND, 支持输入维度为两维(1, N)或一维(N, )、(1, ). 当antiquant_scale的数据类型为int64时, 此参数必须为空.
Atlas 推理系列加速卡产品: 暂不支持此参数.
bias: Tensor类型, 即矩阵乘中的bias, 可选参数, 默认值为None, 数据格式支持ND, 不支持非连续的Tensor, 支持输入维度为两维(1, N)或一维(N, )、(1, ).
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float32. 当x数据类型为bfloat16, bias需为float32; 当x数据类型为float16, bias需为float16.
Atlas A3 训练系列产品: 数据类型支持float16、float32. 当x数据类型为bfloat16, bias需为float32; 当x数据类型为float16, bias需为float16.
antiquant_group_size: int类型, 用于控制per-group场景下group大小, 其他量化场景不生效. 可选参数. 默认值为0, per-group场景下要求传入值的范围为[32, K-1]且必须是32的倍数.
Atlas 推理系列加速卡产品: 暂不支持此参数.
inner_precise: int类型, 计算模式选择, 默认为0. 0表示高精度模式, 1表示高性能模式, 可能会影响精度. 当weight以int32类型且以FRACTAL_NZ格式输入, M不大于16的per-group场景下可以设置为1, 提升性能. 其他场景不建议使用高性能模式.
输出说明:
输出为Tensor类型, 代表计算结果. 当输入存在quant_scale时输出数据类型为int8, 当输入不存在quant_scale时输出数据类型和输入x一致.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式. 当输入weight为FRACTAL_NZ格式时暂不支持单算子调用, 只支持图模式调用.
x和weight后两维必须为(M, K)和(K, N)格式, K、N的范围为[1, 65535]; 在x为非转置时, M的范围为[1, 2^31-1], 在x为转置时, M的范围为[1, 65535].
不支持空Tensor输入.
antiquant_scale和antiquant_offset的输入shape要保持一致.
quant_scale和quant_offset的输入shape要保持一致, 且quant_offset不能独立于quant_scale存在.
如需传入int64数据类型的quant_scale, 需要提前调用torch_npu.npu_trans_quant_param接口将数据类型为float32的quant_scale和quant_offset转换为数据类型为int64的quant_scale输入, 可参考调用示例.
当输入weight为FRACTAL_NZ格式且类型为int32时, per-channel场景需满足weight为转置输入; per-group场景需满足x为转置输入, weight为非转置输入, antiquant_group_size为64或128, K为antiquant_group_size对齐, N为64对齐.
不支持输入weight shape为(1, 8)且类型为int4, 同时weight带有transpose的场景, 否则会报错x矩阵和weight矩阵K轴不匹配, 该场景建议走非量化算子获取更高精度和性能.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
Atlas 推理系列加速卡产品
调用示例:
单算子模式调用
weight非transpose+quant_scale场景, 仅支持如下产品:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
import torch
import torch_npu
# 输入int8+ND
cpu_x = torch.randn((8192, 320),dtype=torch.float16)
cpu_weight = torch.randint(low=-8, high=8, size=(320, 256),dtype=torch.int8)
cpu_antiquantscale = torch.randn((1, 256),dtype=torch.float16)
cpu_antiquantoffset = torch.randn((1, 256),dtype=torch.float16)
cpu_quantscale = torch.randn((1, 256),dtype=torch.float32)
cpu_quantoffset = torch.randn((1, 256),dtype=torch.float32)
quantscale= torch_npu.npu_trans_quant_param(cpu_quantscale.npu(), cpu_quantoffset.npu())
npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(),quantscale.npu())
weight transpose+antiquant_scale场景
import torch
import torch_npu
# 输入int8+ND
cpu_x = torch.randn((96, 320),dtype=torch.float16)
cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8)
cpu_antiquantscale = torch.randn((256,1),dtype=torch.float16)
cpu_antiquantoffset = torch.randint(-128, 127, (256,1), dtype=torch.float16)
npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.npu().transpose(-1, -2), cpu_antiquantscale.npu().transpose(-1, -2), cpu_antiquantoffset.npu().transpose(-1, -2))
weight transpose+antiquant_scale场景 , 仅支持如下产品:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
Atlas 推理系列加速卡产品
import torch
import torch_npu
cpu_x = torch.randn((96, 320),dtype=torch.float16)
cpu_weight = torch.randint(low=-8, high=8, size=(256, 320),dtype=torch.int8)
cpu_antiquantscale = torch.randn((256),dtype=torch.float16)
# 构建int64类型的scale参数
antiquant_scale = torch_npu.npu_trans_quant_param(cpu_antiquantscale.to(torch.float32).npu()).reshape(256, 1)
cpu_antiquantoffset = torch.randint(-128, 127, (256, 1), dtype=torch.int32)
npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), cpu_weight.transpose(-1,-2).npu(), antiquant_scale.transpose(-1,-2).npu(), cpu_antiquantoffset.transpose(-1,-2).npu())
图模式调用
weight输入为ND格式
# 图模式
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
cpu_x = torch.randn((8192, 320),device='npu',dtype=torch.bfloat16)
cpu_weight = torch.randn((320, 256),device='npu',dtype=torch.int8)
cpu_antiquantscale = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
cpu_antiquantoffset = torch.randn((1, 256),device='npu',dtype=torch.bfloat16)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
cpu_model = MyModel()
model = cpu_model.npu()
model = torch.compile(model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x.npu(), cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
Atlas 推理系列加速卡产品: weight输入为FRACTAL_NZ格式
import torch_npu
import torch
from torchair.configs.compiler_config import CompilerConfig
import torchair as tng
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
class NPUQuantizedLinearA16W8(torch.nn.Module):
def __init__(self,
weight,
antiquant_scale,
antiquant_offset,
quant_offset=None,
quant_scale=None,
bias=None,
transpose_x=False,
transpose_weight=True,
w4=False):
super().__init__()
self.dtype = torch.float16
self.weight = weight.to(torch.int8).npu()
self.transpose_weight = transpose_weight
if self.transpose_weight:
self.weight = torch_npu.npu_format_cast(self.weight.contiguous(), 29)
else:
self.weight = torch_npu.npu_format_cast(self.weight.transpose(0, 1).contiguous(), 29) # n,k ->nz
self.bias = None
self.antiquant_scale = antiquant_scale
self.antiquant_offset = antiquant_offset
self.quant_offset = quant_offset
self.quant_scale = quant_scale
self.transpose_x = transpose_x
def forward(self, x):
x = torch_npu.npu_weight_quant_batchmatmul(x.transpose(0, 1) if self.transpose_x else x,
self.weight.transpose(0, 1),
self.antiquant_scale.transpose(0, 1),
self.antiquant_offset.transpose(0, 1),
self.quant_scale,
self.quant_offset,
self.bias)
return x
m, k, n = 4, 1024, 4096
cpu_x = torch.randn((m, k),dtype=torch.float16)
cpu_weight = torch.randint(1, 10, (k, n),dtype=torch.int8)
cpu_weight = cpu_weight.transpose(-1, -2)
cpu_antiquantscale = torch.randn((1, n),dtype=torch.float16)
cpu_antiquantoffset = torch.randn((1, n),dtype=torch.float16)
cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
model = NPUQuantizedLinearA16W8(cpu_weight.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
model = torch.compile(model, backend=npu_backend, dynamic=True)
out = model(cpu_x.npu())
"""
)
_add_torch_npu_docstr(
"npu_transpose_batchmatmul",
"""
功能描述:
该接口用于实现矩阵乘计算输入和输出的transpose操作。
接口原型:
torch_npu.npu_transpose_batchmatmul(Tensor input, Tensor weight, *, Tensor? bias=None, Tensor? scale=None, int[]? perm_x1=None, int[]? perm_x2=None, int[]? perm_y=None, int? batch_split_factor=1) -> Tensor
参数说明:
- input(Tensor, 计算输入): 必选参数, 一个3D的Device侧Tensor输入,矩阵计算的左矩阵。数据类型支持float32、float16、bfloat16,数据格式支持ND。
- weight(Tensor, 计算输入): 必选参数, 一个3D的Device侧Tensor输入,矩阵计算的右矩阵。数据类型支持float32、float16、bfloat16,数据格式支持ND。
- bias(Tensor, 计算输入): 可选参数, 一个1D的Device侧Tensor输入,矩阵计算的bias参数。数据类型支持float32、float16、bfloat16,数据格式支持ND。
- scale(Tensor, 计算输入): 可选参数, 一个1D的Device侧Tensor输入,矩阵计算量化参数。数据类型支持int64、uint64,数据格式支持ND。
- perm_x1(ListInt, 计算输入): 可选参数, int类型列表,将input在矩阵乘之前排列成[B, M, K]。
- perm_x2(ListInt, 计算输入): 可选参数, int类型列表,将weight在矩阵乘之前排列成[B, K, N]。
- perm_y(ListInt, 计算输入): 可选参数, int类型列表,将y在矩阵乘后重新排列。
- batch_split_factor(Int, 计算输入): 可选参数,声明output_batch的系数,默认是1。
- y(Tensor, 计算输出): 一个3D的Tensor,输出数据类型支持float32、float16、int8、bfloat16。
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 推理系列产品
调用示例:
# 单算子调用
import torch
import torch_npu
M, K, N, Batch = 32, 512, 128, 32
x1 = torch.randn((M, Batch, K), dtype=torch.float16)
x2 = torch.randn((Batch, K, N), dtype=torch.float16)
scale = torch.rand((Batch * N, ), dtype=torch.float32)
scale = torch_npu.npu_trans_quant_param(scale.npu(), round_mode=1)
y = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=scale.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
# 图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, scale):
scale = torch_npu.npu_trans_quant_param(scale, round_mode=1)
output = torch_npu.npu_transpose_batchmatmul(x1, x2, scale=scale,
perm_x1=(1, 0, 2), perm_x2=(0, 1, 2),
perm_y=(1, 0, 2))
return output
M, K, N, Batch = 32, 512, 128, 32
x1 = torch.randn((M, Batch, K), dtype=torch.float16)
x2 = torch.randn((Batch, K, N), dtype=torch.float16)
scale = torch.rand((Batch * N, ), dtype=torch.float32)
model = Model().npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
y = model(x1.npu(), x2.npu(), scale.npu())
"""
)
_add_torch_npu_docstr(
"npu_transpose_quant_batchmatmul",
"""
"""
)
_add_torch_npu_docstr(
"npu_convert_weight_to_int4pack",
"""
功能描述:
将int32类型的输入tensor打包为int4存放, 每8个int4数据通过一个int32数据承载, 并进行交叠排放.
接口原型:
torch_npu.npu_convert_weight_to_int4pack(Tensor weight, int inner_k_tiles=0) -> Tensor
参数说明:
weight : Tensor类型, 输入的weight, 数据格式支持ND、FRACTAL_NZ, 数据类型支持int32, 不支持非连续的Tensor; 维度支持2维, shape支持(k, n)、 (n, k), 最后一维度需要8个元素对齐, 元素的值需要在int4的表示范围内, 即[-8, 7].
inner_k_tiles: int类型, 用于制定内部打包格式中, 多少个K-tiles被打包在一起, 默认值为0. 预留参数, 暂未使用.
输出说明:
输出为Tensor类型, 代表int4打包后的输出, 数据类型为int32, shape为(k, n/8), (n, k/8), 数据格式支持ND.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3.1
PyTorch 2.0
PyTorch 2.1
PyTorch 2.2
PyTorch 1.11
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
m = 128
k = 64
n = 32
trans_weight = False
cpu_x = torch.randn((m, k), dtype=torch.float16)
if trans_weight:
cpu_weight = torch.randint(low=-8, high=8, size=(n, k), dtype=torch.int32)
cpu_antiquantscale = torch.randn((n, 1), dtype=torch.float16)
cpu_antiquantoffset = torch.randn((n, 1), dtype=torch.float16)
else:
cpu_weight = torch.randint(low=-8, high=8, size=(k, n), dtype=torch.int32)
cpu_antiquantscale = torch.randn((1, n), dtype=torch.float16)
cpu_antiquantoffset = torch.randn((1, n), dtype=torch.float16)
weight_int4 = torch_npu.npu_convert_weight_to_int4pack(cpu_weight.npu())
if trans_weight:
cpu_weight = cpu_weight.transpose(-1, -2)
weight_int4 = weight_int4.transpose(-1, -2)
cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
npu_out = torch_npu.npu_weight_quant_batchmatmul(cpu_x.npu(), weight_int4.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
图模式调用
import torch
import torch_npu
import torchair
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
m = 16
k = 17
n = 72
trans_weight = False
is_weight_nz = False
cpu_x = torch.randn((m, k),dtype=torch.float16)
if trans_weight:
cpu_weight = torch.randint(low=-8, high=8, size=(n, k) ,dtype=torch.int32)
cpu_antiquantscale = torch.ones((n, 1),dtype=torch.float16)
cpu_antiquantoffset = torch.zeros((n, 1),dtype=torch.float16)
else:
cpu_weight = torch.randint(low=-8, high=8, size=(k, n) ,dtype=torch.int32)
cpu_antiquantscale = torch.ones((1, n),dtype=torch.float16)
cpu_antiquantoffset = torch.zeros((1, n),dtype=torch.float16)
npu_weight = cpu_weight.npu()
if is_weight_nz:
# nd to fractal_nz
npu_weight = torch_npu.npu_format_cast(npu_weight.npu(), 29)
# int32 to int4pack
weight_int4pack = torch_npu.npu_convert_weight_to_int4pack(npu_weight)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, antiquant_scale, antiquant_offset, quant_scale,quant_offset, bias, antiquant_group_size):
if trans_weight:
weight = weight.transpose(-1, -2)
antiquant_scale = antiquant_scale.transpose(-1, -2)
antiquant_offset = antiquant_offset.transpose(-1, -2)
return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale ,quant_offset, bias, antiquant_group_size)
cpu_model = MyModel()
model = cpu_model.npu()
model = torch.compile(model, backend=npu_backend, dynamic=True, fullgraph=True)
npu_out = model(cpu_x.npu(), weight_int4pack, cpu_antiquantscale.npu(), cpu_antiquantoffset.npu(), None, None, None, 0)
"""
)
_add_torch_npu_docstr(
"npu_grouped_matmul",
"""
功能描述:
算子功能: npu_grouped_matmul是一种对多个矩阵乘法(matmul)操作进行分组计算的高效方法. 该API实现了对多个矩阵乘法操作的批量处理, 通过将具有相同形状或相似形状的矩阵乘法操作组合在一起, 减少内存访问开销和计算资源的浪费, 从而提高计算效率.
计算公式:
非量化场景(公式1):
y_{i}=x_{i}×weight_{i}+bias_{i}
per-channel量化场景 (公式2):
y_{i}=(x_{i}×weight_{i}+bias_{i})×scale_{i}+offset_{i}
per-token量化场景 (公式3):
y_{i}=(x_{i}×weight_{i}+bias_{i})×scale_{i}+pertokenscale_{i}
伪量化场景 (公式4):
y_{i}=x_{i}×(weight_{i}+antiquant_offset_{i})×antiquantscale_{i}+bias_{i}
接口原型:
npu_grouped_matmul(x, weight, *, bias=None, scale=None, offset=None, antiquant_scale=None, antiquant_offset=None, per_token_scale=None, group_list=None, activation_input=None, activation_quant_scale=None, activation_quant_offset=None, split_item=0, group_type=None, group_list_type=0, act_type=0, output_dtype=None, int[]? tuning_config) -> List[torch.Tensor]
参数说明:
x (List[torch.Tensor]): 输入矩阵列表, 表示矩阵乘法中的左矩阵.
支持的数据类型如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: torch.float16、torch.float32、torch.bfloat16和torch.int8.
Atlas 推理系列产品: torch.float16. .
列表最大长度为128.
当split_item=0时, 张量支持2至6维输入; 其他情况下, 张量仅支持2维输入.
weight (List[torch.Tensor]): 权重矩阵列表, 表示矩阵乘法中的右矩阵.
支持的数据类型如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:
当group_list输入类型为List[int]时, 支持torch.float16、torch.float32、torch.bfloat16和torch.int8.
当group_list输入类型为torch.Tensor时, 支持torch.float16、torch.float32、torch.bfloat16、int4和torch.int8.
Atlas 推理系列产品: torch.float16.
列表最大长度为128.
每个张量支持2维或3维输入.
bias (List[torch.Tensor]): 每个分组的矩阵乘法输出的独立偏置项.
支持的数据类型如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: torch.float16、torch.float32和torch.int32.
Atlas 推理系列产品: torch.float16.
列表长度与weight列表长度相同.
每个张量仅支持1维输入.
scale (List[torch.Tensor]): 用于缩放原数值以匹配量化后的范围值, 代表量化参数中的缩放因子, 对应公式(2)、公式(3)和公式(5).
支持的数据类型如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:
当group_list输入类型为List[int]时, 支持torch.int64.
当group_list输入类型为torch.Tensor时, 支持torch.float32、torch.bfloat16和torch.int64.
Atlas 推理系列产品: 仅支持传入None. .
列表长度与weight列表长度相同.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品每个张量仅支持1维输入.
offset (List[torch.Tensor]): 用于调整量化后的数值偏移量, 从而更准确地表示原始浮点数值, 对应公式(2). 当前仅支持传入None.
antiquant_scale (List[torch.Tensor]): 用于缩放原数值以匹配伪量化后的范围值, 代表伪量化参数中的缩放因子, 对应公式(4).
支持的数据类型如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: torch.float16、torch.bfloat16.
Atlas 推理系列产品: 仅支持传入None.
列表长度与weight列表长度相同.
每个张量支持输入维度如下(其中g为matmul组数, G为per-group数, Gi为第i个tensor的per-group数):
伪量化per-channel场景, weight为单tensor时, shape限制为[g, n]; weight为多tensor时, shape限制为[ni].
伪量化per-group场景, weight为单tensor时, shape限制为[g, G, n]; weight为多tensor时, shape限制为[Gi, ni].
antiquant_offset (List[torch.Tensor]): 用于调整伪量化后的数值偏移量, 从而更准确地表示原始浮点数值, 对应公式(4).
支持的数据类型如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: torch.float16、torch.bfloat16.
Atlas 推理系列产品: 仅支持传入None.
列表长度与weight列表长度相同.
每个张量输入维度和antiquant_scale输入维度一致.
per_token_scale (List[torch.Tensor]): 用于缩放原数值以匹配量化后的范围值, 代表per-token量化参数中由x量化引入的缩放因子, 对应公式(3)和公式(5).
group_list输入类型为List[int]时, 当前只支持传入None.
group_list输入类型为torch.Tensor时:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持torch.float32.
列表长度与x列表长度相同.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 每个张量仅支持1维输入.
group_list (List[int]/torch.Tensor): 用于指定分组的索引, 表示x的第0维矩阵乘法的索引情况. 数据类型支持torch.int64.
Atlas 推理系列产品: 仅支持torch.Tensor类型.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持List[int]或torch.Tensor类型.
Atlas 推理系列产品: 每个张量仅支持1维输入, 长度与weight列表长度相同.
和Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 每个张量仅支持1维输入, 长度与weight列表长度相同.
配置值要求如下:
group_list输入类型为List[int]时, 配置值必须为非负递增数列, 且长度不能为1.
group_list输入类型为torch.Tensor时:
当group_list_type为0时, group_list必须为非负单调非递减数列.
当group_list_type为1时, group_list必须为非负数列, 且长度不能为1.
activation_input (List[torch.Tensor]): 代表激活函数的反向输入, 当前仅支持传入None.
activation_quant_scale (List[torch.Tensor]): 预留参数, 当前只支持传入None.
activation_quant_offset (List[torch.Tensor]): 预留参数, 当前只支持传入None.
split_item (int): 用于指定切分模式. 数据类型支持torch.int32.
0/1: 输出为多个张量, 数量与weight相同.
2/3: 输出为单个张量.
group_type (int): 代表需要分组的轴. 数据类型支持torch.int32.
group_list输入类型为List[int]时仅支持传入None.
group_list输入类型为torch.Tensor时, 若矩阵乘为C[m,n]=A[m,k]xB[k,n], group_type支持的枚举值为: -1代表不分组; 0代表m轴分组; 1代表n轴分组.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当前支持取-1、0.
Atlas 推理系列产品: 当前只支持取0.
group_list_type (int): 代表group_list的表达形式. 数据类型支持torch.int32.
group_list输入类型为List[int]时仅支持传入None.
group_list输入类型为torch.Tensor时:
可取值0或1, 0代表group_list_type中数值为分组轴大小的cumsum结果(累积和), 1代表group_list_type中数值为分组轴上每组大小.
act_type (int): 代表激活函数类型. 数据类型支持torch.int32.
group_list输入类型为List[int]时仅支持传入None.
group_list输入类型为torch.Tensor时, 支持的枚举值包括: 0代表不激活; 1代表RELU激活; 2代表GELU_TANH激活; 3代表暂不支持; 4代表FAST_GELU激活; 5代表SILU激活.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 取值范围为0-5.
Atlas 推理系列产品: 当前只支持传入0.
output_dtype (torch.dtype): 输出数据类型. 支持的配置包括:
None: 默认值, 表示输出数据类型与输入x的数据类型相同.
与输出y数据类型一致的类型, 具体参考约束说明.
输出说明:
List[torch.Tensor]: 当split_item为0或1时, 返回的张量数量与weight相同. 当split_item为2或3时, 返回的张量数量为1.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品的内轴限制InnerLimit为65536. x和weight中每一组tensor的最后一维大小都应小于InnerLimit. xi的最后一维指当x不转置时xi的K轴或当x转置时xi的M轴. weighti的最后一维指当weight不转置时weighti的N轴或当weight转置时weighti的K轴.
各场景输入与输出数据类型使用约束:
group_list输入类型为List[int]时, Atlas A2 训练系列产品/Atlas 800I A2 推理产品数据类型使用约束:
表1 数据类型约束场景
非量化
x: torch.float16, torch.bfloat16, torch.float32
weight: torch.float16, torch.bfloat16, torch.float32
bias: torch.float16, torch.float32, torch.float32
scale: 无需赋值, 无需赋值, 无需赋值
antiquant_scale: 无需赋值, 无需赋值, 无需赋值
antiquant_offset: 无需赋值, 无需赋值, 无需赋值
output_dtype: torch.float16, torch.bfloat16, torch.float32
y: torch.float16, torch.bfloat16, torch.float32
per-channel量化
x: torch.int8
weight: torch.int8
bias: torch.int32
scale: torch.int64
antiquant_scale: 无需赋值
antiquant_offset: 无需赋值
output_dtype: torch.int8
y: torch.int8
伪量化
x: torch.float16, torch.bfloat16
weight: torch.int8, torch.int8
bias: torch.float16, torch.float32
scale: 无需赋值, 无需赋值
antiquant_scale: torch.float16, torch.bfloat16
antiquant_offset: torch.float16, torch.bfloat16
output_dtype: torch.float16, torch.bfloat16
y: torch.float16, torch.bfloat16
group_list输入类型为torch.Tensor时, 数据类型使用约束:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:
表1 数据类型约束场景
非量化
x: torch.float16, torch.bfloat16, torch.float32
weight: torch.float16, torch.bfloat16, torch.float32
bias: torch.float16, torch.float32, torch.float32
scale: 无需赋值, 无需赋值, 无需赋值
antiquant_scale: 无需赋值, 无需赋值, 无需赋值
antiquant_offset: 无需赋值, 无需赋值, 无需赋值
per_token_scale: 无需赋值, 无需赋值, 无需赋值
output_dtype: None/torch.float16, None/torch.bfloat16, None/torch.float32(仅x/weight/y均为单张量)
y: torch.float16, torch.bfloat16,torch.float32
per-channel量化
x: torch.int8, torch.int8, torch.int8
weight: torch.int8, torch.int8, torch.int8
bias: torch.int32, torch.int32, torch.int32
scale: torch.int64, torch.bfloat16, torch.float32
antiquant_scale: 无需赋值, 无需赋值, 无需赋值
antiquant_offset: 无需赋值, 无需赋值, 无需赋值
per_token_scale: 无需赋值, 无需赋值, 无需赋值
output_dtype: None/torch.int8, torch.bfloat16, torch.float16
y: torch.int8, torch.bfloat16, torch.float16
per-token量化
x: torch.int8, torch.int8
weight: torch.int8, torch.int8
bias: torch.int32, torch.int32
scale: torch.bfloat16, torch.float32
antiquant_scale: 无需赋值, 无需赋值
antiquant_offset: 无需赋值, 无需赋值
per_token_scale: torch.float32, torch.float32
output_dtype: torch.bfloat16, torch.float16
y: torch.bfloat16, torch.float16
伪量化
x: torch.float16, torch.bfloat16
weight: torch.int8/int4, torch.int8/int4
bias: torch.float16, torch.float32
scale: 无需赋值, 无需赋值
antiquant_scale: torch.float16, torch.bfloat16
antiquant_offset: torch.float16, torch.bfloat16
per_token_scale: 无需赋值, 无需赋值
output_dtype: None/torch.float16, None/torch.bfloat16
y: torch.float16, torch.bfloat16
伪量化场景, 若weight的类型为torch.int8, 仅支持per-channel模式; 若weight的类型为int4, 支持per-channel和per-group两种模式. 若为per-group, per-group数G或Gi必须要能整除对应的ki. 若weight为多tensor, 定义per-group长度si = ki / Gi, 要求所有si(i=1,2,...g)都相等.
伪量化场景, 若weight的类型为int4, 则weight中每一组tensor的最后一维大小都应是偶数. weighti的最后一维指weight不转置时weighti的N轴或当weight转置时weighti的K轴. 并且在per-group场景下, 当weight转置时, 要求per-group长度si是偶数. tensor转置: 指若tensor shape为[M,K]时, 则stride为[1,M],数据排布为[K,M]的场景, 即非连续tensor.
当前PyTorch不支持int4类型数据, 需要使用时可以通过torch_npu.npu_quantize接口使用torch.int32数据表示int4.
Atlas 推理系列产品:
表1 数据类型约束
x: torch.float16
weight: torch.float16
bias: torch.float16
scale: 无需赋值
antiquant_scale: 无需赋值
antiquant_offset: 无需赋值
per_token_scale: torch.float32
output_dtype: torch.float16
y: torch.float16
根据输入x、输入weight与输出y的Tensor数量不同, 支持以下几种场景. 场景中的“单”表示单个张量, “多”表示多个张量. 场景顺序为x、weight、y, 例如“单多单”表示x为单张量, weight为多张量, y为单张量.
group_list输入类型为List[int]时, Atlas A2 训练系列产品/Atlas 800I A2 推理产品各场景的限制.
场景说明
多多多: x和weight为多张量, y为多张量. 每组数据的张量是独立的.
单多单: x为单张量, weight为多张量, y为单张量.
单多多: x为单张量, weight为多张量, y为多张量.
多多单: x和weight为多张量, y为单张量. 每组矩阵乘法的结果连续存放在同一个张量中.
场景限制
多多多: 仅支持split_item为0或1. x中tensor要求维度一致, 支持2-6维, weight中tensor需为2维, y中tensor维度和x保持一致. x中tensor大于2维, group_list必须传空. x中tensor为2维且传入group_list, group_list的差值需与x中tensor的第一维一一对应.
单多单: 仅支持split_item为2或3. 必须传group_list, 且最后一个值与x中tensor的第一维相等. x、weight、y中tensor需为2维. weight中每个tensor的N轴必须相等.
单多多: 仅支持split_item为0或1. 必须传group_list, group_list的差值需与y中tensor的第一维一一对应. x、weight、y中tensor需为2维.
多多单: 仅支持split_item为2或3. x、weight、y中tensor需为2维. weight中每个tensor的N轴必须相等. 若传入group_list, group_list的差值需与x中tensor的第一维一一对应.
group_list输入类型为torch.Tensor时, 各场景的限制.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:
量化、伪量化仅支持group_type为-1和0场景.
仅per-token量化场景支持激活函数计算.
group_type
-1: 多多多, x和weight为多张量, y为多张量. 每组数据的张量是独立的.
0: 单单单, x、weight与y均为单张量.
0: 单多单, x为单张量, weight为多张量, y为单张量.
0: 多多单, x和weight为多张量, y为单张量. 每组矩阵乘法的结果连续存放在同一个张量中.
场景限制
-1: 仅支持split_item为0或1. x中tensor要求维度一致, 支持2-6维, weight中tensor需为2维, y中tensor维度和x保持一致. group_list必须传空. 支持weight转置, 但weight中每个tensor是否转置需保持统一. x不支持转置.
0: 仅支持split_item为2或3. weight中tensor需为3维, x、y中tensor需为2维. 必须传group_list, 且当group_list_type为0时, 最后一个值与x中tensor的第一维相等, 当group_list_type为1时, 数值的总和与x中tensor的第一维相等. group_list第1维最大支持1024, 即最多支持1024个group. 支持weight转置. x不支持转置.
0: 仅支持split_item为2或3. 必须传group_list, 且当group_list_type为0时, 最后一个值与x中tensor的第一维相等, 当group_list_type为1时, 数值的总和与x中tensor的第一维相等, 长度最大为128. x、weight、y中tensor需为2维. weight中每个tensor的N轴必须相等. 支持weight转置, 但weight中每个tensor是否转置需保持统一. x不支持转置.
0: 仅支持split_item为2或3. x、weight、y中tensor需为2维. weight中每个tensor的N轴必须相等. 若传入group_list, 当group_list_type为0时, group_list的差值需与x中tensor的第一维一一对应, 当group_list_type为1时, group_list的数值需与x中tensor的第一维一一对应, 且长度最大为128. 支持weight转置, 但weight中每个tensor是否转置需保持统一. x不支持转置.
Atlas 推理系列产品:
输入输出只支持float16的数据类型, 输出y的n轴大小需要是16的倍数.
group_type
0: 单单单, x、weight与y均为单张量
场景限制
0: 仅支持split_item为2或3. weight中tensor需为3维, x、y中tensor需为2维. 必须传group_list, 且当group_list_type为0时, 最后一个值与x中tensor的第一维相等, 当group_list_type为1时, 数值的总和与x中tensor的第一维相等. group_list第1维最大支持1024, 即最多支持1024个group. 支持weight转置, 不支持x转置.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 2.0
PyTorch 1.11
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列产品
调用示例:
单算子模式调用
通用调用示例
import torch
import torch_npu
x1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16)
x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16)
x = [x1, x2, x3]
weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16)
weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16)
weight = [weight1, weight2, weight3]
bias1 = torch.randn(256, device='npu', dtype=torch.float16)
bias2 = torch.randn(1024, device='npu', dtype=torch.float16)
bias3 = torch.randn(128, device='npu', dtype=torch.float16)
bias = [bias1, bias2, bias3]
group_list = None
split_item = 0
npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, group_list=group_list, split_item=split_item, group_type=-1)
图模式调用
import torch
import torch.nn as nn
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class GMMModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight):
return torch_npu.npu_grouped_matmul(x, weight, group_type=-1)
def main():
x1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16)
x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16)
x = [x1, x2, x3]
weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16)
weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16)
weight = [weight1, weight2, weight3]
model = GMMModel().npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
custom_output = model(x, weight)
if __name__ == '__main__':
main()
"""
)
_add_torch_npu_docstr(
"npu_rms_norm_quant",
"""
接口原型
npu_rms_norm_quant(Tensor x, Tensor gamma, Tensor beta, Tensor scale, Tensor offset, float epsilon=1e-06) -> Tensor
功能描述
RmsNormQuant算子将RmsNorm算子以及RmsNorm后的Quantize算子融合起来,减少搬入搬出的操作。
参数说明
x: Device侧的Tensor类型,标准化输入张量。shape支持1-8维,数据类型支持FLOAT16、BFLOAT16,格式支持ND。不支持空Tensor。
gamma: Device侧的Tensor类型,归一化权重张量。shape为1-2维,需与x最后一维一致,数据类型与x一致。格式支持ND。不支持空Tensor。
beta: Device侧的Tensor类型,归一化偏置项。shape和数据类型与x一致。格式支持ND。不支持空Tensor。
scale: Device侧的Tensor类型,量化过程中得到y进行的scale张量,shape为1,维度为1.格式支持ND。不支持空Tensor。
offset: Device侧的Tensor类型,量化过程中得到y进行的offset张量.shape与scale保持一致,格式支持ND。不支持空Tensor。
epsilon: double类型,防止除0错误,默认值为1e-6.
输出说明
y: Device侧的Tensor类型。数据类型支持INT8。shape、数据格式需要与入参x保持一致。支持非连续的Tensor,不支持空Tensor。
约束说明
x、y的尾轴长度,以及gamma的尾轴长度必大于等于32Bytes.
支持的型号
Atlas A3训练系列产品/Atlas A3推理系列产品
Atlas A2训练系列产品/Atlas 800I A2推理产品/A200I A2 Box异构组件
调用示例
import torch
import torch_npu
eps = 1e-6
x = torch.randn(16, dtype = torch.float16).npu()
gamma = torch.randn(16, dtype = torch.float16).npu()
beta = torch.zeros(16, dtype = torch.float16).npu()
scale = torch.ones(1, dtype = torch.float16).npu()
offset = torch.zeros(1, dtype = torch.int8).npu()
y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset, eps)
_ = y.cpu().numpy()
"""
)
_add_torch_npu_docstr(
"npu_grouped_matmul_finalize_routing",
"""
功能描述:
GroupedMatmul和MoeFinalizeRouting的融合算子,GroupedMatmul计算后的输出按照索引做combine动作。
接口原型:
torch_npu.npu_grouped_matmul_finalize_routing(Tensor x, Tensor weight, Tensor group_list, *, Tensor? scale=None, Tensor? bias=None, Tensor? pertoken_scale=None, Tensor? shared_input=None, Tensor? logit=None, Tensor? row_index=None, ScalarType? dtype=None, float? shared_input_weight=1.0, int? shared_input_offset=0, int? output_bs=0, int? group_list_type=1, int[]? tuning_config) -> Tensor
参数说明:
- x(Tensor, 计算输入): 必选参数,一个2D的Device侧Tensor输入,矩阵计算的左矩阵,不支持非连续的Tensor。数据类型支持int8,数据格式支持ND,维度为(m,k)。m取值范围为[1, 16*1024*8],K取值为16整倍数。
- weight(Tensor, 计算输入): 必选参数,一个3D的Device侧Tensor输入,矩阵计算的右矩阵,不支持非连续的Tensor。数据类型支持int8、int4。a8w8场景下数据格式支持NZ,维度为(e,k,n),e取值范围为[1, 256],n取值为32整数倍且大于等于256,a8w4场景下数据格式支持ND,维度为(e,k,n),k只支持2048,n只支持7168。
- group_list(Tensor, 计算输入): 必选参数,一个1D的Device侧Tensor输入,GroupedMatMul的各分组大小值,不支持非连续的Tensor。数据类型支持int64,数据格式支持ND,维度为(e,),group_list的值的总和要求小于等于m。
- scale(Tensor, 计算输入): 可选参数,Device侧Tensor输入,矩阵计算反量化参数,对应weight矩阵,不支持非连续的Tensor。a8w8场景下是2D的Tensor,数据类型支持float32,数据格式支持ND,支持per-channel量化方式,维度为(e,n);a8w4场景下是3D的Tensor,数据类型支持int64,维度为(e,1,n)。
- bias(Tensor, 计算输入): 可选参数,一个2D的Device侧Tensor输入,矩阵计算的bias参数,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND,只支持a8w4场景。
- offset(Tensor, 计算输入): 可选参数,Device侧Tensor输入,矩阵计算量化参数的偏移量,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND,只支持a8w4场景。
- pertoken_scale(Tensor, 计算输入): 可选参数,一个1D的Device侧Tensor输入,矩阵计算的反量化参数,对应x矩阵,per-token量化方式,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND,维度为(m,)。
- shared_input(Tensor, 计算输入): 可选参数,一个2D的Device侧Tensor输入,moe计算中共享专家的输出,需要与moe专家的输出进行combine操作,不支持非连续的Tensor。数据类型支持bfloat16,数据格式支持ND,维度为(batch/dp,n),batch/dp取值范围[1, 2*1024],batch取值范围[1, 16*1024]。
- logit(Tensor, 计算输入): 可选参数,一个1D的Device侧Tensor输入,moe专家对各个token的logit大小,矩阵乘的计算输出与该logit做乘法,然后索引进行combine,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND,维度为(m,)。
- row_index(Tensor*, 计算输入): 可选参数,一个1D的Device侧Tensor输入,moe专家输出按照该rowIndex进行combine,其中的值即为combine做scatter add的索引,不支持非连续的Tensor。数据类型支持int32、int64,数据格式支持ND,维度为(m,)。
- dtype(torch.dtype, 计算输入): 可选参数,指定GroupedMatMul计算的输出类型。枚举值含义:0表示float32,1表示float16,2表示bfloat16。默认值为0。
- shared_input_weight(float, 计算输入): 可选参数,float类型,指共享专家与moe专家进行combine的系数,shared_input先与该参数相乘,然后再和moe专家结果累加。默认为1.0。
- shared_input_offset(int, 计算输入): 可选参数,共享专家输出在总输出中的偏移。默认值为0.
- output_bs(int, 计算输入): 可选参数,输出的最高维大小。默认值为0。
- group_list_type(int, 计算输入): 可选参数,GroupedMatMul的分组模式,0为cumsum模式,1为count模式,默认为1。
- tuning_config:(ListInt, 计算输入): 可选参数,数组中第一个值表示各个专家处理的token数的预期值,算子tiling时会按照该预期值进行最优tiling。
- y(Tensor, 计算输出): 2D的Tensor,不支持非连续的Tensor,输出的数据类型固定为float32,维度为(batch, n)。
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 推理系列产品
调用示例:
# 单算子调用
import numpy as np
import torch
import torch_npu
from scipy.special import softmax
m, k, n = 576, 2048, 7168
batch = 72
topK = 8
group_num = 8
x = np.random.randint(-10, 10, (m, k)).astype(np.int8)
weight = np.random.randint(-10, 10, (group_num, k, n)).astype(np.int8)
scale = np.random.normal(0, 0.01, (group_num, n)).astype(np.float32)
pertoken_scale = np.random.normal(0, 0.01, (m, )).astype(np.float32)
group_list = np.array([batch] * group_num, dtype=np.int64)
shared_input = np.random.normal(0, 0.1, (batch // 4, n)).astype(np.float32)
logit_ori = np.random.normal(0, 0.1, (batch, group_num)).astype(np.float32)
routing = np.argsort(logit_ori, axis=1)[:, -topK:].astype(np.int32)
logit = softmax(logit_ori[np.arange(batch).reshape(-1, 1).repeat(topK, axis=1), routing], axis=1).astype(np.float32)
logit = logit.reshape(m)
row_index = (np.argsort(routing.reshape(-1)) // topK).astype(np.int64)
x_clone = torch.from_numpy(x).npu()
weight_clone = torch.from_numpy(weight).npu()
weightNz = torch_npu.npu_format_cast(weight_clone, 29)
scale_clone = torch.from_numpy(scale).npu()
pertoken_scale_clone = torch.from_numpy(pertoken_scale).npu()
group_list_clone = torch.from_numpy(group_list).npu()
shared_input_clone = torch.from_numpy(shared_input).to(torch.bfloat16).npu()
logit_clone = torch.from_numpy(logit).npu()
row_index_clone = torch.from_numpy(row_index).npu()
shared_input_offset = batch // 2
output_bs = batch
y = torch_npu.npu_grouped_matmul_finalize_routing(x_clone, weightNz,
group_list_clone, scale=scale_clone, pertoken_scale=pertoken_scale_clone,
shared_input=shared_input_clone, logit=logit_clone, row_index=row_index_clone,
shared_input_offset=shared_input_offset, output_bs=output_bs)
# 图模式调用
import numpy as np
import torch
import torch_npu
import torchair as tng
from scipy.special import softmax
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, group_list, scale, pertoken_scale, shared_input, logit, row_index, shared_input_offset, output_bs):
output = torch_npu.npu_grouped_matmul_finalize_routing(x, weight, group_list,
scale=scale, pertoken_scale=pertoken_scale, shared_input=shared_input,
logit=logit, row_index=row_index, shared_input_offset=shared_input_offset, output_bs=output_bs)
return output
m, k, n = 576, 2048, 7168
batch = 72
topK = 8
group_num = 8
x = np.random.randint(-10, 10, (m, k)).astype(np.int8)
weight = np.random.randint(-10, 10, (group_num, k, n)).astype(np.int8)
scale = np.random.normal(0, 0.01, (group_num, n)).astype(np.float32)
pertoken_scale = np.random.normal(0, 0.01, (m, )).astype(np.float32)
group_list = np.array([batch] * group_num, dtype=np.int64)
shared_input = np.random.normal(0, 0.1, (batch // 4, n)).astype(np.float32)
logit_ori = np.random.normal(0, 0.1, (batch, group_num)).astype(np.float32)
routing = np.argsort(logit_ori, axis=1)[:, -topK:].astype(np.int32)
logit = softmax(logit_ori[np.arange(batch).reshape(-1, 1).repeat(topK, axis=1), routing], axis=1).astype(np.float32)
logit = logit.reshape(m)
row_index = (np.argsort(routing.reshape(-1)) // topK).astype(np.int64)
x_clone = torch.from_numpy(x).npu()
weight_clone = torch.from_numpy(weight).npu()
weightNz = torch_npu.npu_format_cast(weight_clone, 29)
scale_clone = torch.from_numpy(scale).npu()
pertoken_scale_clone = torch.from_numpy(pertoken_scale).npu()
group_list_clone = torch.from_numpy(group_list).npu()
shared_input_clone = torch.from_numpy(shared_input).to(torch.bfloat16).npu()
logit_clone = torch.from_numpy(logit).npu()
row_index_clone = torch.from_numpy(row_index).npu()
shared_input_offset = batch // 2
output_bs = batch
model = Model().npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
y = model(x_clone, weightNz, group_list_clone, scale_clone, pertoken_scale_clone, shared_input_clone,
logit_clone, row_index_clone, shared_input_offset, output_bs)
"""
)
_add_torch_npu_docstr(
"npu_quant_scatter",
"""
功能描述:
先将updates进行量化, 然后将updates中的值按指定的轴axis和索引indices更新input中的值, 并将结果保存到输出tensor, input本身的数据不变.
接口原型:
torch_npu.npu_quant_scatter(Tensor input, Tensor indices, Tensor updates, Tensor quant_scales, Tensor? quant_zero_points=None, int axis=-2, int quant_axis=-1, str reduce='update', int? dst_type=None, str? round_mode='rint') -> Tensor
参数说明:
input: Tensor类型, 必选输入, 源数据张量, 数据类型支持int8, 数据格式支持ND, 支持非连续的Tensor, 维数只能是3~8维.
indices: Tensor类型, 必选输入, 索引张量, 数据类型支持int32, 数据格式支持ND, 支持非连续的Tensor.
updates: Tensor类型, 必选输入, 更新数据张量, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持bfloat16、float16.
quant_scales: Tensor类型, 必选输入, 量化缩放张量, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持bfloat16、float32.
quant_zero_points: Tensor类型, 可选输入, 量化偏移张量, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持int32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持bfloat16、int32.
axis: int类型, 可选参数, updates上用来更新的轴, 默认值为-2.
quant_axis: int类型, 可选参数, updates上用来量化的轴, 默认值为-1.
reduce: 字符串类型, 可选参数, 表示数据操作方式; 当前只支持'update', 即更新操作.
输出说明:
一个Tensor类型的输出, 代表input被更新后的结果.
约束说明:
该接口支持图模式.
indices的维数只能是1维或者2维; 如果是2维, 其第2维的大小必须是2; 不支持索引越界, 索引越界不校验; indices映射的input数据段不能重合, 若重合则会因为多核并发原因导致多次执行结果不一样.
updates的维数需要与input的维数一样; 其第1维的大小等于indices的第1维的大小, 且不大于input的第1维的大小; 其axis轴的大小不大于input的axis轴的大小; 其余维度的大小要跟input对应维度的大小相等; 其最后一维的大小必须32B对齐.
quant_scales的元素个数需要等于updates在quant_axis轴的大小.
quant_zero_points的元素个数需要等于updates在quant_axis轴的大小.
axis不能为updates的第1维或最后1维.
quant_axis只能为updates的最后1维.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [24, 4096, 128]).astype(np.int8)
var = torch.from_numpy(data_var).to(torch.int8).npu()
data_indices = np.random.uniform(0, 1, [24]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [24, 1, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
data_quant_scales = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
data_quant_zero_points = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
quant_zero_points = torch.from_numpy(data_quant_zero_points).to(torch.bfloat16).npu()
axis = -2
quant_axis = -1
reduce = "update"
out = torch_npu.npu_quant_scatter(var, indices, updates, quant_scales, quant_zero_points, axis=axis, quant_axis=quant_axis, reduce=reduce)
图模式调用
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
import numpy as np
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
dtype_list2 =["fp16","int8","int32","fp32","fp16"]
dtype_list =[np.float16,np.int8,np.int32,np.float32]
updates_shape =[1,11,1,32]
var_shape =[1,11,1,32]
indices_shape =[1,2]
quant_scales_shape =[1,1,1,32]
quant_zero_points_shape =[1,1,1,32]
axis =-2
quant_axis=-1
reduce = "update"
updates_data = np.random.uniform(-1,1,size=updates_shape)
var_data = np.random.uniform(1,2,size=var_shape).astype(dtype_list[1])
quant_scales_data = np.random.uniform(1,2,size=quant_scales_shape)
indices_data = np.random.uniform(0,1,size=indices_shape).astype(dtype_list[2])
quant_zero_points_data = np.random.uniform(0,1,size=quant_zero_points_shape)
updates_npu = torch.from_numpy(updates_data).npu().to(torch.bfloat16).npu()
var_npu = torch.from_numpy(var_data).npu()
quant_scales_npu = torch.from_numpy(quant_scales_data).npu().to(torch.bfloat16).npu()
quant_zero_points_npu = torch.from_numpy(quant_zero_points_data).to(torch.bfloat16).npu()
indices_npu = torch.from_numpy(indices_data).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_quant_scatter(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_quant_scatter(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
print("single op output with mask:", single_op[0], single_op[0].shape)
print("graph output with mask:", graph_output[0], graph_output[0].shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[ 1, 1, 0, 1, 0, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0]],
[[ 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]],
[[ 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 1, 1, 1, 1,
0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2, 0, 1, 1]],
[[ 1, 1, 0, 1, 0, -1, 0, 1, 0, 1, 0, 0, -1, 0, 1, 0, 0,
1, 0, 2, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]],
[[ 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]],
[[ 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,
0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]],
[[ 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, -1, 1, 0, 0,
1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]],
[[ 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1]],
[[ 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, -1, 0,
1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1]],
[[ 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1,
0, 1, 1, 1, -1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]],
[[ 1, 0, -1, 1, 0, 0, 1, 0, 1, 2, 0, 1, 0, -1, 1, 1, 1,
1, 0, 0, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]]],
device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
graph output with mask: tensor([[[ 1, 1, 0, 1, 0, -1, 0, 0, 0, 1, 0, 1, 0, -1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 2, 1, 0, 0]],
[[ 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]],
[[ 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 1, 1, 1, 1,
0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2, 0, 1, 1]],
[[ 1, 1, 0, 1, 0, -1, 0, 1, 0, 1, 0, 0, -1, 0, 1, 0, 0,
1, 0, 2, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]],
[[ 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]],
[[ 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,
0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]],
[[ 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, -1, 1, 0, 0,
1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1]],
[[ 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1]],
[[ 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, -1, 0,
1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1]],
[[ 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1,
0, 1, 1, 1, -1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]],
[[ 1, 0, -1, 1, 0, 0, 1, 0, 1, 2, 0, 1, 0, -1, 1, 1, 1,
1, 0, 0, 2, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]]],
device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
"""
)
_add_torch_npu_docstr(
"npu_quant_scatter_",
"""
功能描述:
先将updates进行量化, 然后将updates中的值按指定的轴axis和索引indices更新input中的值, input中的数据被改变.
接口原型:
torch_npu.npu_quant_scatter_(Tensor(a!) input, Tensor indices, Tensor updates, Tensor quant_scales, Tensor? quant_zero_points=None, int axis=-2, int quant_axis=-1, str reduce='update', int? dst_type=None, str? round_mode='rint') -> Tensor(a!)
参数说明:
input: Tensor类型, 必选输入, 源数据张量, 数据类型支持int8, 数据格式支持ND, 支持非连续的Tensor, 维数只能是3~8维.
indices: Tensor类型, 必选输入, 索引张量, 数据类型支持int32, 数据格式支持ND, 支持非连续的Tensor.
updates: Tensor类型, 必选输入, 更新数据张量, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持bfloat16、float16.
quant_scales: Tensor类型, 必选输入, 量化缩放张量, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持bfloat16、float32.
quant_zero_points: Tensor类型, 可选输入, 量化偏移张量, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持int32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持bfloat16、int32.
axis: int类型, 可选参数, updates上用来更新的轴, 默认值为-2.
quant_axis: int类型, 可选参数, updates上用来量化的轴, 默认值为-1.
reduce: 字符串类型, 可选参数, 表示数据操作方式; 当前只支持'update', 即更新操作.
输出说明:
一个Tensor类型的输出, 代表input被更新后的结果.
约束说明:
该接口支持图模式.
indices的维数只能是1维或者2维; 如果是2维, 其第2维的大小必须是2; 不支持索引越界, 索引越界不校验; indices映射的input数据段不能重合, 若重合则会因为多核并发原因导致多次执行结果不一样.
updates的维数需要与input的维数一样; 其第1维的大小等于indices的第1维的大小, 且不大于input的第1维的大小; 其axis轴的大小不大于input的axis轴的大小; 其余维度的大小要跟input对应维度大小相等; 其最后一维的大小必须32B对齐.
quant_scales的元素个数需要等于updates在quant_axis轴的大小.
quant_zero_points的元素个数需要等于updates在quant_axis轴的大小.
axis不能为updates的第1维或最后1维.
quant_axis只能为updates的最后1维.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [24, 4096, 128]).astype(np.int8)
var = torch.from_numpy(data_var).to(torch.int8).npu()
data_indices = np.random.uniform(0, 1, [24]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [24, 1, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
data_quant_scales = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
data_quant_zero_points = np.random.uniform(0, 1, [1, 1, 128]).astype(np.float16)
quant_zero_points = torch.from_numpy(data_quant_zero_points).to(torch.bfloat16).npu()
axis = -2
quant_axis = -1
reduce = "update"
torch_npu.npu_quant_scatter_(var, indices, updates, quant_scales, quant_zero_points, axis=axis, quant_axis=quant_axis, reduce=reduce)
图模式调用
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
import numpy as np
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
dtype_list2 =["fp16","int8","int32","fp32","fp16"]
dtype_list =[np.float16,np.int8,np.int32,np.float32]
updates_shape =[1,11,1,32]
var_shape =[1,11,1,32]
indices_shape =[1,2]
quant_scales_shape =[1,1,1,32]
quant_zero_points_shape =[1,1,1,32]
axis =-2
quant_axis=-1
reduce = "update"
updates_data = np.random.uniform(-1,1,size=updates_shape)
var_data = np.random.uniform(1,2,size=var_shape).astype(dtype_list[1])
quant_scales_data = np.random.uniform(1,2,size=quant_scales_shape)
indices_data = np.random.uniform(0,1,size=indices_shape).astype(dtype_list[2])
quant_zero_points_data = np.random.uniform(0,1,size=quant_zero_points_shape)
updates_npu = torch.from_numpy(updates_data).npu().to(torch.bfloat16).npu()
var_npu = torch.from_numpy(var_data).npu()
quant_scales_npu = torch.from_numpy(quant_scales_data).npu().to(torch.bfloat16).npu()
quant_zero_points_npu = torch.from_numpy(quant_zero_points_data).to(torch.bfloat16).npu()
indices_npu = torch.from_numpy(indices_data).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_quant_scatter_(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_quant_scatter_(var_npu, indices_npu, updates_npu, quant_scales_npu, quant_zero_points_npu, axis=axis, quant_axis=quant_axis, reduce=reduce)
print("single op output with mask:", single_op[0], single_op[0].shape)
print("graph output with mask:", graph_output[0], graph_output[0].shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[ 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1]],
[[ 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0,
1, 1, 0, 1, 1, 0, 0, -1, 0, 1, 0, 1, 0, 1, 0]],
[[ 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0,
1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1]],
[[ 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1,
1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]],
[[ 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0,
1, 1, 0, 1, 1, 1, 1, -1, 0, 0, 0, 1, 1, 1, 0]],
[[ 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0,
1, 1, 1, 0, 0, 1, 0, -1, 0, 0, 0, 1, 1, 1, 0]],
[[ 0, -1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 2, 1, 0,
1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0]],
[[ 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, 1, 0,
1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1]],
[[ 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 2, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0]],
[[ 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, -1, 0, 1, 1, 0, 1,
1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0]],
[[ 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0,
1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1]]],
device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
graph output with mask: tensor([[[ 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1]],
[[ 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0,
1, 1, 0, 1, 1, 0, 0, -1, 0, 1, 0, 1, 0, 1, 0]],
[[ 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0,
1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1]],
[[ 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1,
1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]],
[[ 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0,
1, 1, 0, 1, 1, 1, 1, -1, 0, 0, 0, 1, 1, 1, 0]],
[[ 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0,
1, 1, 1, 0, 0, 1, 0, -1, 0, 0, 0, 1, 1, 1, 0]],
[[ 0, -1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 2, 1, 0,
1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0]],
[[ 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 2, 0, 1, 0,
1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1]],
[[ 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 2, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0]],
[[ 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, -1, 0, 1, 1, 0, 1,
1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0]],
[[ 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0,
1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1]]],
device='npu:0', dtype=torch.int8) torch.Size([11, 1, 32])
"""
)
_add_torch_npu_docstr(
"npu_scatter_nd_update",
"""
功能描述:
将updates中的值按指定的索引indices更新input中的值,并将结果保存到输出tensor,input本身的数据不变。
接口原型:
torch_npu.npu_scatter_nd_update(Tensor input, Tensor indices, Tensor updates) -> Tensor
参数说明:
input:Tensor类型,必选输入,源数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与updates一致,维数只能是1~8维。
Atlas 推理系列加速卡产品:数据类型支持float32、float16、bool。
Atlas 训练系列产品:数据类型支持float32、float16、bool。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、float16、bool、bfloat16、int64、int8。
Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、float16、bool、bfloat16、int64、int8。
indices:Tensor类型,必选输入,索引张量,数据类型支持int32、int64,数据格式支持ND,支持非连续的Tensor,indices中的索引数据不支持越界。
updates:Tensor类型,必选输入,更新数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与input一致。
Atlas 推理系列加速卡产品:数据类型支持float32、float16、bool。
Atlas 训练系列产品:数据类型支持float32、float16、bool。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、float16、bool、bfloat16、int64、int8。
Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、float16、bool、bfloat16、int64、int8。
输出说明:
一个Tensor类型的输出,代表input被更新后的结果。
约束说明:
该接口支持图模式。
indices至少是2维,其最后1维的大小不能超过input的维度大小。
假设indices最后1维的大小是a,则updates的shape等于indices除最后1维外的shape加上input除前a维外的shape。举例:input的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6)。
支持的PyTorch版本:
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas 训练系列产品
Atlas 推理系列产品
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16)
var = torch.from_numpy(data_var).to(torch.float16).npu()
data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.float16).npu()
out = torch_npu.npu_scatter_nd_update(var, indices, updates)
图模式调用
import os
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch.nn as nn
import torch
import numpy as np
import numpy
torch_npu.npu.set_compile_mode(jit_compile=True)
os.environ["ENABLE_ACLNN"] = "false"
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, var, indices, update):
# 调用目标接口
res = torch_npu.npu_scatter_nd_update(var, indices, update)
return res
npu_mode = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
dtype = np.float32
x = [33 ,5]
indices = [33,25,1]
update = [33,25,5]
data_x = np.random.uniform(0, 1, x).astype(dtype)
data_indices = np.random.uniform(0, 10, indices).astype(dtype)
data_update = np.random.uniform(0, 1, update).astype(dtype)
tensor_x = torch.from_numpy(data_x).to(torch.float16)
tensor_indices = torch.from_numpy(data_indices).to(torch.int32)
tensor_update = torch.from_numpy(data_update).to(torch.float16)
# 传参
print(npu_mode(tensor_x.npu(), tensor_indices.npu(), tensor_update.npu()))
"""
)
_add_torch_npu_docstr(
"npu_recurrent_gated_delta_rule",
"""
功能描述:
完成变步长的Recurrent Gated Delta Rule计算。
接口原型:
npu_recurrent_gated_delta_rule(Tensor query, Tensor key, Tensor value, Tensor(a!) state, *, Tensor? beta=None, float? scale=None, Tensor? actual_seq_lengths=None, Tensor? ssm_state_indices=None, Tensor? num_accepted_tokens=None, Tensor? g=None, Tensor? gk=None) -> Tensor
参数说明:
令 $B$ 表示batch size,$L_i$ 表示第i个序列的长度,$T=\sum_i^B L_i$ 表示累积序列长度。$N_k$ 表示key的头数,$N_v$ 表示value的头数,$D_k$ 表示key向量的维度,$D_v$ 表示value向量的维度。
- query (Tensor):必选输入,对应公式中的q,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nk, Dk)。
- key (Tensor):必选输入,对应公式中的k,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nk, Dk)。
- value (Tensor):必选输入,对应公式中的v,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nv, Dv)。
- state (Tensor):必选输入&输出,对应公式中的状态矩阵S,数据类型支持bfloat16,数据格式支持ND,shape为(BlockNum, Nv, Dv, Dk)。
- beta (Tensor):必选输入,对应公式中的β,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nv)。
- scale (Scalar):必选输入,query的缩放因子,对应公式中的 $1/\sqrt{d_k}$。数据类型支持float32。
- actual_seq_lengths (Tensor):必选输入,各batch的输入序列长度。数据类型支持int32,数据格式支持ND,shape为(B,)。
- ssm_state_indices (Tensor):必选输入,输入序列到状态矩阵的映射索引。state[ssm_state_indices[i]]表示第i个token的状态矩阵。数据类型支持int32,数据格式支持ND,shape为(T,)。
- num_accepted_tokens (Tensor):可选输入,投机推理每个batch接受的token数量。默认为None,表示每个batch接受的token数为1。数据类型支持int32,数据格式支持ND,shape为(B,)。
- g (Tensor):可选输入,衰减系数,对应公式中的α=e^g。默认为None,表示全0。数据类型支持float32,数据格式支持ND,shape为(T, Nv)。
- gk (Tensor):可选输入,衰减系数,当前版本暂不支持,传None即可。
输出说明:
注意力计算结果。输出的数据类型为bfloat16,数据格式为ND,形状为(T, Nv, Dv)。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持静态图模式。
- 输入shape大小需满足约束:$L_i \le 8$,$N_k \le 256$,$N_v \le 256$,$D_k \le 256$,$D_v \le 256$。
支持的PyTorch版本:
PyTorch 2.1 及更高版本
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
# 构造输入
bs, mtp, nk, nv, dk, dv = (2, 3, 4, 8, 128, 128)
actual_seq_lengths = (torch.ones(bs) * mtp).npu().to(torch.int32)
T = int(torch.sum(actual_seq_lengths))
state = torch.rand((T, nv, dv, dk), dtype=torch.bfloat16).npu()
query = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
key = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
value = torch.rand((T, nv, dv), dtype=torch.bfloat16).npu()
g = torch.rand((T, nv), dtype=torch.float32).npu() * (-1.0)
beta = torch.rand((T, nv), dtype=torch.bfloat16).npu()
ssm_state_indices = (torch.arange(T).npu()).to(torch.int32)
query = torch.nn.functional.normalize(query, p=2, dim=-1)
key = torch.nn.functional.normalize(key, p=2, dim=-1)
scale = dk ** -0.5
num_accepted_tokens = (torch.randint(1, mtp + 1, (bs,)).npu()).to(torch.int32)
# 调用算子
o = torch_npu.npu_recurrent_gated_delta_rule(
query, key, value, state, beta=beta, scale=scale,
actual_seq_lengths=actual_seq_lengths, ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens, g=g, gk=None)
print(o)
"""
)
_add_torch_npu_docstr(
"npu_recurrent_gated_delta_rule_functional",
"""
功能描述:
完成变步长的Recurrent Gated Delta Rule计算。
接口原型:
npu_recurrent_gated_delta_rule_functional(
Tensor query,
Tensor key,
Tensor value,
Tensor state,
*,
Tensor? beta=None,
float? scale=None,
Tensor? actual_seq_lengths=None,
Tensor? ssm_state_indices=None,
Tensor? num_accepted_tokens=None,
Tensor? g=None, Tensor? gk=None) -> (Tensor, Tensor)
参数说明:
令 $B$ 表示batch size,$L_i$ 表示第i个序列的长度,$T=\sum_i^B L_i$ 表示累积序列长度。$N_k$ 表示key的头数,$N_v$ 表示value的头数,$D_k$ 表示key向量的维度,$D_v$ 表示value向量的维度。
- query (Tensor):必选输入,对应公式中的q,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nk, Dk)。
- key (Tensor):必选输入,对应公式中的k,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nk, Dk)。
- value (Tensor):必选输入,对应公式中的v,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nv, Dv)。
- state (Tensor):必选输入&输出,对应公式中的状态矩阵S,数据类型支持bfloat16,数据格式支持ND,shape为(BlockNum, Nv, Dv, Dk)。
- beta (Tensor):必选输入,对应公式中的β,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nv)。
- scale (Scalar):必选输入,query的缩放因子,对应公式中的 $1/\sqrt{d_k}$。数据类型支持float32。
- actual_seq_lengths (Tensor):必选输入,各batch的输入序列长度。数据类型支持int32,数据格式支持ND,shape为(B,)。
- ssm_state_indices (Tensor):必选输入,输入序列到状态矩阵的映射索引。state[ssm_state_indices[i]]表示第i个token的状态矩阵。数据类型支持int32,数据格式支持ND,shape为(T,)。
- num_accepted_tokens (Tensor):可选输入,投机推理每个batch接受的token数量。默认为None,表示每个batch接受的token数为1。数据类型支持int32,数据格式支持ND,shape为(B,)。
- g (Tensor):可选输入,衰减系数,对应公式中的α=e^g。默认为None,表示全0。数据类型支持float32,数据格式支持ND,shape为(T, Nv)。
- gk (Tensor):可选输入,衰减系数,当前版本暂不支持,传None即可。
输出说明:
注意力计算结果。输出的数据类型为bfloat16,数据格式为ND,形状为(T, Nv, Dv)。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持静态图模式。
- 输入shape大小需满足约束:$L_i \le 8$,$N_k \le 256$,$N_v \le 256$,$D_k \le 256$,$D_v \le 256$。
支持的PyTorch版本:
PyTorch 2.1 及更高版本
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
# 构造输入
bs, mtp, nk, nv, dk, dv = (2, 3, 4, 8, 128, 128)
actual_seq_lengths = (torch.ones(bs) * mtp).npu().to(torch.int32)
T = int(torch.sum(actual_seq_lengths))
state = torch.rand((T, nv, dv, dk), dtype=torch.bfloat16).npu()
query = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
key = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
value = torch.rand((T, nv, dv), dtype=torch.bfloat16).npu()
g = torch.rand((T, nv), dtype=torch.float32).npu() * (-1.0)
beta = torch.rand((T, nv), dtype=torch.bfloat16).npu()
ssm_state_indices = (torch.arange(T).npu()).to(torch.int32)
query = torch.nn.functional.normalize(query, p=2, dim=-1)
key = torch.nn.functional.normalize(key, p=2, dim=-1)
scale = dk ** -0.5
num_accepted_tokens = (torch.randint(1, mtp + 1, (bs,)).npu()).to(torch.int32)
# 调用算子
o, stateOut = torch_npu.npu_recurrent_gated_delta_rule_functional(
query, key, value, state, bata=beta, scale=scale,
actual_seq_lengths=actual_seq_lengths, ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens, g=g, gk=None)
print(o)
"""
)
_add_torch_npu_docstr(
"npu_scatter_nd_update_",
"""
功能描述:
将updates中的值按指定的索引indices更新input中的值,并将结果保存到输出tensor,input中的数据被改变。
接口原型:
torch_npu.npu_scatter_nd_update_(Tensor(a!) input, Tensor indices, Tensor updates) -> Tensor(a!)
参数说明:
input:Tensor类型,必选输入,源数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与updates一致,维数只能是1~8维。
Atlas 推理系列加速卡产品:数据类型支持float32、float16、bool。
Atlas 训练系列产品:数据类型支持float32、float16、bool。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、float16、bool、bfloat16、int64、int8。
Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、float16、bool、bfloat16、int64、int8。
indices:Tensor类型,必选输入,索引张量,数据类型支持int32、int64,数据格式支持ND,支持非连续的Tensor,indices中的索引数据不支持越界。
updates:Tensor类型,必选输入,更新数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与input一致。
Atlas 推理系列加速卡产品:数据类型支持float32、float16、bool。
Atlas 训练系列产品:数据类型支持float32、float16、bool。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float32、float16、bool、bfloat16、int64、int8。
Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32、float16、bool、bfloat16、int64、int8。
输出说明:
一个Tensor类型的输出,代表input被更新后的结果。
约束说明:
该接口支持图模式。
indices至少是2维,其最后1维的大小不能超过input的维度大小。
假设indices最后1维的大小是a,则updates的shape等于indices除最后1维外的shape加上input除前a维外的shape。举例:input的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6)。
支持的PyTorch版本:
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas 训练系列产品
Atlas 推理系列产品
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16)
var = torch.from_numpy(data_var).to(torch.float16).npu()
data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.float16).npu()
torch_npu.npu_scatter_nd_update_(var, indices, updates)
图模式调用
import os
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch.nn as nn
import torch
import numpy as np
import numpy
torch_npu.npu.set_compile_mode(jit_compile=True)
os.environ["ENABLE_ACLNN"] = "false"
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, var, indices, update):
# 调用目标接口
res = torch_npu.npu_scatter_nd_update_(var, indices, update)
return res
npu_mode = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
dtype = np.float32
x = [33 ,5]
indices = [33,25,1]
update = [33,25,5]
data_x = np.random.uniform(0, 1, x).astype(dtype)
data_indices = np.random.uniform(0, 10, indices).astype(dtype)
data_update = np.random.uniform(0, 1, update).astype(dtype)
tensor_x = torch.from_numpy(data_x).to(torch.float16)
tensor_indices = torch.from_numpy(data_indices).to(torch.int32)
tensor_update = torch.from_numpy(data_update).to(torch.float16)
# 传参
print(npu_mode(tensor_x.npu(), tensor_indices.npu(), tensor_update.npu()))
"""
)
_add_torch_npu_docstr(
"npu_scatter_pa_kv_cache",
"""
功能描述:
- 更新KvCache中指定位置的key和value。
- 输入输出支持以下场景:
- 场景一:
key:[batch, num_head, k_head_size]
value:[batch, num_head, v_head_size]
key_cache:[num_blocks, num_head * k_head_size // last_dim_k, block_size, last_dim_k]
value_cache:[num_blocks, num_head * v_head_size // last_dim_k, block_size, last_dim_k]
slot_mapping:[batch]
- 场景二:
key:[batch*seq_len, num_head, k_head_size]
value:[batch*seq_len, num_head, v_head_size]
keyCache:[num_blocks, block_size, num_head, k_head_size]
valueCache:[num_blocks, block_size, num_head, v_head_size]
slotMapping:[batch * seq_len]
其中k_head_size与v_head_size 可以不同,也可以相同。
- 场景三:
key:[batch, seq_len, num_head, k_head_size]
value:[batch, seq_len, num_head, v_head_size]
key_cache:[num_blocks, block_size, 1, k_head_size]
value_cache:[num_blocks, block_size, 1, k_head_size]
slot_mapping:[batch, num_head]
compress_lens:[batch, num_head]
seq_lens:[batch]
compress_seq_offsets:[batch * num_head]
- 上述场景根据构造的参数来区别,符合第一种入参构造走场景一,符合第二种构造走场景二,符合第三种构造走场景三。
场景一、场景二没有compress_lens、seq_lens、compress_seq_offsets这三个可选参数。
接口原型:
torch_npu.npu_scatter_pa_kv_cache(Tensor key, Tensor value, Tensor(a!) key_cache, Tensor(b!) value_cache, Tensor slot_mapping, *, Tensor? compress_lens=None, Tensor? compress_seq_offsets=None, Tensor? seq_lens=None, str? cache_mode='PA_NZ') -> ()
参数说明:
- key(Tensor):必选参数。表示待更新的key值,当前step多个token的key,支持3维或4维。数据类型支持float16、float、bfloat16、int8、uint8、int16、uint16、int32、uint32、hifloat8、float8_e5m2、float8_e4m3fn,数据格式支持ND。
- value(Tensor):必选参数,表示待更新的value值,当前step多个token的value,支持3维或4维,数据类型和数据格式与key保持一致。
- key_cache(Tensor):必选参数,表示需要更新的key cache,当前layer的key cache,只支持4维,数据类型和数据格式与key保持一致。
- value_cache(Tensor):必选参数,表示需要更新的value cache,当前layer的value cache。只支持4维,数据类型和数据格式与key保持一致。
- slot_mapping(Tensor):必选参数,表示每个token key或value在cache中的存储偏移,数据类型支持int32和int64,数据格式支持ND。
- compress_lens(Tensor):可选参数,表示压缩量,数据类型与slot_mapping一致,数据格式支持ND。
- compress_seq_offsets(Tensor):可选参数,表示每个batch每个head的压缩起点,数据类型与slot_mapping一致,数据格式支持ND。
- seq_lens(Tensor):可选参数,表示每个batch的实际seqLens,数据类型与slot_mapping一致,数据格式支持ND。
- cache_mode(str):可选参数,表示key_cache和value_cache的内存排布格式。当传None或'Norm'时,仅支持ND内存排布格式。当传入'PA_NZ'时,仅支持NZ内存排布格式,默认值为'PA_NZ'。
输出说明:
- key_cache(Tensor):表示Key输出到kv_cache中的Tensor(本质in-place更新)。
- value_cache(Tensor):表示value输出到kv_cache中的Tensor(本质in-place更新)。
约束说明:
- 输入参数不支持非连续;
- key、value、key_cache、value_cache的数据类型必须一致;
- slot_mapping、compress_lens、compress_seq_offset、seq_lens的数据类型必须一致;
- slot_mapping的值范围[0,num_blocks*block_size-1],且slot_mapping内的元素值保证不重复,重复时不保证正确性;
- 当key和value都是3维,则key和value的前两维shape必须相同;
- 当key和value都是4维,则key和value的前三维shape必须相同,且key_cache和value_cache的第三维必须是1;
- 当key和value是4维时,compress_lens、seq_lens为必选参数;当key和value是3维时,compress_lens、compress_seq_offsets、seq_lens为可选参数;
- 当key和value都是4维时,slot_mapping是二维,且slot_mapping的第一维值等于key的第一维为batch,slot_mapping的第二维值等于key的第三维为num_head(对应场景三);
- 当key和value都是4维时,seq_lens是一维,且seq_lens的值等于key的第一维为batch(对应场景三);
- seq_lens和compress_lens里面的每个元素值必须满足公式:reduceSum(seq_lens[i] - compress_lens[i]) <= num_blocks * block_size (对应场景三);
支持的型号:
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
调用示例:
# 单算子调用方式
import numpy as np
import torch
import torch_npu
# 生成随机数据, 并发送到npu
bs = 16
num_head = 4
k_head_size = 32
v_head_size = 64
num_blocks = 2
lastDim_k = 16
block_size = 32
key = np.random.randn(bs, num_head, k_head_size).astype(np.float16)
value = np.random.randn(bs, num_head, v_head_size).astype(np.float16)
key_cache = np.random.randn(
num_blocks, num_head * k_head_size // lastDim_k, block_size, lastDim_k).astype(np.float16)
value_cache = np.zeros(
(num_blocks, num_head * v_head_size // lastDim_k, block_si ze, lastDim_k)).astype(np.float16)
slot_mapping = np.random.choice(num_blocks * block_size, bs, replace=False).astype(np.int32)
key_npu = torch.from_numpy(key).npu()
value_npu = torch.from_numpy(value).npu()
key_cache_npu = torch.from_numpy(key_cache).npu()
value_cache_npu = torch.from_numpy(value_cache).npu()
key_cache_npu_cast = torch_npu.npu_format_cast(key_cache_npu.contiguous(), 29)
value_cache_npu_cast = torch_npu.npu_format_cast(value_cache_npu.contiguous(), 29)
slot_mapping_npu = torch.from_numpy(slot_mapping).npu()
# 调用ScatterPaKvCache算子
torch_npu.npu_scatter_pa_kv_cache(key_npu, value_npu, key_cache_npu_cast, value_cache_npu_cast, slot_mapping_npu)
# 执行上述代码,打印key_cache_npu_cast或value_cache_npu_cast输出类似如下
tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
...,
[[ 1.8271, 1.4551, 1.3154, ..., 1.9854, 1.4365, 1.0732],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]],
...,
[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 1.9492, 1.6455, 1.6504, ..., 1.5957, 1.6201, 1.4385],
[ 0.0742, 0.1982, 0.8945, ..., 0.4912, 0.6753, 0.1120],
...,
[[ 0.1113, 0.6255, 0.7686, ..., 0.0247, 0.2490, 0.6909],
[ 0.4312, 0.7954, 0.7339, ..., 0.1154, 0.6440, 0.3342],
[ 0.9570, 0.2869, 0.6489, ..., 0.7451, 0.0234, 0.8843]],
...,
[[ 1.8271, 1.4551, 1.3154, ..., 1.9854, 1.4365, 1.0732],
[ 1.9492, 1.6455, 1.6504, ..., 1.5957, 1.6201, 1.4385],
[ 0.0742, 0.1982, 0.8945, ..., 0.4912, 0.6753, 0.1120]]]]
device='npu:0', dtype=torch.float16)
# 入图方式
import numpy as np
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.aoe_config.aoe_mode = "2"
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
bs = 16
num_head = 4
k_head_size = 32
v_head_size = 64
num_blocks = 2
lastDim_k = 16
block_size = 32
class Model(torch.nn.Module):
def init(self):
super().init()
def forward(self, key, value, slot_mapping, key_cache, value_cache):
torch_npu.npu_scatter_pa_kv_cache(key, value, key_cache, value_cache, slot_mapping)
if name=="main":
torch_npu.npu.set_device(0)
key = np.random.randn(bs, num_head, k_head_size).astype(np.float16)
value = np.random.randn(bs, num_head, v_head_size).astype(np.float16)
key_cache = np.random.randn(
num_blocks, num_head * k_head_size // lastDim_k, block_size, lastDim_k).astype(np.float16)
value_cache = np.zeros(
(num_blocks, num_head * v_head_size // lastDim_k, block_size, lastDim_k)).astype(np.float16)
slot_mapping = np.random.choice(num_blocks * block_size, bs, replace=False).astype(np.int32)
key_npu = torch.from_numpy(key).npu()
value_npu = torch.from_numpy(value).npu()
key_cache_npu = torch.from_numpy(key_cache).npu()
value_cache_npu = torch.from_numpy(value_cache).npu()
key_cache_npu_cast = torch_npu.npu_format_cast(key_cache_npu.contiguous(), 29)
value_cache_npu_cast = torch_npu.npu_format_cast(value_cache_npu.contiguous(), 29)
slot_mapping_npu = torch.from_numpy(slot_mapping).npu()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
model = Model().npu()
# 图模式调用
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
model(key_npu, value_npu, slot_mapping_npu, key_cache_npu_cast, value_cache_npu_cast)
# 执行上述代码,打印key_cache_npu_cast或value_cache_npu_cast输出类似如下
tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
...,
[[ 1.8271, 1.4551, 1.3154, ..., 1.9854, 1.4365, 1.0732],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]],
...,
[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 1.9492, 1.6455, 1.6504, ..., 1.5957, 1.6201, 1.4385],
[ 0.0742, 0.1982, 0.8945, ..., 0.4912, 0.6753, 0.1120],
...,
[[ 0.1113, 0.6255, 0.7686, ..., 0.0247, 0.2490, 0.6909],
[ 0.4312, 0.7954, 0.7339, ..., 0.1154, 0.6440, 0.3342],
[ 0.9570, 0.2869, 0.6489, ..., 0.7451, 0.0234, 0.8843]],
...,
[[ 1.8271, 1.4551, 1.3154, ..., 1.9854, 1.4365, 1.0732],
[ 1.9492, 1.6455, 1.6504, ..., 1.5957, 1.6201, 1.4385],
[ 0.0742, 0.1982, 0.8945, ..., 0.4912, 0.6753, 0.1120]]]]
device='npu:0', dtype=torch.float16)
"""
)
_add_torch_npu_docstr(
"npu_anti_quant",
"""
功能描述:
算子功能: 对张量x进行反量化操作, 即将整数恢复为浮点数.
计算公式为: anti_quant(x)=quant((x+offset)*scale)
接口原型:
torch_npu.npu_anti_quant(Tensor x, Tensor scale, *, Tensor? offset=None, ScalarType? dst_dtype=None, ScalarType? src_dtype=None) -> Tensor
参数说明:
x: Tensor类型, 即输入参数中的x. 数据格式支持ND, 支持非连续的Tensor. 输入最大支持8维.
Atlas 推理系列产品: 数据类型支持int8.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、int32, 其中int32类型数据的每个值是由8个int4数值拼成.
Atlas A3 训练系列产品: 数据类型支持int8、int32, 其中int32类型数据的每个值是由8个int4数值拼成.
scale: Tensor类型, 反量化参数scale. 仅支持1维Tensor, shape为(n,). 其中n可以为1, 如果n不为1, 当x为int8类型, 必须与输入x的尾轴维度大小相同; 当x为int32类型时, 必须为输入x的尾轴维度大小的8倍. 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32、bfloat16.
Atlas A3 训练系列产品: 数据类型支持float32、bfloat16.
offset: Tensor类型, 可选参数, 反量化参数offset. 仅支持1维Tensor, 数据类型和shape必须与scale一致. 数据格式支持ND, 支持非连续的Tensor.
dst_dtype: ScalarType类型, 可选参数, 指定输出的数据类型, 默认值为float16.
Atlas 推理系列产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16.
src_dtype: ScalarType类型, 可选参数, 指定源输入的数据类型, 默认值为int8.
Atlas 推理系列产品: 数据类型支持int8.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持quint4x2或int8.
Atlas A3 训练系列产品: 数据类型支持quint4x2或int8.
输出说明:
一个Tensor类型的输出, 代表antiquant的计算结果.
约束说明:
该接口支持推理、训练场景下使用.
该接口支持图模式.
x、scale这两个输入中不能含有None.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
Atlas 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
x_tensor = torch.tensor([1,2,3,4], dtype=torch.int8).npu()
scale = torch.tensor([2.0], dtype=torch.float).npu()
offset = torch.tensor([2.0], dtype=torch.float).npu()
out = torch_npu.npu_anti_quant(x_tensor, scale, offset=offset, dst_dtype=torch.float16)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.debug.graph_dump.type = 'pbtxt'
npu_backend = tng.get_npu_backend(compiler_config=config)
x_tensor = torch.tensor([1,2,3,4], dtype=torch.int8).npu()
scale = torch.tensor([2.0], dtype=torch.float).npu()
offset = torch.tensor([2.0], dtype=torch.float).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self,x,scale,offset):
return torch_npu.npu_anti_quant(x, scale, offset=offset, dst_dtype=torch.float16)
cpu_model = Model()
model = cpu_model.npu()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
output = model(x_tensor,scale,offset)
"""
)
_add_torch_npu_docstr(
"npu_mm_all_reduce_base",
"""
功能描述:
TP切分场景下, 实现mm和all_reduce的融合, 融合算子内部实现计算和通信流水并行.
使用该接口时, 请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本, 否则将会引发报错, 比如BUS ERROR等.
接口原型:
torch_npu.npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None Tensor? pertoken_scale=None, Tensor? comm_quant_scale_1=None, Tensor? comm_quant_scale_2=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor
参数说明:
x1: Tensor类型, 数据类型支持int8、float16、bfloat16. 数据格式支持ND, 输入shape支持2维或者3维.
x2: Tensor类型, 数据类型支持float16、int8、bfloat16, 数据格式支持NZ(昇腾亲和排布格式)、ND. 非量化场景, 数据类型需要和x1保持一致, 输入shape维度第0维和x1的最后一维保持一致.
hcom: String类型, 通信域handle名, 通过get_hccl_comm_name接口获取.
*: 代表其之前的变量是位置相关, 按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
reduce_op: String类型, reduce操作类型, 当前版本仅支持'sum', 默认值: 'sum'.
bias: Tensor类型, 可选输入, 数据类型支持int32、float16、bfloat16, 数据格式支持ND. bias当前仅支持一维, 且维度大小与output/x2的最后一维大小相同.
antiquant_scale: Tensor类型, 可选输入, 伪量化场景对x2进行去量化的系数, 数据类型支持float16、bfloat16, 数据格式支持ND. 伪量化场景数据类型需要和x1保持一致.
per-tensor场景: shape为[1].
per-channel场景: shape为[1,n]或者[n], n为x2最后一维的大小.
per-group场景: shape为[ceil(k, antiquant_group_size), n]. 其中k为x2第一维的大小, n为x2最后一维的大小, antiquant_group_size为伪量化场景对输入x2进行反量化计算的groupSize输入.
ceil(k, antiquant_group_size)的计算逻辑为: (k+antiquant_group_siz-1)/antiquant_group_size, 并对计算结果取整数部分.
antiquant_offset: Tensor类型, 可选输入, 伪量化场景对x2进行去量化的系数, 数据类型支持float16、bfloat16, 数据格式支持ND. 数据类型、shape需要和antiquant_scale保持一致.
x3: Tensor类型, 可选输入, matmul计算后的偏移.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16, 数据格式支持ND. 数据类型、shape需要和输出output保持一致.
dequant_scale: Tensor类型, 可选输入, matmul计算后的去量化系数. 数据类型支持int64、uint64、bfloat16、float32; 数据格式支持ND.
per-tensor场景: shape为[1].
per-channel场景: shape为[n]/[1,n], n为x2最后一维的大小.
pertoken_scale: Tensor类型, 可选输入, matmul计算后的per-token去量化系数.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32. 当x1为[m,k]时pertoken_scale shape为[m]; 当x1为[b, s, k]时pertoken_scale shape为[b*s].
comm_quant_scale_1: Tensor类型, 可选输入, alltoall通信前后的量化、去量化系数. 支持float16、bfloat16, 支持ND格式. x2为[k, n]时shape为[1, n]或[n], 用户需保证每张卡上数据保持一致且正确.
comm_quant_scale_2: Tensor类型, 可选输入, allgather通信前后的量化、去量化系数. 支持float16、bfloat16, 支持ND格式. x2为[k, n]时shape为[1, n]或[n], 用户需保证每张卡上数据保持一致且正确.
comm_turn: int类型, 表示rank间通信切分粒度, 默认值: 0, 表示默认的切分方式. 当前版本仅支持输入0.
antiquant_group_size: int类型, 表示伪量化pre-group算法模式下, 对输入x2进行反量化计算的groupSize输入, 描述一组反量化参数对应的待反量化数据量在k轴方向的大小. 当伪量化算法模式不为pre-group时传入0; 当伪量化算法模式为pre-group时传入值的范围为[32, min(k-1, INT_MAX)]且值要求是32的倍数, 其中k为x2第一维的大小. 默认值0, 为0则表示非per-group场景.
输出说明
Tensor类型, 数据类型非量化场景以及伪量化场景与x1保持一致, 全量化场景输出数据类型为float16或bfloat16. shape第0维度和x1的0维保持一致, 若x1为2维, shape第1维度和x2的1维保持一致, 若x1为3维, shape第1维度和x1的1维保持一致, shape第2维度和x2的1维保持一致.
约束说明
增量场景不使能该融合算子, 全量场景使能该融合算子.
该接口支持推理场景下使用.
该接口支持图模式.
输入x1可为2维或者3维、x2必须是2维, 分别为(b, s, k)/(m, k), (k, n), k轴满足mm算子入参要求, k轴相等. bias当前仅支持一维, 且维度大小与output的最后一维大小相同. x3的shape与output的shape相同.
x1不支持输入转置后的tensor, x2转置后输入, 需要满足shape的第一维大小与x1的最后一维相同, 满足matmul的计算条件.
antiquant_group_size中k值的范围与matmul一致, 为[1,65535], INT_MAX大于(k-1).
Atlas A2 训练系列产品/Atlas 800I A2 推理产品:
数据类型支持bfloat16.
x1、x2不支持为空tensor.
支持1、2、4、8卡, 并且仅支持hccs链路all mesh组网.
非量化场景下, m、k、n的取值范围均为[1, 2147483647].
comm_quant_scale_1, comm_quant_scale_2的shape应保持一致, dtype与输出的dtype保持一致, 且只在全量化场景支持.
全量化场景: m取值范围均为[1, 2147483647], x1、x2的最后一维范围为[1, 65535], 即k的取值范围为[1, 65535]、仅当x2(shape=[n,k])为转置时n可以大于65535.
伪量化场景: m取值范围均为[1, 2147483647], k、n的取值范围为[1, 65535].
Atlas A2 训练系列产品: 一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce), 仅支持相同通信域.
在长序列场景, 随着b/s或者m的增大, 可能出现内存不足或者计算超时.
不同场景下数据类型支持情况:
表1 非量化场景产品型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
x1: float16
x2: float16
bias: float16
x3: float16
output(输出): float16
antiquant_scale: None
antiquant_offset: None
dequant_scale: None
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
x1: bfloat16
x2: bfloat16
bias: bfloat16
x3: bfloat16
output(输出): bfloat16
antiquant_scale: None
antiquant_offset: None
dequant_scale: None
表2 伪量化场景产品型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
x1: float16
x2: int8
bias: float16
x3: float16
output(输出): float16
antiquant_scale: float16
antiquant_offset: float16
dequant_scale: None
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
x1: bfloat16
x2: int8
bias: bfloat16
x3: bfloat16
output(输出): bfloat16
antiquant_scale: bfloat16
antiquant_offset: bfloat16
dequant_scale: None
表3 全量化场景产品型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
x1: int8, int8, int8, int8
x2: int8, int8, int8, int8
bias: int32, int32, int32, int32
x3: float16, bfloat16, float16, bfloat16
output(输出): float16, bfloat16, float16, bfloat16
antiquant_scale: None, None, None, None
antiquant_offset: None, None, None, None
dequant_scale: uint64或int64, bfloat16, float32, bfloat16
pertoken_scale: None, None, float32, float32
全量化场景: 若dequant_scale需要以FP32类型传入, 在调用torch_npu.npu_mm_all_reduce_base前, 需通过torch_npu.npu_trans_quant_param接口对dequant_scale进行处理为int64类型(处理方法见对应的接口使用说明).
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.1
PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例:
单算子模式调用
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
input_ = torch.randn(x1_shape, dtype=dtype).npu()
weight = torch.randn(x2_shape, dtype=dtype).npu()
output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum')
print("output: ", output)
if __name__ == "__main__":
worksize = 8
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [128, 512]
x2_shape = [512, 64]
dtype = torch.float16
mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
图模式调用
非量化、伪量化、全量化使能NZ调用示例如下:
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
class MM_ALLREDUCE_GRAPH_Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale):
output_npu = torch_npu.npu_mm_all_reduce_base(x1=x1,
x2=x2,
hcom=hcom,
reduce_op=reduce_op,
bias=bias,
antiquant_scale=antiquant_scale,
antiquant_offset=antiquant_offset,
x3=x3,
dequant_scale=dequant_scale
)
return output_npu
class MM_ALLREDUCE_A8W8_GRAPH_Model(MM_ALLREDUCE_GRAPH_Model):
def __init__(self):
super().__init__()
def forward(self, x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale):
output_npu = torch_npu.npu_mm_all_reduce_base(x1=x1,
x2=x2.t(),
hcom=hcom,
reduce_op=reduce_op,
bias=bias,
antiquant_scale=antiquant_scale,
antiquant_offset=antiquant_offset,
x3=x3,
dequant_scale=dequant_scale
)
return output_npu
def define_model(model, graph_type):
import torchair
if graph_type == 1: # 传统入图模式, 静态shape+在线编译场景
npu_backend = torchair.get_npu_backend(compiler_config=None)
model = torch.compile(model, backend=npu_backend, dynamic=False)
elif graph_type == 2: # ACLNN入图模式, 动态shape+二进制
npu_backend = torchair.get_npu_backend(compiler_config=None)
model = torch.compile(model, backend=npu_backend, dynamic=True)
else:
print("Error type")
return model
def get_graph(input, weight, hcomm_info, dequant_scale, bias, antiquant_scale, antiquant_offset, x3):
model = MM_ALLREDUCE_A8W8_GRAPH_Model()
model = define_model(model, 2) # 1:静态入图;2:动态入图;
output = model(x1=input, x2=weight, hcom=hcomm_info, reduce_op="sum", bias=bias, antiquant_scale=antiquant_scale,
antiquant_offset=antiquant_offset, x3=x3, dequant_scale=dequant_scale)
return output
def run_mc2_a16w16(x1_shape, x2_shape, hcom_info):
np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.float16)
np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.float16)
input = torch.tensor(np_input).npu()
weight = torch.tensor(np_weight).npu()
output_a16w16 = get_graph(input, weight, hcom_info, None, None, None, None, None)
return output_a16w16
def run_mc2_a8w8(x1_shape, x2_shape, hcom_info):
np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.int8)
np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.int8)
input = torch.tensor(np_input).npu()
weight = torch.tensor(np_weight).npu()
weight_nz = torch_npu.npu_format_cast(weight.contiguous(), 29)
dequant_scale = torch.randn(x2_shape[0], dtype=torch.float32).uniform_(float(-10), float(10)).npu()
dequant_scale = torch_npu.npu_trans_quant_param(dequant_scale)
output_a8w8 = get_graph(input, weight_nz, hcom_info, dequant_scale, None, None, None, None)
return output_a8w8
def run_mc2_a16w8(x1_shape, x2_shape, hcom_info):
np_input = np.random.uniform(float(-3), float(3), size=x1_shape).astype(np.float16)
np_weight = np.random.uniform(float(-3), float(3), size=x2_shape).astype(np.int8)
input = torch.tensor(np_input).npu()
weight = torch.tensor(np_weight).npu()
weight_nz = torch_npu.npu_format_cast(weight.contiguous(), 29)
antiquant_scale = torch.randn(x2_shape[0], dtype=torch.float16).uniform_(float(-1), float(1)).npu()
antiquant_offset = torch.ones(x2_shape[0], dtype=torch.float16).npu()
output_a16w8 = get_graph(input, weight_nz, hcom_info, None, None, antiquant_scale, antiquant_offset, None)
return output_a16w8
def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, op_type):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
output = None
# 非量化调用
if op_type == "a16w16":
output = run_mc2_a16w16(x1_shape, x2_shape, hcom_info)
# 伪量化调用
if op_type == "a16w8":
output = run_mc2_a16w8(x1_shape, x2_shape, hcom_info)
# 全量化调用
if op_type == "a8w8":
output = run_mc2_a8w8(x1_shape, x2_shape, hcom_info)
print("output:", output)
if __name__ == "__main__":
worksize = 2
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [1280, 5120]
x2_shape = [640, 5120]
op_type = "a16w8" # Options: a16w16, a16w8, a8w8
mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, op_type), nprocs=worksize)
"""
)
_add_torch_npu_docstr(
"npu_ffn",
"""
功能描述:
算子功能: 该FFN算子提供MoeFFN和FFN的计算功能. 在没有专家分组(expert_tokens为空)时是FFN, 有专家分组时是MoeFFN.
计算公式:
out=activation(xW1+b1)W2+b2
激活层为geglu/swiglu/reglu时, 性能使能需要满足门槛要求, 即整网中FFN结构所对应的小算子中vector耗时30us且占比10%以上的用例方可尝试FFN融合算子; 或在不知道小算子性能的情况下, 尝试使能FFN, 若性能劣化则不使能FFN.
接口原型:
torch_npu.npu_ffn(Tensor x, Tensor weight1, Tensor weight2, str activation, *, int[]? expert_tokens=None, int[]? expert_tokens_index=None, Tensor? bias1=None, Tensor? bias2=None, Tensor? scale=None, Tensor? offset=None, Tensor? deq_scale1=None, Tensor? deq_scale2=None, Tensor? antiquant_scale1=None, Tensor? antiquant_scale2=None, Tensor? antiquant_offset1=None, Tensor? antiquant_offset2=None, int? inner_precise=None, ScalarType? output_dtype=None) -> Tensor
参数说明:
x: Tensor类型, 输入参数, 公式中的x, 数据类型支持float16、bfloat16、int8, 数据格式支持ND, 支持输入的维度最少是2维[M, K1], 最多是8维.
weight1: Tensor类型, 专家的权重数据, 公式中的W1, 数据类型支持float16、bfloat16、int8, 数据格式支持ND, 输入在有/无专家时分别为[E, K1, N1]/[K1, N1].
weight2: Tensor类型, 专家的权重数据, 公式中的W2, 数据类型支持float16、bfloat16、int8, 数据格式支持ND, 输入在有/无专家时分别为[E, K2, N2]/[K2, N2].
M表示token个数, 对应transform中的BS(B: Batch, 表示输入样本批量大小, S: Seq-Length, 表示输入样本序列长度); K1表示第一个matmul的输入通道数, 对应transform中的H(Head-Size, 表示隐藏层的大小); N1表示第一个matmul的输出通道数; K2表示第二个matmul的输入通道数; N2表示第二个matmul的输出通道数, 对应transform中的H; E表示有专家场景的专家数.
expert_tokens: List类型, 可选参数. 代表各专家的token数, 数据类型支持int32, 数据格式支持ND, 若不为空时可支持的最大长度为256个.
expert_tokens_index: List类型, 可选参数. 代表各专家计算token的索引值, 数据类型支持int32, 数据格式支持ND, 若不为空时可支持的最大长度为256个.
bias1: Tensor类型, 可选参数. 权重数据修正值, 公式中的b1, 数据类型支持float16、float32、int32, 数据格式支持ND, 输入在有/无专家时分别为[E, N1]/[N1].
bias2: Tensor类型, 可选参数. 权重数据修正值, 公式中的b2, 数据类型支持float16、float32、int32, 数据格式支持ND, 输入在有/无专家时分别为[E, N2]/[N2].
activation: string类型, 代表使用的激活函数, 即输入参数中的activation. 当前仅支持fastgelu、gelu、relu、silu、geglu、swiglu、reglu.
scale: Tensor类型, 可选参数, 量化参数, 量化缩放系数, 数据类型支持float32, 数据格式支持ND. per-tensor下输入在有/无专家时均为一维向量, 输入元素个数在有/无专家时分别为[E]/[1]; per-channel下输入在有/无专家时为二维向量/一维向量, 输入元素个数在有/无专家时分别为[E, N1]/[N1].
offset: Tensor类型, 可选参数, 量化参数, 量化偏移量, 数据类型支持float32, 数据格式支持ND, 一维向量, 输入元素个数在有/无专家时分别为[E]/[1].
deq_scale1: Tensor类型, 可选参数, 量化参数, 第一组matmul的反量化缩放系数, 数据类型支持int64、float32、bfloat16, 数据格式支持ND, 输入在有/无专家时分别为[E, N1]/[N1].
deq_scale2: Tensor类型, 可选参数, 量化参数, 第二组matmul的反量化缩放系数, 数据类型支持int64、float32、bfloat16, 数据格式支持ND, 输入在有/无专家时分别为[E, N2]/[N2].
antiquant_scale1: Tensor类型, 可选参数, 伪量化参数, 第一组matmul的缩放系数, 数据类型支持float16、bfloat16, 数据格式支持ND, per-channel下输入在有/无专家时分别为[E, N1]/[N1].
antiquant_scale2: Tensor类型, 可选参数, 伪量化参数, 第二组matmul的缩放系数, 数据类型支持float16、bfloat16, 数据格式支持ND, per-channel下输入在有/无专家时分别为[E, N2]/[N2].
antiquant_offset1: Tensor类型, 可选参数, 伪量化参数, 第一组matmul的偏移量, 数据类型支持float16、bfloat16, 数据格式支持ND, per-channel下输入在有/无专家时分别为[E, N1]/[N1].
antiquant_offset2: Tensor类型, 可选参数, 伪量化参数, 第二组matmul的偏移量, 数据类型支持float16、bfloat16, 数据格式支持ND, per-channel下输入在有/无专家时分别为[E, N2]/[N2].
inner_precise: int类型, 可选参数, 表示高精度或者高性能选择. 数据类型支持int64. 该参数仅对float16生效, bfloat16和int8不区分高精度和高性能.
inner_precise为0时, 代表开启高精度模式, 算子内部采用float32数据类型计算.
inner_precise为1时, 代表高性能模式.
inner_precise参数在bfloat16非量化场景, 只能配置为0; float16非量化场景, 可以配置为0或者1; 量化或者伪量化场景, 0和1都可配置, 但是配置后不生效.
output_dtype: ScalarType类型, 可选参数, 该参数只在量化场景生效, 其他场景不生效. 表示输出Tensor的数据类型, 支持输入float16、bfloat16. 默认值为None, 代表输出Tensor数据类型为float16.
输出说明:
一个Tensor类型的输出, 公式中的输出y, 数据类型支持float16、bfloat16, 数据格式支持ND, 输出维度与x一致.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
有专家时, 专家数据的总数需要与x的M保持一致.
激活层为geglu/swiglu/reglu时, 仅支持无专家分组时的float16高性能场景(float16场景指类型为Tensor的必选参数数据类型都为float16的场景), 且N1=2*K2.
激活层为gelu/fastgelu/relu/silu时, 支持有专家或无专家分组的float16高精度及高性能场景, bfloat16场景, 量化场景及伪量化场景, 且N1=K2.
所有场景下需满足K1=N2、K1<65536、K2<65536、M轴在32Byte对齐后小于int32的最大值.
非量化场景不能输入量化参数和伪量化参数, 量化场景不能输入伪量化参数, 伪量化场景不能输入量化参数.
量化场景参数类型: x为int8、weight为int8、bias为int32、scale为float32、offset为float32, 其余参数类型根据y不同分两种情况:
y为float16, deqScale支持数据类型uint64、int64、float32.
y为bfloat16, deqScale支持数据类型bfloat16.
要求deqScale1与deqScale2的数据类型保持一致.
量化场景支持scale的per-channel模式参数类型: x为int8、weight为int8、bias为int32、scale为float32、offset为float32, 其余参数类型根据y不同分两种情况:
y为float16, deqScale支持数据类型uint64、int64.
y为bfloat16, deqScale支持数据类型bfloat16.
要求deqScale1与deqScale2的数据类型保持一致.
伪量化场景支持两种不同参数类型:
y为float16、x为float16、bias为float16、antiquant_scale为float16、antiquant_offset为float16、weight支持数据类型int8.
y为bfloat16、x为bfloat16、bias为float32、antiquant_scale为bfloat16、antiquant_offset为bfloat16、weight支持数据类型int8.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 2.0
PyTorch 1.11.0
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import logging
import os
cpu_x = torch.randn((1, 1280), device='npu', dtype=torch.float16)
cpu_weight1 = torch.randn(1280, 10240, device='npu', dtype=torch.float16)
cpu_weight2 = torch.randn(10240, 1280, device='npu', dtype=torch.float16)
activation = "fastgelu"
npu_out = torch_npu.npu_ffn(cpu_x.npu(), cpu_weight1.npu(), cpu_weight2.npu(), activation, inner_precise=1)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import os
os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight1, weight2, activation, expert):
return torch_npu.npu_ffn(x, weight1, weight2, activation, expert_tokens=expert, inner_precise=1)
cpu_model = MyModel()
cpu_x = torch.randn((1954, 2560),device='npu',dtype=torch.float16)
cpu_weight1 = torch.randn((16, 2560, 5120),device='npu',dtype=torch.float16)
cpu_weight2 = torch.randn((16, 5120, 2560),device='npu',dtype=torch.float16)
activation = "fastgelu"
expert = [227, 62, 78, 126, 178, 27, 122, 1, 19, 182, 166, 118, 66, 217, 122, 243]
model = cpu_model.npu()
model = torch.compile(model, backend=npu_backend, dynamic=True)
npu_out = model(cpu_x.npu(), cpu_weight1.npu(), cpu_weight2.npu(), activation, expert)
"""
)
_add_torch_npu_docstr(
"npu_incre_flash_attention",
"""
功能描述:
增量FA实现, 实现对应公式:
atten_out=softmax(scale*(query*key)+atten_mask)*value
接口原型:
torch_npu.npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? padding_mask=None, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_lengths=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? kv_padding_size=None, int num_heads=1, float scale_value=1.0, str input_layout="BSH", int num_key_value_heads=0, int block_size=0, int inner_precise=1) -> Tensor
参数说明:
query: Tensor类型, 数据格式支持ND.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16.
key: Tensor类型, 数据格式支持ND.
Atlas 推理系列加速卡产品: 数据类型支持float16、bfloat16、int8.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8.
value: Tensor类型, 数据格式支持ND.
Atlas 推理系列加速卡产品: 数据类型支持float16、int8.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8.
*: 代表其之前的变量是位置相关, 需要按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
padding_mask: Tensor类型, 预留参数, 暂未使用, 默认值为None.
pse_shift: Tensor类型, 表示在attention结构内部的位置编码参数, 数据格式支持ND. 如不使用该功能时可不传或传入None.
Atlas 推理系列加速卡产品: 仅支持None.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16.
atten_mask: Tensor类型, 取值为1代表该位不参与计算(不生效), 为0代表该位参与计算, 默认值为None, 即全部参与计算; 数据类型支持bool、int8、uint8, 数据格式支持ND.
actual_seq_lengths: int型数组, 其shape为(B,)或(1,), 形如[1, 2, 3], 代表key、value中有效的S序列长度, 默认值为None, 即全部有效, 类型为List int; 数据类型为int64, 数据格式支持ND.
dequant_scale1: Tensor类型, 数据类型支持float32, 数据格式支持ND, 表示BMM1后面反量化的量化因子, 支持per-tensor(scalar). 如不使用该功能时可不传或传入None. Atlas 推理系列加速卡产品暂不使用该参数.
quant_scale1: Tensor类型, 数据类型支持float32, 数据格式支持ND, 表示BMM2前面量化的量化因子, 支持per-tensor(scalar). 如不使用该功能时可不传或传入None. Atlas 推理系列加速卡产品暂不使用该参数.
dequant_scale2: Tensor类型, 数据类型支持float32, 数据格式支持ND, 表示BMM2后面反量化的量化因子, 支持per-tensor(scalar). 如不使用该功能时可不传或传入None. Atlas 推理系列加速卡产品暂不使用该参数.
quant_scale2: Tensor类型, 数据格式支持ND, 表示输出量化的量化因子, 支持per-tensor(scalar)和per-channel(list). 如不使用该功能时可不传或传入None.
Atlas 推理系列加速卡产品: 当前版本不支持.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32、bfloat16.
quant_offset2: Tensor类型, 数据格式支持ND, 表示输出量化的量化偏移, 支持per-tensor(scalar)和per-channel(list). 如不使用该功能时可不传或传入None.
Atlas 推理系列加速卡产品: 当前版本不支持.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32、bfloat16.
antiquant_scale: Tensor类型, 数据格式支持ND, 表示量化因子, 支持per-channel(list), 由shape决定, BNSD场景下shape为(2, N, 1, D), BSH场景下shape为(2, H), BSND场景下shape为(2, N, D). 如不使用该功能时可不传或传入None.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16.
antiquant_offset: Tensor类型, 数据格式支持ND, 表示量化偏移, 支持per-channel(list), 由shape决定, BNSD场景下shape为(2, N, 1, D), BSH场景下shape为(2, H), BSND场景下shape为(2, N, D). 如不使用该功能时可不传或传入None.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16.
block_table: Tensor类型, 数据类型支持int32, 数据格式支持ND. block_table为2维Tensor, 表示PageAttention中KV存储使用的block映射表, 具体约束和使用方法可见约束说明. 如不使用该功能时可不传或传入None.
kv_padding_size: Tensor类型, 数据类型支持int64, 数据格式支持ND, 表示kv左padding场景使能时, 最后一个有效token到S的距离. 如不使用该功能时可传入None.
num_heads: int类型, 代表query的头数, 即query的N, 默认值为1; 数据类型为int64.
scale_value: float类型, 代表缩放系数, 用来约束梯度, 其默认值为1.0, 典型值为$\frac{1}{\sqrt{D}}$; 数据类型为float32.
input_layout: 字符串类型, 代表query、key、value的布局, 根据输入的query、key、value的shape确定, 三维Tensor是BSH, 四维Tensor是BNSD或BSND, 默认值为BSH, 不支持其他值; 数据类型为string.
query、key、value数据排布格式支持从多种维度解读, 其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸, 且满足D=H/N.
num_key_value_heads: int类型, 代表key、value的头数, 用于支持GQA(Grouped-Query Attention, 分组查询注意力)场景, 默认值为0, 表示与query的头数相同, 否则表示key、value的头数, 且num_heads需要能被num_key_value_heads整除; num_heads与num_key_value_heads的比值不能大于64. 数据类型为int64.
block_size: int类型, PageAttention中KV存储每个block中最大的token个数, 默认为0, 通常为128、256等值, 数据类型支持int64.
inner_precise: int类型, 代表高精度/高性能选择, 0代表高精度, 1代表高性能, 默认值为1(高性能), 数据类型支持int64.
输出说明:
atten_out: Tensor类型, 计算的最终结果, shape与query保持一致.
非量化场景下, 输出数据类型与query的数据类型保持一致.
量化场景下, 若传入quant_scale2, 则输出数据类型为int8.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
query、key、value的维度必须保持一致, key、value的shape必须保持一致.
num_heads的值要等于query的N.
input_layout的值与query的shape相关, 三维是BSH, 四维是BNSD或BSND.
num_key_value_heads的值要等于key、value的N, 且num_heads需要能被num_key_value_heads整除.
query, key, value输入, 功能使用限制如下:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品支持B轴小于等于65535, 支持N轴小于等于256, 支持S轴小于等于262144, 支持D轴小于等于512.
Atlas 推理系列加速卡产品支持B轴小于等于256, 支持N轴小于等于256, 支持S轴小于等于65536, 支持D轴小于等于512.
query、key、value输入均为int8的场景暂不支持.
int8量化相关入参数量与输入、输出数据格式的综合限制:
query、key、value输入为float16, 输出为int8的场景: 入参quant_scale2必填, quant_offset2可选, 不能传入dequant_scale1、quant_scale1、dequant_scale2(即为None)参数.
pse_shift功能使用限制如下:
pse_shift数据类型需与query数据类型保持一致.
仅支持D轴对齐, 即D轴可以被16整除.
page attention使用限制:
page attention使能必要条件是block_table存在且有效, 且传入每个batch对应的actual_seq_lengths. page attention使能场景下, key、value是按照block_table中的索引在一片连续内存中排布, 支持key、value数据类型为float16、bfloat16、int8.
page attention使能场景下, 输入kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims)或(blocknum, blocksize, H), blocknum不应小于每个batch所需block个数的总和. 通常情况下, kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims)时, 性能比kv cache排布格式为(blocknum, blocksize, H)时更好.
page attention使能场景下, 支持kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims), 但此时query layout仅支持BNSD.
page attention使能场景下, 当输入kv cache排布格式为(blocknum, blocksize, H), 且H(H=numKvHeads * headDims)超过64k时, 受硬件指令约束, 会被拦截报错.
page attention场景下, 必须传入输入actual_seq_lengths, 每个batch的actualSeqLength表示每个batch对sequence真实长度, 该值除以属性输入blocksize即表示每个batch所需block数量.
page attention场景下, block_table必须为二维Tensor, 第一维长度需等于batch数, 第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actual_seq_lengths对应的block数量). 例如, batch数为2, 属性blocksize=128, 当每个batch的actualSeqLength为512时, 表明每个batch至少需要4个block, 因此block_table的排布可以为(2, 4).
page attention使能场景下, block_size是用户自定义的参数, 该参数的取值会影响page attention的性能, 通常为128或256. key、value输入类型为float16、bfloat16时block_size需要16对齐; key、value输入类型为int8时block_size需要32对齐. 通常情况下, page attention可以提高吞吐量, 但会带来性能上的下降.
quant_scale2、quant_offset2为一组参数, 其中quant_offset2可选, 传入该组参数后算子输出数据类型会推导为int8, 若不期望int8输出, 请勿传入该组参数.
kv左padding场景使用限制:
kvCache的搬运起点计算公式为: Smax-kv_padding_size-actual_seq_lengths. kvCache的搬运终点计算公式为: Smax-kv_padding_size. 其中kvCache的搬运起点或终点小于0时, 返回数据结果为全0.
kv左padding场景kv_padding_size小于0时将被置为0.
kv左padding场景使能需要同时存在kv_padding_size和actual_seq_lengths参数, 否则默认为kv右padding场景.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列加速卡产品
调用示例:
单算子调用
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
k = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
v = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
# 调用IFA算子
out = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale)
# 执行上述代码的输出类似如下
tensor([[[ 0.3149, -0.2460, 0.7939, ..., 0.5737, -0.4929, -0.1500]],
[[ 0.8115, 1.3789, 0.6484, ..., -0.9092, -0.6206, -0.7412]]],
device='npu:0', dtype=torch.float16)
图模式调用
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
k = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
v = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
atten = torch.randn(2, 1, 1, 2048).bool().npu()
scale_value = 1/math.sqrt(128.0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
print("single op output with mask:", single_op, single_op.shape)
print("graph output with mask:", graph_output, graph_output.shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[ 0.2488, -0.6572, 1.0928, ..., 0.1694, 0.1142, -2.2266]],
[[-0.9595, -0.9609, -0.6602, ..., 0.7959, 1.7920, 0.0783]]],
device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
graph output with mask: tensor([[[ 0.2488, -0.6572, 1.0928, ..., 0.1694, 0.1142, -2.2266]],
[[-0.9595, -0.9609, -0.6602, ..., 0.7959, 1.7920, 0.0783]]],
device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
"""
)
_add_torch_npu_docstr(
"npu_prompt_flash_attention",
"""
功能描述:
全量FA实现, 实现对应公式:
atten_out=softmax(scale*(Q*K)+atten_mask)*V
接口原型:
torch_npu.npu_prompt_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, padding_mask=None, Tensor? atten_mask=None, int[]? actual_seq_lengths=None, Tensor? deq_scale1=None, Tensor? quant_scale1=None, Tensor? deq_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147483647, int next_tokens=0, str input_layout="BSH", int num_key_value_heads=0, int[]? actual_seq_lengths_kv=None, int sparse_mode=0) -> Tensor
参数说明:
query、key、value数据排布格式支持从多种维度解读, 其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸, 且满足D=H/N、T表示所有Batch输入样本序列长度的累加和.
query: Tensor类型, 公式中的输入Q, 数据类型与key的数据类型需满足数据类型推导规则, 即保持与key、value的数据类型一致. 不支持非连续的Tensor, 数据格式支持ND.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8.
key: Tensor类型, 公式中的输入K, 数据类型与query的数据类型需满足数据类型推导规则, 即保持与query、value的数据类型一致. 不支持非连续的Tensor, 数据格式支持ND.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8.
value: Tensor类型, 公式中的输入V, 数据类型与query的数据类型需满足数据类型推导规则, 即保持与query、key的数据类型一致. 不支持非连续的Tensor, 数据格式支持ND.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8.
*: 代表其之前的变量是位置相关, 需要按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
pse_shift: Tensor类型, 可选参数. 不支持非连续的Tensor, 数据格式支持ND. 输入shape类型需为(B, N, Q_S, KV_S)或(1, N, Q_S, KV_S), 其中Q_S为query的shape中的S, KV_S为key和value的shape中的S. 对于pse_shift的KV_S为非32字节对齐的场景, 建议padding到32字节来提高性能, 多余部分的填充值不做要求. 如不使用该功能时可传入None. 综合约束请见约束说明.
Atlas 推理系列加速卡产品: 暂不支持该参数.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16. 当pse_shift为float16时, 要求query为float16或int8; 当pse_shift为bfloat16时, 要求query为bfloat16. 在query、key、value为float16且pse_shift存在的情况下, 默认走高精度模式.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16. 当pse_shift为float16时, 要求query为float16或int8; 当pse_shift为bfloat16时, 要求query为bfloat16. 在query、key、value为float16且pse_shift存在的情况下, 默认走高精度模式.
padding_mask: 预留参数, 暂未使用, 默认值为None.
atten_mask: Tensor类型, 代表下三角全为0上三角全为负无穷的倒三角mask矩阵, 数据类型支持bool、int8和uint8. 数据格式支持ND, 不支持非连续的Tensor. 如果不使用该功能可传入None. 通常建议shape输入(Q_S, KV_S)、(B, Q_S, KV_S)、(1, Q_S, KV_S)、(B, 1, Q_S, KV_S)、(1, 1, Q_S, KV_S), 其中Q_S为query的shape中的S, KV_S为key和value的shape中的S, 对于attenMask的KV_S为非32字节对齐的场景, 建议padding到32字节对齐来提高性能, 多余部分填充成1. 综合约束请见7.2.1.79-约束说明.
actual_seq_lengths: int类型数组, 代表不同Batch中query的有效seqlen, 数据类型支持int64. 如果不指定seqlen可以传入None, 表示和query的shape的s长度相同. 限制: 该入参中每个batch的有效Sequence Length应该不大于query中对应batch的seqlen. seqlen的传入长度为1时, 每个Batch使用相同seqlen; 传入长度大于等于Batch数时取seqlen的前Batch个数. 其它长度不支持.
Atlas 推理系列加速卡产品: 暂不支持该参数.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持TND格式. 当query的input_layout为TND时, 该入参必须传入, 且以该入参元素的数量作为Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
Atlas A3 训练系列产品: 支持TND格式. 当query的input_layout为TND时, 该入参必须传入, 且以该入参元素的数量作为Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
deq_scale1: Tensor类型, 表示BMM1后面的反量化因子, 支持per-tensor. 数据类型支持uint64、float32, 数据格式支持ND. 如不使用该功能时可传入None. Atlas 推理系列加速卡产品暂不支持该参数.
quant_scale1: Tensor类型, 数据类型支持float32. 数据格式支持ND, 表示BMM2前面的量化因子, 支持per-tensor. 如不使用该功能时可传入None. Atlas 推理系列加速卡产品暂不支持该参数.
deq_scale2: Tensor类型, 数据类型支持uint64、float32. 数据格式支持ND, 表示BMM2后面的反量化因子, 支持per-tensor. 如不使用该功能时可传入None. Atlas 推理系列加速卡产品暂不支持该参数.
quant_scale2: Tensor类型, 数据格式支持ND, 表示输出的量化因子, 支持per-tensor、per-channel. 如不使用该功能时可传入None.
Atlas 推理系列加速卡产品: 暂不支持该参数.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32、bfloat16. 当输入为bfloat16时, 同时支持float32和bfloat16 , 否则仅支持float32 . per-channel格式, 当输出layout为BSH时, 要求quant_scale2所有维度的乘积等于H; 其他layout要求乘积等于N*D(建议输出layout为BSH时, quant_scale2 shape传入(1, 1, H)或(H,); 输出为BNSD时, 建议传入(1, N, 1, D)或(N, D); 输出为BSND时, 建议传入(1, 1, N, D)或(N, D)).
Atlas A3 训练系列产品: 数据类型支持float32、bfloat16. 当输入为bfloat16时, 同时支持float32和bfloat16 , 否则仅支持float32 . per-channel格式, 当输出layout为BSH时, 要求quant_scale2所有维度的乘积等于H; 其他layout要求乘积等于N*D(建议输出layout为BSH时, quant_scale2 shape传入(1, 1, H)或(H,); 输出为BNSD时, 建议传入(1, N, 1, D)或(N, D); 输出为BSND时, 建议传入(1, 1, N, D)或(N, D)).
quant_offset2: Tensor类型, 数据格式支持ND, 表示输出的量化偏移, 支持per-tensor、per-channel. 若传入quant_offset2, 需保证其类型和shape信息与 quant_scale2一致. 如不使用该功能时可传入None.
Atlas 推理系列加速卡产品: 暂不支持该参数.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float32、bfloat16.
Atlas A3 训练系列产品: 数据类型支持float32、bfloat16.
num_heads: int类型数组, 代表query的head个数, 数据类型支持int64.
scale_value: 浮点型, 公式中d开根号的倒数, 代表缩放系数, 作为计算流中Muls的scalar值, 数据类型支持float. 数据类型与query的数据类型需满足数据类型推导规则. 用户不特意指定时可传入默认值1.0.
pre_tokens: int类型, 用于稀疏计算, 表示attention需要和前几个Token计算关联, 数据类型支持int64. 用户不特意指定时可传入默认值2147483647. Atlas 推理系列加速卡产品仅支持默认值2147483647.
next_tokens: int类型, 用于稀疏计算, 表示attention需要和后几个Token计算关联. 数据类型支持int64. 用户不特意指定时可传入默认值0. Atlas 推理系列加速卡产品仅支持0和2147483647.
input_layout: 字符串类型, 用于标识输入query、key、value的数据排布格式, 当前支持BSH、BSND、BNSD、BNSD、BNSD_BSND(输入为BNSD时, 输出格式为BSND). 用户不特意指定时可传入默认值"BSH". 支持TND(不支持pse、全量化、后量化).
num_key_value_heads: int类型, 代表key、value中head个数, 用于支持GQA(Grouped-Query Attention, 分组查询注意力)场景, 数据类型支持int64. 用户不特意指定时可传入默认值0, 表示key/value和query的head个数相等. 限制: 需要满足num_heads整除num_key_value_heads, num_heads与num_key_value_heads的比值不能大于64, 且在BSND、BNSD、BNSD_BSND场景下, 需要与shape中的key/value的N轴shape值相同, 否则报错. Atlas 推理系列加速卡产品仅支持默认值0.
actual_seq_lengths_kv: int类型数组, 代表不同batch中key/value的有效seqlenKV. 数据类型支持int64. 限制: 该入参中每个batch的有效seqlenKV应该不大于key/value中对应batch的seqlenKV. seqlenKV的传入长度为1时, 每个Batch使用相同seqlenKV; 传入长度大于等于Batch数时取seqlenKV的前Batch个数, 其它长度不支持.
Atlas 推理系列加速卡产品: 暂不支持该参数.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持TND格式. 当key/value的input_layout为TND时, 该入参必须传入, 且以该入参元素的数量作为Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKV和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
Atlas A3 训练系列产品: 支持TND格式. 当key/value的input_layout为TND时, 该入参必须传入, 且以该入参元素的数量作为Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKV和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
sparse_mode: int类型, 表示sparse的模式, 数据类型支持int64. Atlas 推理系列加速卡产品仅支持默认值0.
sparse_mode为0时, 代表defaultMask模式, 如果atten_mask未传入则不做mask操作, 忽略preTokens和nextTokens(内部赋值为INT_MAX); 如果传入, 则需要传入完整的atten_mask矩阵(S1 * S2), 表示pre_tokens和next_tokens之间的部分需要计算.
sparse_mode为1时, 代表allMask.
sparse_mode为2时, 代表leftUpCausal模式的mask, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为3时, 代表rightDownCausal模式的mask, 均对应以左顶点为划分的下三角场景, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为4时, 代表band模式的mask, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为5、6、7、8时, 分别代表prefix、global、dilated、block_local, 均暂不支持. 用户不特意指定时可传入默认值0.
输出说明
atten_out: Tensor类型, 计算的最终结果, shape与query保持一致.
Atlas 推理系列加速卡产品: 数据类型支持float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
该接口与PyTorch配合使用时, 需要保证CANN相关包与PyTorch相关包的版本匹配.
入参为空的处理: 算子内部需要判断参数query是否为空, 如果是空则直接返回. 参数query不为空Tensor, 参数key、value为空tensor(即S2为0), 则填充全零的对应shape的输出(填充attention_out). attention_out为空Tensor时, AscendCLNN框架会处理.
query、key、value输入, 功能使用限制如下:
轴约束
Atlas 推理系列加速卡产品: 支持B轴小于等于128. 支持N轴小于等于256. 支持S轴小于等于65535(64k). 支持D轴小于等于512.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品&Atlas A3 训练系列产品:
{支持B轴小于等于65536(64k), D轴32byte不对齐时仅支持到128.
支持N轴小于等于256.
S支持小于等于20971520(20M). 长序列场景下, 如果计算量过大可能会导致PFA算子执行超时(aicore error类型报错, errorStr为timeout or trap error), 此场景下建议做S切分处理, 注: 这里计算量会受B、S、N、D等的影响, 值越大计算量越大. 典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
B=1, Q_N=20, Q_S=1048576, D = 256, KV_N=1, KV_S=1048576.
B=1, Q_N=2, Q_S=10485760, D = 256, KV_N=2, KV_S=10485760.
B=20, Q_N=1, Q_S=1048576, D = 256, KV_N=1, KV_S=1048576.
B=1, Q_N=10, Q_S=1048576, D = 512, KV_N=1, KV_S=1048576.
支持D轴小于等于512. input_layout为BSH或者BSND时, 要求N*D小于65535.
TND场景下query, key, value输入的综合限制:
B=1, Q_N=20, Q_S=1048576, D = 256, KV_N=1, KV_S=1048576.
T小于等于65536;
N等于8/16/32/64/128, 且Q_N、K_N、V_N相等;
Q_D、K_D等于192, V_D等于128/192;
数据类型仅支持BFLOAT16;
sparse模式仅支持sparse=0且不传mask, 或sparse=3且传入mask;
当sparse=3时, 要求每个batch单独的actualSeqLengths < actualSeqLengthsKv. }
参数sparse_mode当前仅支持值为0、1、2、3、4的场景, 取其它值时会报错.
sparse_mode=0时, atten_mask如果为None, 则忽略入参pre_tokens、next_tokens(内部赋值为INT_MAX).
sparse_mode=2、3、4时, atten_mask的shape需要为(S, S)或(1, S, S)或(1, 1, S, S), 其中S的值需要固定为2048, 且需要用户保证传入的atten_mask为下三角, 不传入atten_mask或者传入的shape不正确报错.
sparse_mode=1、2、3的场景忽略入参pre_tokens、next_tokens并按照相关规则赋值.
int8量化相关入参数量与输入、输出数据格式的综合限制:
输入为int8, 输出为int8的场景: 入参deq_scale1、quant_scale1、deq_scale2、quant_scale2需要同时存在, quant_offset2可选, 不传时默认为0.
输入为int8, 输出为float16的场景: 入参deq_scale1、quant_scale1、deq_scale2需要同时存在, 若存在入参quant_offset2或quant_scale2(即不为None), 则报错并返回.
输入为float16或bfloat16, 输出为int8的场景: 入参quant_scale2需存在, quant_offset2可选, 不传时默认为0, 若存在入参deq_scale1或quant_scale1或deq_scale2(即不为None), 则报错并返回.
入参quant_offset2和quant_scale2支持per-tensor/per-channel两种格式和float32/bfloat16两种数据类型. 若传入quant_offset2, 需保证其类型和shape信息与quant_scale2一致. 当输入为bfloat16时, 同时支持float32和bfloat16, 否则仅支持float32. per-channel格式, 当输出layout为BSH时, 要求quant_scale2所有维度的乘积等于H; 其他layout要求乘积等于N*D. 当输出layout为BSH时, quant_scale2 shape建议传入(1, 1, H)或(H,); 当输出为BNSD时, 建议传入(1, N, 1, D)或(N, D); 当输出为BSND时, 建议传入(1, 1, N, D)或(N, D). per-tensor格式, 建议D轴对齐到32Byte.
per-channel格式, 入参quant_scale2和quant_offset2暂不支持左padding、Ring Attention或者D非32Byte对齐的场景.
输出为int8时, 暂不支持sparse为band且pre_tokens/next_tokens为负数.
pse_shift功能使用限制如下:
支持query数据类型为float16或bfloat16或int8场景下使用该功能.
query, key, value数据类型为float16且pse_shift存在时, 强制走高精度模式, 对应的限制继承自高精度模式的限制.
Q_S需大于等于query的S长度, KV_S需大于等于key的S长度.
输出为int8, 入参quant_offset2传入非None和非空tensor值, 并且sparse_mode、pre_tokens和next_tokens满足以下条件, 矩阵会存在某几行不参与计算的情况, 导致计算结果误差, 该场景会拦截:
sparseMode=0, atten_mask如果非None, 每个batch actual_seq_lengths-actual_seq_lengths_kv-pre_tokens>0或nextTokens<0时, 满足拦截条件.
sparseMode=1或2, 不会出现满足拦截条件的情况.
sparseMode=3, 每个batch actual_seq_lengths_kv- actual_seq_lengths<0, 满足拦截条件.
sparseMode= 4, preTokens<0或每个batch next_tokens+actual_seq_lengths_kv-actual_seq_lengths<0时, 满足拦截条件.
kv伪量化参数分离当前暂不支持.
暂不支持D不对齐场景.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
Atlas 推理系列加速卡产品
调用示例:
单算子调用
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
actseqlen = [164]
actseqlenkv = [1024]
# 调用PFA算子
out = torch_npu.npu_prompt_flash_attention(q, k, v,
actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv,
num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出类似如下
tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16)
图模式调用
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
print("single op output with mask:", single_op, single_op.shape)
print("graph output with mask:", graph_output, graph_output.shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
"""
)
_add_torch_npu_docstr(
"npu_fused_infer_attention_score",
"""
功能描述:
算子功能: 适配增量&全量推理场景的FlashAttention算子, 既可以支持全量计算场景(PromptFlashAttention), 也可支持增量计算场景(IncreFlashAttention). 当Query矩阵的S为1, 进入IncreFlashAttention分支, 其余场景进入PromptFlashAttention分支.
计算公式:
attention_out = softmax(scale*(query*key)+atten_mask)*value
接口原型:
torch_npu.npu_fused_infer_attention_score(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_lengths=None, SymInt[]? actual_seq_lengths_kv=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? query_padding_size=None, Tensor? kv_padding_size=None, Tensor? key_antiquant_scale=None, Tensor? key_antiquant_offset=None, Tensor? value_antiquant_scale=None, Tensor? value_antiquant_offset=None, Tensor? key_shared_prefix=None, Tensor? value_shared_prefix=None, Tensor? actual_shared_prefix_len=None,Tensor? query_rope=None, Tensor? key_rope=None, Tensor? key_rope_antiquant_scale=None, int num_heads=1, float scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int num_key_value_heads=0, int sparse_mode=0, int inner_precise=0, int block_size=0, int antiquant_mode=0, bool softmax_lse_flag=False, int key_antiquant_mode=0, int value_antiquant_mode=0) -> (Tensor, Tensor)
参数说明:
query、key、value数据排布格式支持从多种维度解读, 其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸, 且满足D=H/N、T表示所有Batch输入样本序列长度的累加和.
query: Tensor类型, attention结构的Query输入, 数据类型支持float16、bfloat16、int8, 不支持非连续的Tensor, 数据格式支持ND.
key: Tensor类型, attention结构的Key输入, 不支持非连续的Tensor, 数据格式支持ND.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8、int4(int32).
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8、int4(int32).
value: Tensor类型, attention结构的Value输入, 不支持非连续的Tensor, 数据格式支持ND.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8、int4(int32).
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8、int4(int32).
*: 代表其之前的变量是位置相关, 需要按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
pse_shift: Tensor类型, 在attention结构内部的位置编码参数, 数据类型支持float16、bfloat16, 数据类型与query的数据类型需满足数据类型推导规则. 不支持非连续的Tensor, 数据格式支持ND. 如不使用该功能时可传入None.
Q_S大于1, 要求在pse_shift为float16类型时, 此时的query为float16或int8类型; 而在pse_shift为bfloat16类型时, 要求此时的query为bfloat16类型. 输入shape类型需为(B, Q_N, Q_S, KV_S)或(1, Q_N, Q_S, KV_S), 其中Q_S为query的shape中的S, KV_S为key和value的shape中的S. 对于pse_shift的KV_S为非32对齐的场景, 建议padding到32字节来提高性能, 多余部分的填充值不做要求.
Q_S为1, 要求在pse_shift为float16类型时, 此时的query为float16类型; 而在pse_shift为bfloat16类型时, 要求此时的query为bfloat16类型. 输入shape类型需为(B, Q_N, 1, KV_S)或(1, Q_N, 1, KV_S), KV_S为key和value的shape中的S. 对于pse_shift的KV_S为非32对齐的场景, 建议padding到32字节来提高性能, 多余部分的填充值不做要求.
atten_mask: Tensor类型, 对QK的结果进行mask, 用于指示是否计算Token间的相关性, 数据类型支持bool、int8和uint8. 不支持非连续的Tensor, 数据格式支持ND. 如果不使用该功能可传入None.
sparse_mode为0、1时
支持shape传入(1,Q_S,KV_S)、(B,1,Q_S,KV_S)、(1,1,Q_S,KV_S)。
当输入input_layout为BSH、BSND、BNSD、BNSD_BSND时,且query、key、value的D相等,并且不传query_rope和key_rope时,Q_S为1可支持传入(B,KV_S),Q_S大于1时可支持传入(Q_S,KV_S)。
如果Q_S、KV_S非16或32对齐,可以向上取到对齐的S。综合约束请见约束声明。
sparse_mode为2、3、4时,shape输入支持(2048,2048)或(1,2048,2048)或(1,1,2048,2048)。
actual_seq_lengths: int类型数组, 代表不同Batch中query的有效seqlen, 数据类型支持int64. 如果不指定seqlen可以传入None, 表示和query的shape的s长度相同. 限制: 该入参中每个batch的有效seqlen应该不大于query中对应batch的seqlen, Q_S为1时该参数无效. seqlen的传入长度为1时, 每个Batch使用相同seqlen; 传入长度大于等于Batch时取seqlen的前Batch个数. 其他长度不支持. 当query的input_layout为TND时, 该入参必须传入, 且以该入参元素的数量作为Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
actual_seq_lengths_kv: int类型数组, 代表不同Batch中key/value的有效seqlenKv, 数据类型支持int64. 如果不指定None, 表示和key/value的shape的S长度相同. 不同O_S值有不同的约束, 具体参见约束说明.
dequant_scale1: Tensor类型, 数据类型支持uint64、float32. 数据格式支持ND, 表示BMM1后面的反量化因子, 支持per-tensor. 如不使用该功能时传入None.
quant_scale1: Tensor类型, 数据类型支持float32. 数据格式支持ND, 表示BMM2前面的量化因子, 支持per-tensor. 如不使用该功能时可传入None, 综合约束请见约束说明.
dequant_scale2: Tensor类型, 数据类型支持uint64、float32. 数据格式支持ND, 表示BMM2后面的反量化因子, 支持per-tensor. 如不使用该功能时传入None.
quant_scale2: Tensor类型, 数据类型支持float32、bfloat16. 数据格式支持ND, 表示输出的量化因子, 支持per-tensor、per-channel. 当输入为bfloat16时, 同时支持float32和bfloat16 , 否则仅支持float32 . per-channel格式, 当输出layout为BSH时, 要求quant_scale2所有维度的乘积等于H; 其他layout要求乘积等于Q_N*D(建议输出layout为BSH时, quant_scale2shape传入(1, 1, H)或(H,); 输出为BNSD时, 建议传入(1, Q_N, 1, D)或(Q_N, D); 输出为BSND时, 建议传入(1, 1, Q_N, D)或(Q_N, D)). 如不使用该功能时可传入None, 综合约束请见约束说明.
quant_offset2: Tensor类型, 数据类型支持float32、bfloat16. 数据格式支持ND, 表示输出的量化偏移, 支持per-tensor、per-channel. 若传入quant_offset2, 需保证其类型和shape信息与quantScale2 一致. 如不使用该功能时可传入None, 综合约束请见约束说明.
antiquant_scale: Tensor类型, 数据类型支持float16、bfloat16. 数据格式支持ND, 表示伪量化因子, 支持per-tensor、per-channel, Q_S为1时只支持per-channel, Q_S大于等于2时只支持float16, 如不使用该功能时可传入None, 综合约束请见约束说明.
antiquant_offset: Tensor类型, 数据类型支持float16、bfloat16. 数据格式支持ND, 表示伪量化偏移, 支持per-tensor、per-channel, Q_S为1时只支持per-channel, Q_S大于等于2时只支持float16, 如不使用该功能时可传入None, 综合约束请见约束说明.
block_table: Tensor类型, 数据类型支持int32. 数据格式支持ND. 表示PageAttention中KV存储使用的block映射表, 如不使用该功能可传入None.
query_padding_size: Tensor类型, 数据类型支持int64. 数据格式支持ND. 表示Query中每个batch的数据是否右对齐, 且右对齐的个数是多少. 仅支持Q_S大于1, 其余场景该参数无效. 用户不特意指定时可传入默认值None.
kv_padding_size: Tensor类型, 数据类型支持int64. 数据格式支持ND. 表示key、value中每个batch的数据是否右对齐, 且右对齐的个数是多少. 表示key、value中每个batch的数据是否右对齐, 且右对齐的个数是多少. 用户不特意指定时可传入默认值None.
key_antiquant_scale: Tensor类型. 数据格式支持ND, kv伪量化参数分离时表示key的反量化因子. 如不使用该功能时可传入None, 综合约束请见约束说明. 通常支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理scale、per-token叠加per head并使用page attention模式管理scale.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、float32.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、float32.
key_antiquant_offset: Tensor类型, 数据类型支持float16、bfloat16、float32. 数据格式支持ND, kv伪量化参数分离时表示key的反量化偏移. 支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理offset、per-token叠加per head并使用page attention模式管理offset. Q_S大于等于2时仅支持per-token模式, 如不使用该功能时可传入None, 综合约束请见约束说明.
value_antiquant_scale: Tensor类型, 数据类型支持float16、bfloat16、float32. 数据格式支持ND, kv伪量化参数分离时表示value的反量化因子. Q_S大于等于2时仅支持per-token模式, 如不使用该功能时可传入None, 综合约束请见约束说明. 通常支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理scale、per-token叠加per head并使用page attention模式管理scale.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、float32.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、float32.
value_antiquant_offset: Tensor类型, 数据类型支持float16、bfloat16、float32. 数据格式支持ND, kv伪量化参数分离时表示value的反量化偏移, 支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理offset、per-token叠加per head并使用page attention模式管理offset. Q_S大于等于2时仅支持per-token模式, 如不使用该功能时可传入None, 综合约束请见约束说明.
key_shared_prefix: Tensor类型, attention结构中Key的系统前缀部分的参数, 数据类型支持float16、bfloat16、int8, 不支持非连续的Tensor, 数据格式支持ND. 综合约束请见约束说明.
value_shared_prefix: Tensor类型, attention结构中Value的系统前缀部分的输入, 数据类型支持float16、bfloat16、int8, 不支持非连续的Tensor, 数据格式支持ND. 综合约束请见约束说明.
actual_shared_prefix_len: Tensor类型, 代表key_shared_prefix/value_shared_prefix的有效Sequence Length. 数据类型支持: int64. 如果不指定seqlen可以传入None, 表示和key_shared_prefix/value_shared_prefix的s长度相同. 限制: 该入参中的有效Sequence Length应该不大于key_shared_prefix/value_shared_prefix中的Sequence Length.
query_rope: Tensor类型, 表示MLA(Multi-head Latent Attention)结构中的query的rope信息, 数据类型支持float16、bfloat16, 不支持非连续的Tensor, 数据格式支持ND.
key_rope: Tensor类型, 表示MLA(Multi-head Latent Attention)结构中的key的rope信息, 数据类型支持float16、bfloat16, 不支持非连续的Tensor, 数据格式支持ND.
key_rope_antiquant_scale: Tensor类型, 预留参数, 暂未使用, 使用默认值即可. 表示MLA(Multi-head Latent Attention)结构中的key Rope对应的反量化因子, 支持per-channel, 数据类型支持float16、bfloat16, 不支持非连续的Tensor, 数据格式支持ND, D维度与key_rope的D维度保持一致. 仅支持Q_S等于1-16, 其余场景该参数无效.
num_heads: 整型, 代表query的head个数, 数据类型支持int64, 在BNSD场景下, 需要与shape中的query的N轴shape值相同, 否则执行异常.
scale: 浮点型, 公式中d开根号的倒数, 代表缩放系数, 作为计算流中Muls的scalar值, 数据类型支持float. 数据类型与query的数据类型需满足数据类型推导规则. 用户不特意指定时可传入默认值1.0.
pre_tokens: 整型, 用于稀疏计算, 表示attention需要和前几个Token计算关联, 数据类型支持int64. 用户不特意指定时可传入默认值2147483647, Q_S为1时该参数无效.
next_tokens: 整型, 用于稀疏计算, 表示attention需要和后几个Token计算关联. 数据类型支持int64. 用户不特意指定时可传入默认值2147483647, Q_S为1时该参数无效.
input_layout: 字符串类型, 用于标识输入query、key、value的数据排布格式, 用户不特意指定时可传入默认值"BSH".
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持BSH、BSND、BNSD、BNSD_BSND、TND(不支持左padding、tensorlist、pse、page attention、prefix、伪量化、全量化、后量化, 综合约束请见约束说明). 当为TND时, 不支持图模式配置Tiling调度优化功能(tiling_schedule_optimize=True).
Atlas A3 训练系列产品: 支持BSH、BSND、BNSD、BNSD_BSND、TND(不支持左padding、tensorlist、pse、page attention、prefix、伪量化、全量化、后量化, 综合约束请见约束说明). 当为TND时, 不支持图模式配置Tiling调度优化功能(tiling_schedule_optimize=True).
其中BNSD_BSND含义指当输入为BNSD, 输出格式为BSND, 仅支持Q_S大于1.
num_key_value_heads: 整型, 代表key、value中head个数, 用于支持GQA(Grouped-Query Attention, 分组查询注意力)场景, 数据类型支持int64. 用户不特意指定时可传入默认值0, 表示key/value和query的head个数相等, 需要满足num_heads整除num_key_value_heads, num_heads与num_key_value_heads的比值不能大于64. 在BSND、BNSD、BNSD_BSND(仅支持Q_S大于1)场景下, 还需要与shape中的key/value的N轴shape值相同, 否则执行异常.
sparse_mode: 整型, 表示sparse的模式, 默认值为0. 数据类型支持int64. Q_S为1且不带rope输入时该参数无效. 当前仅支持取值0、1、2、3、4、9, 取值5、6、7、8(分别代表prefix、global、dilated、block_local)暂未实现, 请勿使用. input_layout为TND、TND_NTD、NTD_TND时, 综合约束请见约束说明.
sparse_mode为0时, 代表defaultMask模式, 如果atten_mask未传入则不做mask操作, 忽略pre_tokens和next_tokens(内部赋值为INT_MAX); 如果传入, 则需要传入完整的atten_mask矩阵(S1*S2), 表示pre_tokens和next_tokens之间的部分需要计算.
sparse_mode为1时, 代表allMask, 必须传入完整的attenmask矩阵(S1*S2).
sparse_mode为2时, 代表leftUpCausal模式的mask, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为3时, 代表rightDownCausal模式的mask, 对应以右顶点为划分的下三角场景, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为4时, 代表band模式的mask, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为9时, 代表treeMask模式, 用于推测解码场景的树形注意力掩码. 需传入自定义tree mask, 仅MLA场景(query_rope和key_rope不为空)支持. 不支持左padding、pse_shift、sharedPrefix, 输出dtype不支持int8, 每个batch需满足Q_S <= KV_S. 综合约束请见约束说明.
inner_precise: 整型, 一共4种模式: 0、1、2、3. 一共两位bit位, 第0位(bit0)表示高精度或者高性能选择, 第1位(bit1)表示是否做行无效修正. 数据类型支持int64. Q_S>1时, sparse_mode为0或1, 并传入用户自定义mask的情况下, 建议开启行无效; Q_S为1时该参数仅支持innerPrecise为0和1. 综合约束请见约束说明.
inner_precise为0时, 代表开启高精度模式, 且不做行无效修正.
inner_precise为1时, 代表高性能模式, 且不做行无效修正.
inner_precise为2时, 代表开启高精度模式, 且做行无效修正.
inner_precise为3时, 代表高性能模式, 且做行无效修正.
bfloat16和int8不区分高精度和高性能, 行无效修正对float16、bfloat16和int8均生效. 当前0、1为保留配置值, 当计算过程中“参与计算的mask部分”存在某整行全为1的情况时, 精度可能会有损失. 此时可以尝试将该参数配置为2或3来使能行无效功能以提升精度, 但是该配置会导致性能下降.
block_size: 整型, PageAttention中KV存储每个block中最大的token个数, 默认为0, 数据类型支持int64.
antiquant_mode: 整型, 表示伪量化方式, 传入0时表示为per-channel(per-channel包含per-tensor), 传入1时表示per-token. Q_S大于等于2时该参数无效, 用户不特意指定时可传入默认值0, 传入0和1之外的其他值会执行异常.
softmax_lse_flag: 布尔型, 表示是否输出softmax_lse, 支持S轴外切(增加输出). true表示输出softmax_lse, false表示不输出; 用户不特意指定时可传入默认值false.
key_antiquant_mode: 整型, 表示key的伪量化方式. Q_S大于等于2时仅支持传入值为1, 用户不特意指定时可传入默认值0, 取值除了key_antiquant_mode为0并且value_antiquant_mode为1的场景外, 需要与value_antiquant_mode一致. 综合约束请见约束说明.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持取值0、1、2、3、4、5.
Atlas A3 训练系列产品: 支持取值0、1、2、3、4、5.
key_antiquant_mode为0时, 代表per-channel模式(per-channel包含per-tensor).
key_antiquant_mode为1时, 代表per-token模式.
key_antiquant_mode为2时, 代表per-tensor叠加per-head模式.
key_antiquant_mode为3时, 代表per-token叠加per-head模式.
key_antiquant_mode为4时, 代表per-token叠加使用page attention模式管理scale/offset模式.
key_antiquant_mode为5时, 代表per-token叠加per head并使用page attention模式管理scale/offset模式.
value_antiquant_mode: 整型, 表示value的伪量化方式, 模式编号与key_antiquant_mode一致. Q_S大于等于2时仅支持传入值为1, 用户不特意指定时可传入默认值0, 取值除了key_antiquant_mode为0并且value_antiquant_mode为1的场景外, 需要与key_antiquant_mode一致. 综合约束请见约束说明.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持取值0、1、2、3、4、5.
Atlas A3 训练系列产品: 支持取值0、1、2、3、4、5.
输出说明
attention_out: Tensor类型, 公式中的输出, 数据类型支持float16、bfloat16、int8. 数据格式支持ND. 限制: 当input_layout为BNSD_BSND时, 输入query的shape是BNSD, 输出shape为BSND; 其余情况该参数的shape需要与入参query的shape保持一致.
softmaxLse: Tensor类型, ring attention算法对query乘key的结果, 先取max得到softmax_max. query乘key的结果减去softmax_max, 再取exp, 最后取sum, 得到softmax_sum, 最后对softmax_sum取log, 再加上softmax_max得到的结果. 数据类型支持float32, softmax_lse_flag为True时, 一般情况下, 输出shape为(B, Q_N, Q_S, 1)的Tensor, 当input_layout为TND时, 输出shape为(T,Q_N,1)的Tensor; softmax_lse_flag为False时, 则输出shape为[1]的值为0的Tensor.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
该接口与PyTorch配合使用时, 需要保证CANN相关包与PyTorch相关包的版本匹配.
入参为空的处理: 算子内部需要判断参数query是否为空, 如果是空则直接返回. 参数query不为空Tensor, 参数key、value为空tensor(即S2为0), 则填充全零的对应shape的输出(填充attention_out). attention_out为空Tensor时, 框架会处理.
参数key、value中对应tensor的shape需要完全一致; 非连续场景下key、value的tensorlist中的batch只能为1, 个数等于query的B, N和D需要相等.
int8量化相关入参数量与输入、输出数据格式的综合限制:
输入为int8, 输出为int8的场景: 入参dequant_scale1、quant_scale1、dequant_scale2、quant_scale2需要同时存在, quant_offset2可选, 不传时默认为0.
输入为int8, 输出为float16的场景: 入参dequant_scale1、quant_scale1、dequant_scale2需要同时存在, 若存在入参quant_offset2或quant_scale2(即不为None), 则报错并返回.
输入全为float16或bfloat16, 输出为int8的场景: 入参quant_scale2需存在, quant_offset2可选, 不传时默认为0, 若存在入参dequant_scale1或quant_scale1或dequant_scale2(即不为None), 则报错并返回.
入参quant_offset2和quant_scale2支持per-tensor或per-channel格式, 数据类型支持float32、bfloat16.
antiquant_scale和antiquant_offset参数约束:
支持per-channel、per-tensor和per-token三种模式:
per-channel模式: 两个参数BNSD场景下shape为(2, KV_N, 1, D), BSND场景下shape为(2, KV_N, D), BSH场景下shape为(2, H). 参数数据类型和query数据类型相同, antiquant_mode置0, 当key、value数据类型为int8时支持.
per-tensor模式: 两个参数的shape均为(2,), 数据类型和query数据类型相同, antiquant_mode置0, 当key、value数据类型为int8时支持.
per-token模式: 两个参数的shape均为(2, B, KV_S), 数据类型固定为float32, antiquant_mode置1, 当key、value数据类型为int8时支持.
算子运行在何种模式根据参数的shape进行判断, dim为1时运行per-tensor模式, 否则运行per-channel模式.
支持对称量化和非对称量化:
非对称量化模式下, antiquant_scale和antiquant_offset参数需同时存在.
对称量化模式下, antiquant_offset可以为空(即None); 当antiquant_offset参数为空时, 执行对称量化, 否则执行非对称量化.
query_rope和key_rope参数约束及支持特性:
query_rope和key_rope要求同时配置或同时不配置, 不支持只配置其中一个.
query_rope的数据类型、数据格式与query一致.
key_rope的数据类型、数据格式与key一致.
sparse: Q_S等于1时支持sparse=0且不传mask或sparse=4且传入mask, Q_S大于1时支持sparse=3或sparse=4且传入mask;
sparse不为4时:
query_rope配置时要求query的S为1-16、N为32、64、128, D为512, shape中B、N、S与query一致, D为64.
key_rope配置时要求key的N为1, D为512, key_rope的shape中B、N、S与key一致, D为64.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持key、value的input_layout格式为ND或NZ. 当input_layout为NZ时, 输入参数key和value的格式为[blockNum, KV_N, D/16, blockSize, 16].
Atlas A3 训练系列产品: 支持key、value的input_layout格式为ND或NZ. 当input_layout为NZ时, 输入参数key和value的格式为[blockNum, KV_N, D/16, blockSize, 16].
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: input_layout形状支持BSH、BSND、BNSD, 当数据格式为NZ时input_layout不支持BNSD.
Atlas A3 训练系列产品: input_layout形状支持BSH、BSND、BNSD, 当数据格式为NZ时input_layout不支持BNSD.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 该场景下, 必须开启PageAttention, 此时block_size支持16、128, 其中数据格式为NZ时block_size不支持配置16.
Atlas A3 训练系列产品: 该场景下, 必须开启PageAttention, 此时block_size支持16、128, 其中数据格式为NZ时block_size不支持配置16.
sparse为4时:
query_rope配置时要求query的每batch的S不大于key的每batch的S、N为128, D为512, query_rope的shape中B、N、S与query一致, D为64.
key_rope配置时要求key的S不大于131088, N为1, D为512, key_rope的shape中B、N、S与key一致, D为64.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 仅支持key、value的input_layout格式为ND.
Atlas A3 训练系列产品: 仅支持key、value的input_layout格式为ND.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: input_layout形状仅支持BSND.
Atlas A3 训练系列产品: input_layout形状仅支持BSND.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持开启PageAttention, 此时block_size支持64、128.
Atlas A3 训练系列产品: 支持开启PageAttention, 此时block_size支持64、128.
TND场景下query、key、value输入的综合限制:
T小于等于65536;
N等于8/16/32/64/128, 且Q_N、K_N、V_N相等;
Q_D、K_D等于192, V_D等于128/192;
数据类型仅支持BFLOAT16;
sparse模式仅支持sparse=0且不传mask, 或sparse=3且传入mask;
当sparse=3时, 要求每个batch单独的actual_seq_lengths < actual_seq_lengths_kv.
sparse模式支持sparse\_mode=4且传入mask;当sparse\_mode=4时,要求preTokens >= -actual\_seq\_qlen、nextTokens >= -actual\_seq\_kvlen、preTokens + nextTokens >= 0;
当Q_S大于1时:
query、key、value输入, 功能使用限制如下:
支持B轴小于等于65536, D轴32byte不对齐时仅支持到128.
支持N轴小于等于256, 支持D轴小于等于512; input_layout为BSH或者BSND时, 要求N*D小于65535.
S支持小于等于20971520(20M). 部分长序列场景下, 如果计算量过大可能会导致PFA算子执行超时(aicore error类型报错, errorStr为timeout or trap error), 此场景下建议做S切分处理(注: 这里计算量会受B、S、N、D等的影响, 值越大计算量越大), 典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
B=1, Q_N=20, Q_S=2097152, D=256, KV_N=1, KV_S=2097152.
B=1, Q_N=2, Q_S=20971520, D=256, KV_N=2, KV_S=20971520.
B=20, Q_N=1, Q_S=2097152, D=256, KV_N=1, KV_S=2097152.
B=1, Q_N=10, Q_S=2097152, D=512, KV_N=1, KV_S=2097152.
query、key、value输入类型包含int8时, D轴需要32对齐; 输入类型全为float16、bfloat16时, D轴需16对齐.
actual_seq_lengths_kv: 该参数传入时应为非负数, 在input_layout不同时, 其含义与拦截条件不同: 一般情况下, 该入参为可选入参, 该入参中每个Batch的有效seqlenKv应该不大于key/value中对应Batch的seqlenKv. 当本参数的传入长度为1时, 每个Batch使用相同seqlenKv; 传入长度大于等于Batch时取seqlenKv的前Batch个数. 其他长度不支持. 当key/value的input_layout为TND时, 该入参必须传入, 且该入参元素的数量等于Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKv和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
参数sparse_mode当前仅支持值为0、1、2、3、4的场景, 取其它值时会报错.
sparse_mode=0时, atten_mask如果为None, 或者在左padding场景传入atten_mask, 则忽略入参pre_tokens、next_tokens(内部赋值为INT_MAX).
sparse_mode=2、3、4时, atten_mask的shape需要为(S, S)或(1, S, S)或(1, 1, S, S), 其中S的值需要固定为2048, 且需要用户保证传入的atten_mask为下三角, 不传入atten_mask或者传入的shape不正确报错.
sparse_mode=1、2、3的场景忽略入参pre_tokens、next_tokens并按照相关规则赋值.
kvCache反量化的合成参数场景仅支持int8反量化到float16. 入参key、value的data range与入参antiquant_scale的data range乘积范围在(-1, 1)内, 高性能模式可以保证精度, 否则需要开启高精度模式来保证精度.
page attention场景:
page attention的使能必要条件是block_table存在且有效, 同时key、value是按照block_table中的索引在一片连续内存中排布, 支持key、value数据类型为float16、bfloat16、int8. 在该场景下key、value的input_layout参数无效. block_table中填充的是blockid, 当前不会对blockid的合法性进行校验, 需用户自行保证.
block_size是用户自定义的参数, 该参数的取值会影响page attention的性能, 在使能page attention场景下, block_size最小为128, 最大为512, 且要求是128的倍数. 通常情况下, page attention可以提高吞吐量, 但会带来性能上的下降.
page attention场景下, 当输入kv cache排布格式为(blocknum, blocksize, H), 且KV_N*D超过65535时, 受硬件指令约束, 会被拦截报错. 可通过使能GQA(减小KV_N)或调整kv cache排布格式为(blocknum, KV_N, blocksize, D)解决. 当query的input_layout为BNSD、TND时, kv cache排布支持(blocknum, blocksize, H)和(blocknum, KV_N, blocksize, D)两种格式, 当query的input_layout为BSH、BSND时, kv cache排布只支持(blocknum, blocksize, H)一种格式. blocknum不能小于根据actual_seq_lengths_kv和blockSize计算的每个batch的block数量之和. 且key和value的shape需保证一致.
page attention不支持伪量化场景, 不支持tensorlist场景, 不支持左padding场景.
page attention场景下, 必须传入actual_seq_lengths_kv.
page attention场景下, block_table必须为二维, 第一维长度需等于B, 第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为不同batch中最大actual_seq_lengths_kv对应的block数量).
page atte两种格式和float32/bfloat1ntion场景下, 不支持输入query为int8的场景.
page attention使能场景下, 以下场景输入需满足KV_S>=maxBlockNumPerSeq*blockSize:
传入attenMask时, 如mask shape为 (B, 1, Q_S, KV_S).
传入pseShift时, 如pseShift shape为(B, Q_N, Q_S, KV_S).
query左padding场景:
query左padding场景query的搬运起点计算公式为: Q_S-query_padding_size-actual_seq_lengths. query的搬运终点计算公式为: Q_S-query_padding_size. 其中query的搬运起点不能小于0, 终点不能大于Q_S, 否则结果将不符合预期.
query左padding场景kv_padding_size小于0时将被置为0.
query左padding场景需要与actual_seq_lengths参数一起使能, 否则默认为query右padding场景.
query左padding场景不支持PageAttention, 不能与block_table参数一起使能.
kv左padding场景:
kv左padding场景key和value的搬运起点计算公式为: KV_S-kv_padding_size-actual_seq_lengths_kv. key和value的搬运终点计算公式为: KV_S-kv_padding_size. 其中key和value的搬运起点不能小于0, 终点不能大于KV_S, 否则结果将不符合预期.
kv左padding场景kv_padding_size小于0时将被置为0.
kv左padding场景需要与actual_seq_lengths_kv参数一起使能, 否则默认为kv右padding场景.
kv左padding场景不支持PageAttention, 不能与block_table参数一起使能.
入参quant_scale2和quant_offset2支持per-tensor、per-channel量化, 支持float32、bfloat16类型. 若传入quant_offset2, 需保证其类型和shape信息与quant_scale2一致. 当输入为bfloat16时, 同时支持float32和bfloat16 , 否则仅支持float32. per-channel场景下, 当输出layout为BSH时, 要求quant_scale2所有维度的乘积等于H; 其他layout要求乘积等于N*D. 当输出layout为BSH时, quant_scale2 shape建议传入(1, 1, H)或(H,); 当输出layout为BNSD时, 建议传入(1, Q_N, 1, D)或(Q_N, D); 当输出为BSND时, 建议传入(1, 1, Q_N, D)或(Q_N, D).
输出为int8, quant_scale2和quant_offset2为per-channel时, 暂不支持左padding、Ring Attention或者D非32Byte对齐的场景.
输出为int8时, 暂不支持sparse为band且preTokens/nextTokens为负数.
pse_shift功能使用限制如下:
支持query数据类型为float16或bfloat16或int8场景下使用该功能.
query、key、value数据类型为float16且pse_shift存在时, 强制走高精度模式, 对应的限制继承自高精度模式的限制.
Q_S需大于等于query的S长度, KV_S需大于等于key的S长度. prefix场景KV_S需大于等于actual_shared_prefix_len与key的S长度之和.
输出为int8, 入参quant_offset2传入非None和非空tensor值, 并且sparse_mode、pre_tokens和next_tokens满足以下条件, 矩阵会存在某几行不参与计算的情况, 导致计算结果误差, 该场景会拦截:
sparse_mode=0, atten_mask如果非None, 每个batch actual_seq_lengths-actual_seq_lengths_kv-pre_tokens>0或next_tokens<0时, 满足拦截条件.
sparse_mode=1或 2, 不会出现满足拦截条件的情况.
sparse_mode=3, 每个batch actual_seq_lengths_kv-actual_seq_lengths<0, 满足拦截条件.
sparse_mode=4, pre_tokens<0或每个batch next_tokens+actual_seq_lengths_kv-actual_seq_lengths<0时, 满足拦截条件.
prefix相关参数约束:
key_shared_prefix和value_shared_prefix要么都为空, 要么都不为空.
key_shared_prefix和value_shared_prefix都不为空时, key_shared_prefix、value_shared_prefix、key、value的维度相同、dtype保持一致.
key_shared_prefix和value_shared_prefix都不为空时, key_shared_prefix的shape第一维batch必须为1, layout为BNSD和BSND情况下N、D轴要与key一致、BSH情况下H要与key一致, value_shared_prefix同理. key_shared_prefix和value_shared_prefix的S应相等.
当actual_shared_prefix_len存在时, actual_shared_prefix_len的shape需要为[1], 值不能大于key_shared_prefix和value_shared_prefix的S.
公共前缀的S加上key或value的S的结果, 要满足原先key或value的S的限制.
prefix不支持PageAttention场景、不支持左padding场景、不支持tensorlist场景.
prefix场景不支持query、key、value数据类型同时为int8.
prefix场景, sparse为0或1时, 如果传入attenmask, 则S2需大于等于actual_shared_prefix_len与key的S长度之和.
prefix场景, 不支持输入qkv全部为int8的场景.
kv伪量化参数分离:
key_antiquant_mode和value_antiquant_mode需要保持一致.
key_antiquant_scale和value_antiquant_scale要么都为空, 要么都不为空; key_antiquant_offset和value_antiquant_offset要么都为空, 要么都不为空.
key_antiquant_scale和value_antiquant_scale都不为空时, 其shape需要保持一致; key_antiquant_offset和value_antiquant_offset都不为空时, 其shape需要保持一致.
仅支持per-token模式, 且该模式下要求两个参数的shape均为(B, KV_S), 数据类型固定为float32.
当伪量化参数和KV分离量化参数同时传入时, 以KV分离量化参数为准.
key_antiquant_scale与value_antiquant_scale非空场景, 要求query的s小于等于16.
key_antiquant_scale与value_antiquant_scale非空场景, 要求query的dtype为bfloat16, key、value的dtype为int8, 输出的dtype为bfloat16.
key_antiquant_scale与value_antiquant_scale非空场景, 不支持tensorlist、左padding、page attention、prefix特性.
当Q_S等于1时:
query、key、value输入, 功能使用限制如下:
支持B轴小于等于65536, 支持N轴小于等于256, 支持S轴小于等于262144, 支持D轴小于等于512.
query、key、value输入类型均为int8的场景暂不支持.
在int4(int32)伪量化场景下, PyTorch入图调用仅支持KV int4拼接成int32输入(建议通过dynamicQuant生成int4格式的数据, 因为dynamicQuant就是一个int32包括8个int4).
在int4(int32)伪量化场景下, 若KV int4拼接成int32输入, 那么KV的N、D或者H是实际值的八分之一(prefix同理). 并且, int4伪量化仅支持D 64对齐(int32支持D 8对齐).
actual_seq_lengths_kv: 该参数应为非负数, 在input_layout不同时, 其含义与拦截条件不同: 一般情况下, 该入参为可选入参, 该入参中每个Batch的有效Sequence Length应该不大于key/value中对应Batch的seqlenKv. 当本参数的传入长度为1时, 每个Batch使用相同seqlenKv; 传入长度大于等于Batch时取seqlenKv的前Batch个数. 其他长度不支持. 当input_layout为TND时, 该入参必须传入, 在非PA场景下, 第b个值表示前b个Batch的S轴累加长度, 其值应递增(大于等于前一个值)排列, 且该入参元素的数量代表总Batch数, 在PA场景下, 其长度等于key/value的Batch值, 代表每个Batch的实际长度, 值不大于KV_S.
page attention场景:
使能必要条件是block_table存在且有效, 同时key、value是按照block_table中的索引在一片连续内存中排布, 在该场景下key、value的input_layout参数无效.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持key、value数据类型为float16、bfloat16、int8.
Atlas A3 训练系列产品: 支持key、value数据类型为float16、bfloat16、int8.
该场景下, block_size是用户自定义的参数, 该参数的取值会影响page attention的性能. block_size需要传入非0值,且最大不超过512, key、value输入类型为float16、bfloat16时需要16对齐, key、value输入类型为int8时需要32对齐, 推荐使用128. 通常情况下, page attention可以提高吞吐量, 但会带来性能上的下降.
参数key、value各自对应tensor的shape所有维度相乘不能超过int32的表示范围.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 不支持Q为bfloat16、float16、key、value为int4(int32)的场景.
Atlas A3 训练系列产品: 不支持Q为bfloat16、float16、key、value为int4(int32)的场景.
page attention场景下, blockTable必须为二维, 第一维长度需等于B, 第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为不同batch中最大actual_seq_lengths_kv对应的block数量).
page attention场景下, 当query的input_layout为BNSD、TND时, kv cache排布支持(blocknum, blocksize, H)和(blocknum, KV_N, blocksize, D)两种格式, 当query的input_layout为BSH、BSND时, kv cache排布只支持(blocknum, blocksize, H)一种格式. blocknum不能小于根据actual_seq_lengths_kv和blockSize计算的每个batch的block数量之和. 且key和value的shape需保证一致.
page attention场景下, kv cache排布为(blocknum, KV_N, blocksize, D)时性能通常优于kv cache排布为(blocknum, blocksize, H)时的性能, 建议优先选择(blocknum, KV_N, blocksize, D)格式.
page attention场景下, 当输入kv cache排布格式为(blocknum, blocksize, H), 且 numKvHeads * headDim 超过64k时, 受硬件指令约束, 会被拦截报错. 可通过使能GQA(减小 numKvHeads)或调整kv cache排布格式为(blocknum, numKvHeads, blocksize, D)解决.
page attention不支持tensorlist场景, 不支持左padding场景.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 不支持Q为BF16/FP16且KV为INT4(INT32)的场景.
Atlas A3 训练系列产品: 不支持Q为BF16/FP16且KV为INT4(INT32)的场景.
page attention场景的参数key、value各自对应tensor的shape所有维度相乘不能超过int32的表示范围.
age attention场景下,使能`atten_mask`,当`sparse_mode`不为2、3、4时,传入的`atten_mask`的最后一维需要大于等于`block_table`的第二维 * `block_size`.
page attention场景下,使能`pse_shift`,传入的`pse_shift`的最后一维需要大于等于`block_table`的第二维 * `block_size`.
page attention场景下,以下场景输入S需要大于等于`block_table`的第二维 * `block_size`:
使能伪量化per-token模式:输入参数`antiqunant_scale`和`antiquant_offset`的shape均为\(2, B, S\).
使能per-token叠加per-head模式:两个参数的shape均为\(B, N, S\),数据类型固定为`float32`。支持`key`、`value`数据类型为`int8`、`int4`\(`int32`\)。
kv左padding场景:
kvCache的搬运起点计算公式为: Smax-kv_padding_size-actual_seq_lengths. kvCache的搬运终点计算公式为: Smax-kv_padding_size. 其中kvCache的搬运起点或终点小于0时, 返回数据结果为全0.
kv_padding_size小于0时将被置为0.
使能需要同时存在actual_seq_lengths参数, 否则默认为kv右padding场景.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: kv左padding场景不支持Q为bfloat16/float16、KV为int4(int32)的场景.
Atlas A3 训练系列产品: kv左padding场景不支持Q为bfloat16/float16、KV为int4(int32)的场景.
kv伪量化参数分离:
除了key_antiquant_mode为0并且value_antiquant_mode为1的场景外, key_antiquant_mode和value_antiquant_mode取值需要保持一致.
key_antiquant_scale和value_antiquant_scale要么都为空, 要么都不为空; key_antiquant_offset和value_antiquant_offset要么都为空, 要么都不为空.
key_antiquant_scale和value_antiquant_scale都不为空时, 除了key_antiquant_mode为0并且value_antiquant_mode为1的场景外, 其shape需要保持一致; key_antiquant_offset和value_antiquant_offset都不为空时, 除了key_antiquant_mode为0并且value_antiquant_mode为1的场景外, 其shape需要保持一致.
int4(int32)伪量化场景不支持后量化.
管理scale/offset的量化模式如下:
注意scale、offset两个参数指key_antiquant_scale、key_antiquant_scale、value_antiquant_offset、value_antiquant_offset.
场景下scale和offset条件
per-channel模式: 两个参数shape支持(1, KV_N, 1, D), (1, KV_N, D), (1, H), 数据类型和query数据类型相同.
per-tensor模式: 两个参数的shape均为(1,), 数据类型和query数据类型相同.
per-token模式: 两个参数的shape均为(1, B, KV_S), 数据类型固定为float32.
per-tensor叠加per-head模式: 两个参数的shape均为(KV_N,), 数据类型和query数据类型相同.
per-token叠加per-head模式: 两个参数的shape均为(B, KV_N, KV_S), 数据类型固定为float32.
per-token叠加使用page attention模式: 两个参数的shape均为(blocknum, blocksize), 数据类型固定为float32.
per-token叠加per head并使用page attention模式: 两个参数的shape均为(blocknum, KV_N, blocksize), 数据类型固定为float32.
key支持per-channel叠加value支持per-token模式: 对于key支持per-channel, 两个参数的shape可支持(1, KV_N, 1, D)、(1, KV_N, D)、(1, H), 且参数数据类型和query数据类型相同. 对于value支持per-token, 两个参数的shape均为(1, B, KV_S)并且数据类型固定为float32.
场景下key和value条件
per-channel模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int4(int32)或int8时支持. Atlas A3 训练系列产品: 当key、value数据类型为int4(int32)或int8时支持.
per-tensor模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int8时支持. Atlas A3 训练系列产品: 当key、value数据类型为int8时支持.
per-token模式: key、value数据类型为int4(int32)或int8时支持.
per-tensor叠加per-head模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int8时支持. Atlas A3 训练系列产品: 当key、value数据类型为int8时支持.
per-token叠加per-head模式: key、value数据类型为int4(int32)或int8时支持.
per-token叠加使用page attention模式: key、value数据类型为int8时支持.
per-token叠加per head并使用page attention模式: key、value数据类型为int8时支持.
key支持per-channel叠加value支持per-token模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int4(int32)或int8时支持; 当key和value的数据类型为int8时, 仅支持query和输出的dtype为float16. Atlas A3 训练系列产品: 当key、value数据类型为int4(int32)或int8时支持; 当key和value的数据类型为int8时, 仅支持query和输出的dtype为float16.
支持的产品: Atlas A2 训练系列产品/Atlas 800I A2 推理产品. Atlas A3 训练系列产品
pse_shift功能使用限制如下:
pse_shift数据类型需与query数据类型保持一致. 仅支持D轴对齐, 即D轴可以被16整除.
支持的PyTorch版本
PyTorch 2.1
PyTorch 2.3
PyTorch 2.4
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
actseqlen = [164]
actseqlenkv = [1024]
# 调用FIA算子
out, _ = torch_npu.npu_fused_infer_attention_score(q, k, v,
actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv,
num_heads = 8, input_layout = "BNSD", scale = scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出out类似如下
tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
..
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16)
图模式调用
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_fused_infer_attention_score(q, k, v, num_heads = 8, input_layout = "BNSD", scale=scale, pre_tokens=65535, next_tokens=65535)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_fused_infer_attention_score(q, k, v, num_heads = 8, input_layout = "BNSD", scale=scale, pre_tokens=65535, next_tokens=65535)
print("single op output with mask:", single_op[0], single_op[0].shape)
print("graph output with mask:", graph_output[0], graph_output[0].shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
"""
)
_add_torch_npu_docstr(
"_npu_fused_infer_attention_score_get_max_workspace",
"""
功能描述:
算子功能:用于npu_fused_infer_attention_score算子aclgraph tilling下沉场景,获取最大workspace size并创建一个此size大小的tensor。
接口原型:
torch_npu._npu_fused_infer_attention_score_get_max_workspace(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_lengths=None, SymInt[]? actual_seq_lengths_kv=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? key_antiquant_scale=None, Tensor? key_antiquant_offset=None, Tensor? value_antiquant_scale=None, Tensor? value_antiquant_offset=None, Tensor? block_table=None, Tensor? query_padding_size=None, Tensor? kv_padding_size=None, Tensor? key_shared_prefix=None, Tensor? value_shared_prefix=None, SymInt[]? actual_shared_prefix_len=None, int num_heads=1, float scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int num_key_value_heads=0, int sparse_mode=0, int inner_precise=0, int block_size=0, int antiquant_mode=0, int key_antiquant_mode=0, int value_antiquant_mode=0, bool softmax_lse_flag=False) -> Tensor
参数说明:
输入与npu_fused_infer_attention_score一致
输出类型为Tensor, 由aclnnFusedInferAttentionScoreV3GetMaxWorkspaceSize返回最大的Size,返回创建的workspace tensor。
约束说明:
当Q_S等于1时:请参考Incre_Flash_Attention限制
当Q_S大于1时:请参考Prompt_Flash_Attention限制
支持的芯片型号:
Atlas A2 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
# 调用FIA算子
out = torch_npu._npu_fused_infer_attention_score_get_max_workspace(q, k, v, num_heads = 8, input_layout = "BNSD", scale = scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出类似如下
tensor([0., 0., ..., 0., 0., 0.],
device='npu:0', dtype=torch.float16)
# 入图方式
暂不支持入图
"""
)
_add_torch_npu_docstr(
"_npu_fused_infer_attention_score_infer_output",
"""
功能描述:
算子功能:用于npu_fused_infer_attention_score算子aclgraph tilling下沉场景,推算output tensor 并创建一个此size大小的tensor, 实际返回output_tensor 和 softmax_lse_tensor。
接口原型:
torch_npu._npu_fused_infer_attention_score_infer_output(Tensor query, Tensor value, *, str input_layout="BSH", Tensor? quant_scale2=None, Tensor? block_table=None, int num_heads=1, int num_key_value_heads=0, bool softmax_lse_flag=False, Tensor? query_rope=None) -> (Tensor, Tensor)
参数说明:
输入为npu_fused_infer_attention_score的子集
输出类型为(Tensor, Tensor), 由适配层推导,计算返回对应的output_tensor 和 softmax_lse_tensor。
约束说明:
当Q_S等于1时:请参考Incre_Flash_Attention限制
当Q_S大于1时:请参考Prompt_Flash_Attention限制
支持的芯片型号:
Atlas A2 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
# 调用FIA算子
out,softmax_lse = torch_npu._npu_fused_infer_attention_score_infer_output(q, k, v, num_heads = 8, input_layout = "BNSD")
# 执行上述代码的输出类似如下
tensor([0., 0., ..., 0., 0., 0.],
device='npu:0', dtype=torch.float16)
tensor([0., 0., ..., 0., 0., 0.],
device='npu:0', dtype=torch.float16)
# 入图方式
暂不支持入图
"""
)
_add_torch_npu_docstr(
"npu_fused_infer_attention_score.out",
"""
功能描述:
算子功能:npu_fused_infer_attention_score.out算子实现,可用于aclgraph tilling下沉场景(需传入workspace tensor),输入参数相比npu_fused_infer_attention_score增加workspace、attention_out、softmax_lse。
计算公式:atten_out = softmax(scale*(query*key)+atten_mask)*value
接口原型:
torch_npu.npu_fused_infer_attention_score.out(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_lengths=None, SymInt[]? actual_seq_lengths_kv=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? key_antiquant_scale=None, Tensor? key_antiquant_offset=None, Tensor? value_antiquant_scale=None, Tensor? value_antiquant_offset=None, Tensor? block_table=None, Tensor? query_padding_size=None, Tensor? kv_padding_size=None, Tensor? key_shared_prefix=None, Tensor? value_shared_prefix=None, SymInt[]? actual_shared_prefix_len=None, Tensor? query_rope=None, Tensor? key_rope=None, Tensor? key_rope_antiquant_scale=None, int num_heads=1, float scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int num_key_value_heads=0, int sparse_mode=0, int inner_precise=0, int block_size=0, int antiquant_mode=0, int key_antiquant_mode=0, int value_antiquant_mode=0, bool softmax_lse_flag=False, Tensor? workspace=None, Tensor(a!) attention_out, Tensor(b!) softmax_lse) -> (Tensor(a!), Tensor(b!))
参数说明:
在torch_npu.npu_fused_infer_attention_score的基础上增加下面三个参数:
workspace(可选): 一维Device侧的Input Tensor,数据类型与Query一致;
attention_out(aclTensor*,计算输出): 计算的最终结果Attention output tensor, shape与Query一致;
softmax_lse(aclTensor*,计算输出): 也是一个输出结果,当前预留,暂不支持;
约束说明:
当Q_S等于1时:请参考Incre_Flash_Attention限制
当Q_S大于1时:请参考Prompt_Flash_Attention限制
支持的芯片型号:
Atlas A2 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
workspace = torch.randn(2000000, dtype=torch.float16).npu()
output = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
softmax_lse = torch.randn(1, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)
# 调用FIA算子
out = torch_npu.npu_fused_infer_attention_score.out(q, k, v, workspace=workspace, out=[output, softmax_lse], num_heads = 8, input_layout = "BNSD", scale = scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出output类似如下
tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.float16)
# 入图方式
暂不支持入图
"""
)
_add_torch_npu_docstr(
"npu_fused_infer_attention_score_v2",
"""
功能描述:
算子功能: 适配增量&全量推理场景的FlashAttention算子, 既可以支持全量计算场景(PromptFlashAttention), 也可支持增量计算场景(IncreFlashAttention). 当Query矩阵的S为1, 进入IncreFlashAttention分支, 其余场景进入PromptFlashAttention分支.
计算公式:
attention_out = softmax(softmax_scale*(query*key)+atten_mask)*value
接口原型:
torch_npu.npu_fused_infer_attention_score_v2(Tensor query, Tensor key, Tensor value, *, Tensor? query_rope=None, Tensor? key_rope=None, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_qlen=None, SymInt[]? actual_seq_kvlen=None, Tensor? block_table=None, Tensor? dequant_scale_query=None, Tensor? dequant_scale_key=None, Tensor? dequant_offset_key=None, Tensor? dequant_scale_value=None, Tensor? dequant_offset_value=None, Tensor? dequant_scale_key_rope=None, Tensor? quant_scale_out=None, Tensor? quant_offset_out=None, Tensor? quant_scale_p=None, Tensor? learnable_sink=None, int num_query_heads=1, int num_key_value_heads=0, float softmax_scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int sparse_mode=0, int block_size=0, int query_quant_mode=0, int key_quant_mode=0, int value_quant_mode=0, int inner_precise=0, bool return_softmax_lse=False, int? query_dtype=None, int? key_dtype=None, int? value_dtype=None, int? query_rope_dtype=None, int? key_rope_dtype=None, int? key_shared_prefix_dtype=None, int? value_shared_prefix_dtype=None, int? dequant_scale_query_dtype=None, int? dequant_scale_key_dtype=None, int? dequant_scale_value_dtype=None, int? dequant_scale_key_rope_dtype=None, int? out_dtype=None) -> (Tensor, Tensor)
参数说明:
query、key、value数据排布格式支持从多种维度解读, 其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸, 且满足D=H/N、T表示所有Batch输入样本序列长度的累加和.
query: Tensor类型, attention结构的Query输入, 数据类型支持float16、bfloat16、int8, 不支持非连续的Tensor, 数据格式支持ND.
key: Tensor类型, attention结构的Key输入, 不支持非连续的Tensor, 数据格式支持ND.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8、int4(int32).
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8、int4(int32).
value: Tensor类型, attention结构的Value输入, 不支持非连续的Tensor, 数据格式支持ND.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、int8、int4(int32).
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、int8、int4(int32).
*: 代表其之前的变量是位置相关, 需要按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
query_rope: Tensor类型, 表示MLA(Multi-head Latent Attention)结构中的query的rope信息, 数据类型支持float16、bfloat16, 不支持非连续的Tensor, 数据格式支持ND.
key_rope: Tensor类型, 表示MLA(Multi-head Latent Attention)结构中的key的rope信息, 数据类型支持float16、bfloat16, 不支持非连续的Tensor, 数据格式支持ND.
pse_shift: Tensor类型, 在attention结构内部的位置编码参数, 数据类型支持float16、bfloat16, 数据类型与query的数据类型需满足数据类型推导规则. 不支持非连续的Tensor, 数据格式支持ND. 如不使用该功能时可传入None.
Q_S大于1, 要求在pse_shift为float16类型时, 此时的query为float16或int8类型; 而在pse_shift为bfloat16类型时, 要求此时的query为bfloat16类型. 输入shape类型需为(B, Q_N, Q_S, KV_S)或(1, Q_N, Q_S, KV_S), 其中Q_S为query的shape中的S, KV_S为key和value的shape中的S. 对于pse_shift的KV_S为非32对齐的场景, 建议padding到32字节来提高性能, 多余部分的填充值不做要求.
Q_S为1, 要求在pse_shift为float16类型时, 此时的query为float16类型; 而在pse_shift为bfloat16类型时, 要求此时的query为bfloat16类型. 输入shape类型需为(B, Q_N, 1, KV_S)或(1, Q_N, 1, KV_S), KV_S为key和value的shape中的S. 对于pse_shift的KV_S为非32对齐的场景, 建议padding到32字节来提高性能, 多余部分的填充值不做要求.
atten_mask: Tensor类型, 对QK的结果进行mask, 用于指示是否计算Token间的相关性, 数据类型支持bool、int8和uint8. 不支持非连续的Tensor, 数据格式支持ND. 如果不使用该功能可传入None.
sparse_mode为0、1时
支持shape传入(1,Q_S,KV_S)、(B,1,Q_S,KV_S)、(1,1,Q_S,KV_S)。
当输入input_layout为BSH、BSND、BNSD、BNSD_BSND时,且query、key、value的D相等,并且不传query_rope和key_rope时,Q_S为1可支持传入(B,KV_S),Q_S大于1时可支持传入(Q_S,KV_S)。
如果Q_S、KV_S非16或32对齐,可以向上取到对齐的S。综合约束请见约束声明。
sparse_mode为2、3、4时,shape输入支持(2048,2048)或(1,2048,2048)或(1,1,2048,2048)。
其中Q_S为query的shape中的S, KV_S为key和value的shape中的S, 但如果Q_S、KV_S非16或32对齐, 可以向上取到对齐的S. 综合约束请见约束说明.
actual_seq_qlen: int类型数组, 代表不同Batch中query的有效seqlen, 数据类型支持int64. 如果不指定seqlen可以传入None, 表示和query的shape的s长度相同. 限制: 该入参中每个batch的有效seqlen应该不大于query中对应batch的seqlen, Q_S为1时该参数无效. seqlen的传入长度为1时, 每个Batch使用相同seqlen; 传入长度大于等于Batch时取seqlen的前Batch个数. 其他长度不支持. 当query的input_layout为TND时, 该入参必须传入, 且以该入参元素的数量作为Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
actual_seq_kvlen: int类型数组, 代表不同Batch中key/value的有效seqlenKv, 数据类型支持int64. 如果不指定None, 表示和key/value的shape的S长度相同. 不同O_S值有不同的约束, 具体参见约束说明.
block_table: Tensor类型, 数据类型支持int32. 数据格式支持ND. 表示PageAttention中KV存储使用的block映射表, 如不使用该功能可传入None.
dequant_scale_query: Tensor类型. 数据格式支持ND, query的反量化参数. 仅支持per-token叠加per-head. 如不使用该功能时可传入None, 综合约束请见约束说明.
dequant_scale_key: Tensor类型. 数据格式支持ND, kv伪量化参数分离时表示key的反量化因子. 如不使用该功能时可传入None, 综合约束请见约束说明. 通常支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理scale、per-token叠加per head并使用page attention模式管理scale.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、float32.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、float32.
dequant_offset_key: Tensor类型, 数据类型支持float16、bfloat16、float32. 数据格式支持ND, kv伪量化参数分离时表示key的反量化偏移. 支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理offset、per-token叠加per head并使用page attention模式管理offset. Q_S大于等于2时仅支持per-token模式, 如不使用该功能时可传入None, 综合约束请见约束说明.
dequant_scale_value: Tensor类型, 数据类型支持float16、bfloat16、float32. 数据格式支持ND, kv伪量化参数分离时表示value的反量化因子. Q_S大于等于2时仅支持per-token模式, 如不使用该功能时可传入None, 综合约束请见约束说明. 通常支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理scale、per-token叠加per head并使用page attention模式管理scale.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、bfloat16、float32.
Atlas A3 训练系列产品: 数据类型支持float16、bfloat16、float32.
dequant_offset_value: Tensor类型, 数据类型支持float16、bfloat16、float32. 数据格式支持ND, kv伪量化参数分离时表示value的反量化偏移, 支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head、per-token叠加使用page attention模式管理offset、per-token叠加per head并使用page attention模式管理offset. Q_S大于等于2时仅支持per-token模式, 如不使用该功能时可传入None, 综合约束请见约束说明.
dequant_scale_key_rope: Tensor类型, 预留参数, 暂未使用, 使用默认值即可. 表示MLA(Multi-head Latent Attention)结构中的key Rope对应的反量化因子, 支持per-channel, 数据类型支持float16、bfloat16, 不支持非连续的Tensor, 数据格式支持ND, D维度与key_rope的D维度保持一致. 仅支持Q_S等于1-16, 其余场景该参数无效.
quant_scale_out: Tensor类型, 数据类型支持float32、bfloat16. 数据格式支持ND, 表示输出的量化因子, 支持per-tensor、per-channel. 当输入为bfloat16时, 同时支持float32和bfloat16 , 否则仅支持float32 . per-channel格式, 当输出layout为BSH时, 要求quant_scale2所有维度的乘积等于H; 其他layout要求乘积等于Q_N*D(建议输出layout为BSH时, quant_scale2shape传入(1, 1, H)或(H,); 输出为BNSD时, 建议传入(1, Q_N, 1, D)或(Q_N, D); 输出为BSND时, 建议传入(1, 1, Q_N, D)或(Q_N, D)). 如不使用该功能时可传入None, 综合约束请见约束说明.
quant_offset_out: Tensor类型, 数据类型支持float32、bfloat16. 数据格式支持ND, 表示输出的量化偏移, 支持per-tensor、per-channel. 若传入quant_offset_out, 需保证其类型和shape信息与quant_scale_out 一致. 如不使用该功能时可传入None, 综合约束请见约束说明.
quant_scale_p: Tensor类型, 预留参数, 暂未使用, 使用默认值即可.
learnable_sink: Tensor类型, 数据类型支持bfloat16, 数据格式支持ND, shape输入为(Q_N,), 通过可学习的"Sink Token"起到吸收Attention Score的作用, 如果不使用该功能可传入None, 综合约束请见约束说明.
num_query_heads: 整型, 代表query的head个数, 数据类型支持int64, 在BNSD场景下, 需要与shape中的query的N轴shape值相同, 否则执行异常.
num_key_value_heads: 整型, 代表key、value中head个数, 用于支持GQA(Grouped-Query Attention, 分组查询注意力)场景, 数据类型支持int64. 用户不特意指定时可传入默认值0, 表示key/value和query的head个数相等, 需要满足num_query_heads整除num_key_value_heads, num_query_heads与num_key_value_heads的比值不能大于64. 在BSND、BNSD、BNSD_BSND(仅支持Q_S大于1)场景下, 还需要与shape中的key/value的N轴shape值相同, 否则执行异常.
softmax_scale: 浮点型, 公式中d开根号的倒数, 代表缩放系数, 作为计算流中Muls的scalar值, 数据类型支持float. 数据类型与query的数据类型需满足数据类型推导规则. 用户不特意指定时可传入默认值1.0, 即不做缩放. 建议传入1/sqrt(D)(D为Head Dim), 例如当D=128时传入1/math.sqrt(128.0), 以获得正确的注意力计算结果.
pre_tokens: 整型, 用于稀疏计算, 表示attention需要和前几个Token计算关联, 数据类型支持int64. 用户不特意指定时可传入默认值2147483647, Q_S为1时该参数无效.
next_tokens: 整型, 用于稀疏计算, 表示attention需要和后几个Token计算关联. 数据类型支持int64. 用户不特意指定时可传入默认值2147483647, Q_S为1时该参数无效.
input_layout: 字符串类型, 用于标识输入query、key、value的数据排布格式, 用户不特意指定时可传入默认值"BSH". 各layout对应的shape如下: BSH-(B,Q_S,H), BSND-(B,Q_S,Q_N,D), BNSD-(B,Q_N,Q_S,D), BNSD_BSND输入(B,Q_N,Q_S,D)输出(B,Q_S,Q_N,D)仅支持Q_S>1, BSH_NBSD输入(B,Q_S,H)输出(Q_N,B,Q_S,D), BSND_NBSD输入(B,Q_S,Q_N,D)输出(Q_N,B,Q_S,D), BNSD_NBSD输入(B,Q_N,Q_S,D)输出(Q_N,B,Q_S,D)仅支持Q_S为1~16, TND-(T,Q_N,D)其中T为所有Batch的S累加和, TND_NTD输入(T,Q_N,D)输出(Q_N,T,D), NTD_TND输入(Q_N,T,D)输出(T,Q_N,D). 其中H=N*D.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持BSH、BSND、BNSD、BNSD_BSND、TND(不支持tensorlist、pse、page attention、伪量化、全量化、后量化, 综合约束请见约束说明). 当为TND时, 不支持图模式配置Tiling调度优化功能(tiling_schedule_optimize=True).
Atlas A3 训练系列产品: 支持BSH、BSND、BNSD、BNSD_BSND、TND(不支持tensorlist、pse、page attention、伪量化、全量化、后量化, 综合约束请见约束说明). 当为TND时, 不支持图模式配置Tiling调度优化功能(tiling_schedule_optimize=True).
其中BNSD_BSND含义指当输入为BNSD, 输出格式为BSND, 仅支持Q_S大于1.
sparse_mode: 整型, 表示sparse的模式. 数据类型支持int64. Q_S为1且不带rope输入时该参数无效.
sparse_mode为0时, 代表defaultMask模式, 如果atten_mask未传入则不做mask操作, 忽略pre_tokens和next_tokens(内部赋值为INT_MAX); 如果传入, 则需要传入完整的atten_mask矩阵(S1*S2), 表示pre_tokens和next_tokens之间的部分需要计算.
sparse_mode为1时, 代表allMask, 必须传入完整的attenmask矩阵(S1*S2).
sparse_mode为2时, 代表leftUpCausal模式的mask, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为3时, 代表rightDownCausal模式的mask, 对应以右顶点为划分的下三角场景, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为4时, 代表band模式的mask, 需要传入优化后的atten_mask矩阵(2048*2048).
sparse_mode为5、6、7、8时, 分别代表prefix、global、dilated、block_local, 均暂不支持. 用户不特意指定时可传入默认值0. 综合约束请见约束说明.
block_size: 整型, PageAttention中KV存储每个block中最大的token个数, 默认为0, 数据类型支持int64.
query_quant_mode: 整型, 表示query的伪量化方式。仅支持传入3,代表模式3:代表per-token叠加per-head模式.
key_quant_mode: 整型, 表示key的伪量化方式. Q_S大于等于2时仅支持传入值为1, 用户不特意指定时可传入默认值0, 取值除了key_quant_mode为0并且value_quant_mode为1的场景外, 需要与value_quant_mode一致. 综合约束请见约束说明.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持取值0、1、2、3、4、5.
Atlas A3 训练系列产品: 支持取值0、1、2、3、4、5.
key_quant_mode为0时, 代表per-channel模式(per-channel包含per-tensor).
key_quant_mode为1时, 代表per-token模式.
key_quant_mode为2时, 代表per-tensor叠加per-head模式.
key_quant_mode为3时, 代表per-token叠加per-head模式.
key_quant_mode为4时, 代表per-token叠加使用page attention模式管理scale/offset模式.
key_quant_mode为5时, 代表per-token叠加per head并使用page attention模式管理scale/offset模式.
value_quant_mode: 整型, 表示value的伪量化方式, 模式编号与key_quant_mode一致. Q_S大于等于2时仅支持传入值为1, 用户不特意指定时可传入默认值0, 取值除了key_quant_mode为0并且value_quant_mode为1的场景外, 需要与key_quant_mode一致. 综合约束请见约束说明.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持取值0、1、2、3、4、5.
Atlas A3 训练系列产品: 支持取值0、1、2、3、4、5.
inner_precise: 整型, 一共4种模式: 0、1、2、3. 一共两位bit位, 第0位(bit0)表示高精度或者高性能选择, 第1位(bit1)表示是否做行无效修正. 数据类型支持int64. Q_S>1时, sparse_mode为0或1, 并传入用户自定义mask的情况下, 建议开启行无效; Q_S为1时该参数仅支持innerPrecise为0和1. 综合约束请见约束说明.
inner_precise为0时, 代表开启高精度模式, 且不做行无效修正.
inner_precise为1时, 代表高性能模式, 且不做行无效修正.
inner_precise为2时, 代表开启高精度模式, 且做行无效修正.
inner_precise为3时, 代表高性能模式, 且做行无效修正.
bfloat16和int8不区分高精度和高性能, 行无效修正对float16、bfloat16和int8均生效. 当前0、1为保留配置值, 当计算过程中“参与计算的mask部分”存在某整行全为1的情况时, 精度可能会有损失. 此时可以尝试将该参数配置为2或3来使能行无效功能以提升精度, 但是该配置会导致性能下降.
return_softmax_lse: 布尔型, 表示是否输出softmax_lse, 支持S轴外切(增加输出). true表示输出softmax_lse, false表示不输出; 用户不特意指定时可传入默认值false.
query_dtype: 整型, 表示query的数据类型,预留参数,暂未使用,使用默认值即可.
key_dtype: 整型, 表示key的数据类型,预留参数,暂未使用,使用默认值即可.
value_dtype: 整型, 表示value的数据类型,预留参数,暂未使用,使用默认值即可.
query_rope_dtyp: 整型, 表示query_repo的数据类型,预留参数,暂未使用,使用默认值即可.
key_rope_dtype: 整型, 表示key_rope的数据类型,预留参数,暂未使用,使用默认值即可.
key_shared_prefix_dtype: 整型, 表示key_shared_prefix的数据类型,预留参数,暂未使用,使用默认值即可.
value_shared_prefix_dtype: 整型, 表示value_shared_prefix的数据类型,预留参数,暂未使用,使用默认值即可.
dequant_scale_query_dtype: 整型, 表示dequant_scale_query的数据类型,预留参数,暂未使用,使用默认值即可.
dequant_scale_key_dtype: 整型, 表示dequant_scale_key的数据类型,预留参数,暂未使用,使用默认值即可.
dequant_scale_value_dtype: 整型, 表示dequant_scale_value的数据类型,预留参数,暂未使用,使用默认值即可.
dequant_scale_key_rope_dtype: 整型, 表示dequant_scale_key_rope的数据类型,预留参数,暂未使用,使用默认值即可.
out_dtype: 整型, 表示输出的数据类型. 当输入为int8或float8_e4m3fn时, 可通过该参数指定输出的数据类型(如float8_e5m2). 如不使用该功能可传入None.
输出说明
attention_out: Tensor类型, 公式中的输出, 数据类型支持float16、bfloat16、int8. 数据格式支持ND. 限制:该入参的D维度与value的D保持一致,其余维度需要与入参query的shape保持一致.
softmaxLse: Tensor类型, ring attention算法对query乘key的结果, 先取max得到softmax_max. query乘key的结果减去softmax_max, 再取exp, 最后取sum, 得到softmax_sum, 最后对softmax_sum取log, 再加上softmax_max得到的结果. 数据类型支持float32, return_softmax_lse为True时, 一般情况下, 输出shape为(B, Q_N, Q_S, 1)的Tensor, 当input_layout为TND时, 输出shape为(T,Q_N,1)的Tensor; return_softmax_lse为False时, 则输出shape为[1]的值为0的Tensor.
约束说明:
该接口支持推理场景下使用.
该接口支持图模式.
该接口与PyTorch配合使用时, 需要保证CANN相关包与PyTorch相关包的版本匹配.
入参为空的处理: 算子内部需要判断参数query是否为空, 如果是空则直接返回. 参数query不为空Tensor, 参数key、value为空tensor(即S2为0), 则填充全零的对应shape的输出(填充attention_out). attention_out为空Tensor时, 框架会处理.
参数key、value中对应tensor的shape需要完全一致; 非连续场景下key、value的tensorlist中的batch只能为1, 个数等于query的B, N和D需要相等.
int8量化相关入参数量与输出数据格式的综合限制:
输出为int8的场景: 入参quant_scale_out需要存在, quant_offset_out可选, 不传时默认为0.
输出为float16的场景: 若存在入参quant_offset_out或quant_scale_out(即不为None), 则报错并返回.
入参quant_offset_out和quant_scale_out支持per-tensor或per-channel格式, 数据类型支持float32、bfloat16.
query_rope和key_rope输入时即为MLA场景,参数约束如下:
query_rope的数据类型、数据格式与query一致。
key_rope的数据类型、数据格式与key一致。
query_rope和key_rope要求同时配置或同时不配置,不支持只配置其中一个。
当query_rope和key_rope非空时,支持如下特性:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/Atlas A3 推理系列产品:query的d只支持512、128;
当query的d等于512时:
sparse:Q_S等于1时只支持sparse=0且不传mask,Q_S大于1时只支持sparse=3且传入mask;
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/Atlas A3 推理系列产品约束如下:
query_rope配置时要求query的s为1-16、n为1/2/4/8/16/32/64/128,query_rope的shape中d为64,其余维度与query一致;
key_rope配置时要求key的n为1,d为512,keyRope的shape中d为64,其余维度与key一致;
支持key、value、keyRope的input_layout格式为ND或NZ。当input_layout为NZ时,数据类型为float16或bfloat16时,输入参数key和value的格式为[blockNum, KV_N, D/16, blockSize, 16],数据类型为int8时,输入参数key和value的格式为[blockNum, KV_N, D/32, blockSize, 32];
input_layout形状支持BSH、BSND、BNSD、BNSD_NBSD、BSND_NBSD、BSH_NBSD、TND、TND_NTD,当数据格式为NZ时input_layout不支持BNSD、BNSD_NBSD。
该场景下,必须开启PageAttention,此时block_size支持16、128,其中数据格式为NZ时block_size不支持配置16。
不支持开启左padding、tensorlist、pse、prefix、伪量化、后量化、空Tensor。
支持全量化场景,即输入query/key/value全为int8,query_rope和key_rope为bfloat16,输出为bfloat16的场景:
入参dequant_scale_query、dequant_scale_key、dequant_scale_value需要同时存在,且其数据类型仅支持FP32。
不支持传入quant_scale_out、quant_offset_out、dequant_offset_key、dequant_offset_value(即不为nullptr),否则报错并返回。
query_quant_mode仅支持per-token叠加per-head模式,key_quant_mode和value_quant_mode仅支持per-tensor模式。
支持key、value、keyRope的input_layout格式为NZ。
当query的d等于128时:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/Atlas A3 推理系列产品约束如下:
inputLayout:TND、NTD_TND。
query_rope配置时要求query_rope的shape中d为64,其余维度与query一致。
keyRope配置时要求keyRope的shape中d为64,其余维度与key一致。
不支持左padding、tensorlist、pse、page attention、prefix、伪量化、全量化、后量化、空Tensor。
其余约束同TND、NTD_TND场景下的综合限制保持一致。
TND、TND_NTD、NTD_TND场景下query、key、value输入的综合限制:
T小于等于1M;
sparse模式仅支持sparse=0且不传mask,或sparse=3且传入mask;
actual_seq_qlen和actual_seq_kvlen必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的Sequence Length和,因此后一个元素的值必须大于等于前一个元素的值;
当query的d等于512时:
支持TND、TND_NTD;
必须开启page attention,此时actual_seq_kvlen长度等于key/value的batch值,代表每个batch的实际长度,值不大于KV_S;
支持query每个batch的s为1-16;
要求query的n为1/2/4/8/16/32/64/128,key、value的n为1;
要求query_rope和keyRope不等于空,query_rope和keyRope的d为64;
不支持开启左padding、tensorlist、pse、prefix、伪量化、全量化、后量化、空Tensor。
当query的d不等于512时:
当query_rope和key_rope为空时:TND场景,要求Q_D、K_D、V_D等于128,或者Q_D、K_D等于192,V_D等于128/192;NTD_TND场景,要求Q_D、K_D等于128/192,V_D等于128。当query_rope和key_rope不为空时,要求Q_D、K_D、V_D等于128;
Q_N、K_N、V_N:需要满足K_N、V_N相等,Q_N整除K_N,Q_N与K_N的比值不能大于64;
支持TND、NTD_TND;
数据类型仅支持BFLOAT16;
当sparse=3时,要求每个batch单独的actual_seq_qlen<actual_seq_kvlen;
sparse模式支持sparse\_mode=4且传入mask;当sparse\_mode=4时,要求preTokens >= -actual\_seq\_qlen、nextTokens >= -actual\_seq\_kvlen、preTokens + nextTokens >= 0;
不支持左padding、tensorlist、pse、prefix、伪量化、全量化、后量化、空Tensor;
不支持图模式配置Tiling调度优化(tiling_schedule_optimize=True)、reduce-overhead执行模式(config.mode="reduce-overhead")。
actual_seq_qlen和actual_seq_kvlen的元素个数不大于4096。
GQA伪量化场景下KV为NZ格式时的参数约束如下:
支持per-channel和per-token模式,query数据类型固定为bfloat16,key&value固定为int8;query&key&value的D仅支持128;query Sequence Length仅支持1-16;
input_layout仅支持BSH、BSND、BNSD;
仅支持page_attention场景,blockSize仅支持128或512;
key&value仅支持NZ输入,输入格式为[blockNum, KV_N, D/32, blockSize, 32];
dequant_scale_key和dequant_scale_value的dtype:per-channel模式下,仅支持bfloat16类型;per-token模式下,仅支持float32类型;
dequant_scale_key和dequant_scale_value的shape:per-channel模式下,当layout为BSH时,必须传入[H];layout为BNSD时,必须传入[KV_N,1,D];输出为BSND时,必须传入[KV_N, D];per-token模式下,必须传入[B,KV_S],S需要大于等于blockTable的第二维*blockSize;
仅支持KV分离;
仅支持高性能模式;
当MTP等于0时,支持sparse_mode=0且不传mask;当MTP大于0、小于16时,支持sparse_mode=3且传入优化后的atten_mask矩阵,atten_mask矩阵shape必须传入(2048*2048);
不支持配置dequant_offset_key和dequant_offset_value;
不支持配置query_rope和key_rope;
不支持左padding、tensorlist、pse、prefix、后量化;
num_query_heads与num_key_value_heads支持组合有(10,1)、(64,8)、(80,8)、(128,16)。
learnable_sink的参数约束如下:
仅支持TND、NTD\_TND;
仅支持value的d小于等于128;
当Q_S大于1时:
query、key、value输入, 功能使用限制如下:
支持B轴小于等于65536, D轴32byte不对齐时仅支持到128.
支持N轴小于等于256, 支持D轴小于等于512; input_layout为BSH或者BSND时, 要求N*D小于65535.
S支持小于等于20971520(20M). 部分长序列场景下, 如果计算量过大可能会导致PFA算子执行超时(aicore error类型报错, errorStr为timeout or trap error), 此场景下建议做S切分处理(注: 这里计算量会受B、S、N、D等的影响, 值越大计算量越大), 典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
B=1, Q_N=20, Q_S=2097152, D=256, KV_N=1, KV_S=2097152.
B=1, Q_N=2, Q_S=20971520, D=256, KV_N=2, KV_S=20971520.
B=20, Q_N=1, Q_S=2097152, D=256, KV_N=1, KV_S=2097152.
B=1, Q_N=10, Q_S=2097152, D=512, KV_N=1, KV_S=2097152.
query、key、value输入类型包含int8时, D轴需要32对齐; 输入类型全为float16、bfloat16时, D轴需16对齐.
actual_seq_kvlen: 该参数传入时应为非负数, 在input_layout不同时, 其含义与拦截条件不同: 一般情况下, 该入参为可选入参, 该入参中每个Batch的有效seqlenKv应该不大于key/value中对应Batch的seqlenKv. 当本参数的传入长度为1时, 每个Batch使用相同seqlenKv; 传入长度大于等于Batch时取seqlenKv的前Batch个数. 其他长度不支持. 当key/value的input_layout为TND时, 该入参必须传入, 且该入参元素的数量等于Batch值. 该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKv和, 因此后一个元素的值必须大于等于前一个元素的值, 且不能出现负值.
参数sparse_mode当前仅支持值为0、1、2、3、4的场景, 取其它值时会报错.
sparse_mode=0时, atten_mask如果为None, 则忽略入参pre_tokens、next_tokens(内部赋值为INT_MAX).
sparse_mode=2、3、4时, atten_mask的shape需要为(S, S)或(1, S, S)或(1, 1, S, S), 其中S的值需要固定为2048, 且需要用户保证传入的atten_mask为下三角, 不传入atten_mask或者传入的shape不正确报错.
sparse_mode=1、2、3的场景忽略入参pre_tokens、next_tokens并按照相关规则赋值.
page attention场景:
page attention的使能必要条件是block_table存在且有效, 同时key、value是按照block_table中的索引在一片连续内存中排布, 支持key、value数据类型为float16、bfloat16、int8. 在该场景下key、value的input_layout参数无效. block_table中填充的是blockid, 当前不会对blockid的合法性进行校验, 需用户自行保证.
block_size是用户自定义的参数, 该参数的取值会影响page attention的性能, 在使能page attention场景下, block_size最小为128, 最大为512, 且要求是128的倍数. 通常情况下, page attention可以提高吞吐量, 但会带来性能上的下降.
page attention场景下, 当输入kv cache排布格式为(blocknum, blocksize, H), 且KV_N*D超过65535时, 受硬件指令约束, 会被拦截报错. 可通过使能GQA(减小KV_N)或调整kv cache排布格式为(blocknum, KV_N, blocksize, D)解决. 当query的input_layout为BNSD、TND时, kv cache排布支持(blocknum, blocksize, H)和(blocknum, KV_N, blocksize, D)两种格式, 当query的input_layout为BSH、BSND时, kv cache排布只支持(blocknum, blocksize, H)一种格式. blocknum不能小于根据actual_seq_kvlen和blockSize计算的每个batch的block数量之和. 且key和value的shape需保证一致.
page attention不支持伪量化场景, 不支持tensorlist场景.
page attention场景下, 必须传入actual_seq_kvlen.
page attention场景下, block_table必须为二维, 第一维长度需等于B, 第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为不同batch中最大actual_seq_kvlen对应的block数量).
page attention场景下,支持两种格式和float32/bfloat16,不支持输入query为int8的场景。
page attention使能场景下, 以下场景输入需满足KV_S>=maxBlockNumPerSeq*blockSize:
传入attenMask时, 如mask shape为 (B, 1, Q_S, KV_S).
传入pseShift时, 如pseShift shape为(B, Q_N, Q_S, KV_S).
入参quant_scale_out和quant_offset_out支持per-tensor、per-channel量化, 支持float32、bfloat16类型. 若传入quant_offset_out, 需保证其类型和shape信息与quant_scale_out一致. 当输入为bfloat16时, 同时支持float32和bfloat16 , 否则仅支持float32. per-channel场景下, 当输出layout为BSH时, 要求quant_scale_out所有维度的乘积等于H; 其他layout要求乘积等于Q_N*D. 当输出layout为BSH时, quant_scale_out shape建议传入(1, 1, H)或(H,); 当输出layout为BNSD时, 建议传入(1, Q_N, 1, D)或(Q_N, D); 当输出为BSND时, 建议传入(1, 1, Q_N, D)或(Q_N, D).
输出为int8, quant_scale_out和quant_offset_out为per-channel时, 暂不支持Ring Attention或者D非32Byte对齐的场景.
输出为int8时, 暂不支持sparse为band且preTokens/nextTokens为负数.
pse_shift功能使用限制如下:
支持query数据类型为float16或bfloat16或int8场景下使用该功能.
query、key、value数据类型为float16且pse_shift存在时, 强制走高精度模式, 对应的限制继承自高精度模式的限制.
Q_S需大于等于query的S长度, KV_S需大于等于key的S长度.
输出为int8, 入参quant_offset_out传入非None和非空tensor值, 并且sparse_mode、pre_tokens和next_tokens满足以下条件, 矩阵会存在某几行不参与计算的情况, 导致计算结果误差, 该场景会拦截:
sparse_mode=0, atten_mask如果非None, 每个batch actual_seq_qlen-actual_seq_kvlen-pre_tokens>0或next_tokens<0时, 满足拦截条件.
sparse_mode=1或 2, 不会出现满足拦截条件的情况.
sparse_mode=3, 每个batch actual_seq_kvlen-actual_seq_qlen<0, 满足拦截条件.
sparse_mode=4, pre_tokens<0或每个batch next_tokens+actual_seq_kvlen-actual_seq_qlen<0时, 满足拦截条件.
kv伪量化参数分离:
key_quant_mode和value_quant_mode需要保持一致.
dequant_scale_key和dequant_scale_value要么都为空, 要么都不为空; dequant_offset_key和dequant_offset_value要么都为空, 要么都不为空.
dequant_scale_key和dequant_scale_value都不为空时, 其shape需要保持一致; dequant_offset_key和dequant_offset_value都不为空时, 其shape需要保持一致.
仅支持per-token和per-channel模式,per-token模式下要求两个参数的shape均为(B, KV_S),数据类型固定为float32;per-channel模式下要求两个参数的shape为(KV_N, D),(KV_N, D),(H),数据类型固定为bfloat16,H为KV_N*D.
当伪量化参数和KV分离量化参数同时传入时, 以KV分离量化参数为准.
dequant_scale_key与dequant_scale_value非空场景, 要求query的s小于等于16.
dequant_scale_key与dequant_scale_value非空场景, 要求query的dtype为bfloat16, key、value的dtype为int8, 输出的dtype为bfloat16.
dequant_scale_key与dequant_scale_value非空场景, 不支持tensorlist、page attention特性.
当Q_S等于1时:
query、key、value输入, 功能使用限制如下:
支持B轴小于等于65536, 支持N轴小于等于256, 支持S轴小于等于262144, 支持D轴小于等于512.
query、key、value输入类型均为int8的场景暂不支持.
在int4(int32)伪量化场景下, PyTorch入图调用仅支持KV int4拼接成int32输入(建议通过dynamicQuant生成int4格式的数据, 因为dynamicQuant就是一个int32包括8个int4).
在int4(int32)伪量化场景下, 若KV int4拼接成int32输入, 那么KV的N、D或者H是实际值的八分之一. 并且, int4伪量化仅支持D 64对齐(int32支持D 8对齐).
actual_seq_kvlen: 该参数应为非负数, 在input_layout不同时, 其含义与拦截条件不同: 一般情况下, 该入参为可选入参, 该入参中每个Batch的有效Sequence Length应该不大于key/value中对应Batch的seqlenKv. 当本参数的传入长度为1时, 每个Batch使用相同seqlenKv; 传入长度大于等于Batch时取seqlenKv的前Batch个数. 其他长度不支持. 当input_layout为TND时, 该入参必须传入, 在非PA场景下, 第b个值表示前b个Batch的S轴累加长度, 其值应递增(大于等于前一个值)排列, 且该入参元素的数量代表总Batch数, 在PA场景下, 其长度等于key/value的Batch值, 代表每个Batch的实际长度, 值不大于KV_S.
page attention场景:
使能必要条件是block_table存在且有效, 同时key、value是按照block_table中的索引在一片连续内存中排布, 在该场景下key、value的input_layout参数无效.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 支持key、value数据类型为float16、bfloat16、int8.
Atlas A3 训练系列产品: 支持key、value数据类型为float16、bfloat16、int8.
该场景下, block_size是用户自定义的参数, 该参数的取值会影响page attention的性能. key、value输入类型为float16、bfloat16时需要16对齐, key、value输入类型为int8时需要32对齐, 推荐使用128. 通常情况下, page attention可以提高吞吐量, 但会带来性能上的下降.
参数key、value各自对应tensor的shape所有维度相乘不能超过int32的表示范围.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 不支持Q为bfloat16、float16、key、value为int4(int32)的场景.
Atlas A3 训练系列产品: 不支持Q为bfloat16、float16、key、value为int4(int32)的场景.
page attention场景下, blockTable必须为二维, 第一维长度需等于B, 第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为不同batch中最大actual_seq_kvlen对应的block数量).
page attention场景下, 当query的input_layout为BNSD、TND时, kv cache排布支持(blocknum, blocksize, H)和(blocknum, KV_N, blocksize, D)两种格式, 当query的input_layout为BSH、BSND时, kv cache排布只支持(blocknum, blocksize, H)一种格式. blocknum不能小于根据actual_seq_kvlen和blockSize计算的每个batch的block数量之和. 且key和value的shape需保证一致.
page attention场景下, kv cache排布为(blocknum, KV_N, blocksize, D)时性能通常优于kv cache排布为(blocknum, blocksize, H)时的性能, 建议优先选择(blocknum, KV_N, blocksize, D)格式.
page attention使能场景下, 当输入kv cache排布格式为(blocknum, blocksize, H), 且 numKvHeads * headDim 超过64k时, 受硬件指令约束, 会被拦截报错. 可通过使能GQA(减小 numKvHeads)或调整kv cache排布格式为(blocknum, numKvHeads, blocksize, D)解决.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 不支持Q为BF16/FP16且KV为INT4(INT32)的场景.
Atlas A3 训练系列产品: 不支持Q为BF16/FP16且KV为INT4(INT32)的场景.
page attention场景的参数key、value各自对应tensor的shape所有维度相乘不能超过int32的表示范围.
kv伪量化参数分离:
除了key_quant_mode为0并且value_quant_mode为1的场景外, key_quant_mode和value_quant_mode取值需要保持一致.
dequant_scale_key和dequant_scale_value要么都为空, 要么都不为空; dequant_offset_key和dequant_offset_value要么都为空, 要么都不为空.
dequant_scale_key和dequant_scale_value都不为空时, 除了key_quant_mode为0并且value_quant_mode为1的场景外, 其shape需要保持一致; dequant_offset_key和dequant_offset_value都不为空时, 除了key_quant_mode为0并且value_quant_mode为1的场景外, 其shape需要保持一致.
int4(int32)伪量化场景不支持后量化.
管理scale/offset的量化模式如下:
注意scale、offset两个参数指dequant_scale_key、dequant_scale_key、dequant_offset_value、dequant_offset_value.
场景下scale和offset条件
per-channel模式: 两个参数shape支持(1, KV_N, 1, D), (1, KV_N, D), (1, H), 数据类型和query数据类型相同.
per-tensor模式: 两个参数的shape均为(1,), 数据类型和query数据类型相同.
per-token模式: 两个参数的shape均为(1, B, KV_S), 数据类型固定为float32.
per-tensor叠加per-head模式: 两个参数的shape均为(KV_N,), 数据类型和query数据类型相同.
per-token叠加per-head模式: 两个参数的shape均为(B, KV_N, KV_S), 数据类型固定为float32.
per-token叠加使用page attention模式: 两个参数的shape均为(blocknum, blocksize), 数据类型固定为float32.
per-token叠加per head并使用page attention模式: 两个参数的shape均为(blocknum, KV_N, blocksize), 数据类型固定为float32.
key支持per-channel叠加value支持per-token模式: 对于key支持per-channel, 两个参数的shape可支持(1, KV_N, 1, D)、(1, KV_N, D)、(1, H), 且参数数据类型和query数据类型相同. 对于value支持per-token, 两个参数的shape均为(1, B, KV_S)并且数据类型固定为float32.
场景下key和value条件
per-channel模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int4(int32)或int8时支持. Atlas A3 训练系列产品: 当key、value数据类型为int4(int32)或int8时支持.
per-tensor模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int8时支持. Atlas A3 训练系列产品: 当key、value数据类型为int8时支持.
per-token模式: key、value数据类型为int4(int32)或int8时支持.
per-tensor叠加per-head模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int8时支持. Atlas A3 训练系列产品: 当key、value数据类型为int8时支持.
per-token叠加per-head模式: key、value数据类型为int4(int32)或int8时支持.
per-token叠加使用page attention模式: key、value数据类型为int8时支持.
per-token叠加per head并使用page attention模式: key、value数据类型为int8时支持.
key支持per-channel叠加value支持per-token模式: Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 当key、value数据类型为int4(int32)或int8时支持; 当key和value的数据类型为int8时, 仅支持query和输出的dtype为float16. Atlas A3 训练系列产品: 当key、value数据类型为int4(int32)或int8时支持; 当key和value的数据类型为int8时, 仅支持query和输出的dtype为float16.
支持的产品: Atlas A2 训练系列产品/Atlas 800I A2 推理产品. Atlas A3 训练系列产品
pse_shift功能使用限制如下:
pse_shift数据类型需与query数据类型保持一致. 仅支持D轴对齐, 即D轴可以被16整除.
支持的PyTorch版本
PyTorch 2.1
PyTorch 2.3
PyTorch 2.4
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
softmax_scale = 1/math.sqrt(128.0)
actseqlen = [164]
actseqlenkv = [1024]
# 调用FIA算子
out, _ = torch_npu.npu_fused_infer_attention_score_v2(q, k, v,
actual_seq_qlen = actseqlen, actual_seq_kvlen = actseqlenkv,
num_query_heads = 8, input_layout = "BNSD", softmax_scale = softmax_scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出out类似如下
tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
..
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16)
图模式调用
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
softmax_scale = 1/math.sqrt(128.0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_fused_infer_attention_score_v2(q, k, v, num_query_heads = 8, input_layout = "BNSD", softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
single_op = torch_npu.npu_fused_infer_attention_score_v2(q, k, v, num_query_heads = 8, input_layout = "BNSD", softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)
print("single op output with mask:", single_op[0], single_op[0].shape)
print("graph output with mask:", graph_output[0], graph_output[0].shape)
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]],
device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
"""
)
_add_torch_npu_docstr(
"_npu_fused_infer_attention_score_v2_get_max_workspace",
"""
功能描述:
算子功能:用于npu_fused_infer_attention_score_v2算子aclgraph tilling下沉场景,获取最大workspace size并创建一个此size大小的tensor。
接口原型:
torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace(Tensor query, Tensor key, Tensor value, *, Tensor? query_rope=None, Tensor? key_rope=None, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_qlen=None, SymInt[]? actual_seq_kvlen=None, Tensor? block_table=None, Tensor? dequant_scale_query=None, Tensor? dequant_scale_key=None, Tensor? dequant_offset_key=None, Tensor? dequant_scale_value=None, Tensor? dequant_offset_value=None, Tensor? dequant_scale_key_rope=None, Tensor? quant_scale_out=None, Tensor? quant_offset_out=None, Tensor? quant_scale_p=None, Tensor? learnable_sink=None, int num_query_heads=1, int num_key_value_heads=0, float softmax_scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int sparse_mode=0, int block_size=0, int query_quant_mode=0, int key_quant_mode=0, int value_quant_mode=0, int inner_precise=0, bool return_softmax_lse=False, int? query_dtype=None, int? key_dtype=None, int? value_dtype=None, int? query_rope_dtype=None, int? key_rope_dtype=None, int? key_shared_prefix_dtype=None, int? value_shared_prefix_dtype=None, int? dequant_scale_query_dtype=None, int? dequant_scale_key_dtype=None, int? dequant_scale_value_dtype=None, int? dequant_scale_key_rope_dtype=None, int? out_dtype=None) -> Tensor
参数说明:
输入与npu_fused_infer_attention_score_v2一致
输出类型为Tensor, 由aclnnFusedInferAttentionScoreV4GetMaxWorkspaceSize返回最大的Size,返回创建的workspace tensor。
约束说明:
当Q_S等于1时:请参考Incre_Flash_Attention限制
当Q_S大于1时:请参考Prompt_Flash_Attention限制
支持的芯片型号:
Atlas A2 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
softmax_scale = 1/math.sqrt(128.0)
# 调用FIA算子
out = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace(q, k, v, num_query_heads = 8, input_layout = "BNSD", softmax_scale = softmax_scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出类似如下
tensor([0., 0., ..., 0., 0., 0.],
device='npu:0', dtype=torch.float16)
# 入图方式
暂不支持入图
"""
)
_add_torch_npu_docstr(
"_npu_fused_infer_attention_score_v2_infer_output",
"""
功能描述:
算子功能:用于npu_fused_infer_attention_score_v2算子aclgraph tilling下沉场景,推算output tensor 并创建一个此size大小的tensor, 实际返回output_tensor 和 softmax_lse_tensor。
接口原型:
torch_npu._npu_fused_infer_attention_score_v2_infer_output(Tensor query, Tensor value, *, int? query_dtype=None, int? value_dtype=None, str input_layout="BSH", Tensor? quant_scale_out=None, Tensor? block_table=None, int num_query_heads=1, int num_key_value_heads=0, bool return_softmax_lse=False, Tensor? query_rope=None, int? out_dtype=None) -> (Tensor, Tensor)
参数说明:
输入为npu_fused_infer_attention_score_v2的子集
输出类型为(Tensor, Tensor), 由适配层推导,计算返回对应的output_tensor 和 softmax_lse_tensor。
约束说明:
当Q_S等于1时:请参考Incre_Flash_Attention限制
当Q_S大于1时:请参考Prompt_Flash_Attention限制
支持的芯片型号:
Atlas A2 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
# 调用FIA算子
out,softmax_lse = torch_npu._npu_fused_infer_attention_score_v2_infer_output(q, v, num_query_heads = 8, input_layout = "BNSD")
# 执行上述代码的输出类似如下
tensor([0., 0., ..., 0., 0., 0.],
device='npu:0', dtype=torch.float16)
tensor([0., 0., ..., 0., 0., 0.],
device='npu:0', dtype=torch.float16)
# 入图方式
暂不支持入图
"""
)
_add_torch_npu_docstr(
"npu_fused_infer_attention_score_v2.out",
"""
功能描述:
算子功能:npu_fused_infer_attention_score_v2.out算子实现,可用于aclgraph tilling下沉场景(需传入workspace tensor),输入参数相比npu_fused_infer_attention_score_v2增加workspace、attention_out、softmax_lse。
计算公式:atten_out = softmax(softmax_scale*(query*key)+atten_mask)*value
接口原型:
torch_npu.npu_fused_infer_attention_score_v2.out(Tensor query, Tensor key, Tensor value, *, Tensor? query_rope=None, Tensor? key_rope=None, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_qlen=None, SymInt[]? actual_seq_kvlen=None, Tensor? block_table=None, Tensor? dequant_scale_query=None, Tensor? dequant_scale_key=None, Tensor? dequant_offset_key=None, Tensor? dequant_scale_value=None, Tensor? dequant_offset_value=None, Tensor? dequant_scale_key_rope=None, Tensor? quant_scale_out=None, Tensor? quant_offset_out=None, Tensor? quant_scale_p=None, Tensor? learnable_sink=None, int num_query_heads=1, int num_key_value_heads=0, float softmax_scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int sparse_mode=0, int block_size=0, int query_quant_mode=0, int key_quant_mode=0, int value_quant_mode=0, int inner_precise=0, bool return_softmax_lse=False, int? query_dtype=None, int? key_dtype=None, int? value_dtype=None, int? query_rope_dtype=None, int? key_rope_dtype=None, int? key_shared_prefix_dtype=None, int? value_shared_prefix_dtype=None, int? dequant_scale_query_dtype=None, int? dequant_scale_key_dtype=None, int? dequant_scale_value_dtype=None, int? dequant_scale_key_rope_dtype=None, int? out_dtype=None, Tensor? workspace=None, Tensor(a!) attention_out, Tensor(b!) softmax_lse) -> (Tensor(a!), Tensor(b!))
参数说明:
在torch_npu.npu_fused_infer_attention_score_v2的基础上增加下面三个参数:
workspace(可选): 一维Device侧的Input Tensor,数据类型与Query一致;
attention_out(aclTensor*,计算输出): 计算的最终结果Attention output tensor, shape与Query一致;
softmax_lse(aclTensor*,计算输出): 也是一个输出结果,当前预留,暂不支持;
约束说明:
当Q_S等于1时:请参考Incre_Flash_Attention限制
当Q_S大于1时:请参考Prompt_Flash_Attention限制
支持的芯片型号:
Atlas A2 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
workspace = torch.randn(2000000, dtype=torch.float16).npu()
output = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
softmax_lse = torch.randn(1, dtype=torch.float16).npu()
softmax_scale = 1/math.sqrt(128.0)
# 调用FIA算子
out = torch_npu.npu_fused_infer_attention_score_v2.out(q, k, v, workspace=workspace, out=[output, softmax_lse], num_query_heads = 8, input_layout = "BNSD", softmax_scale = softmax_scale, pre_tokens=65535, next_tokens=65535)
# 执行上述代码的输出output类似如下
tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.float16)
# 入图方式
暂不支持入图
"""
)
_add_torch_npu_docstr(
"npu_mla_prolog",
"""
功能描述:
推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为四路,首先对输入x乘以WeightDq进行下采样和RmsNorm后分成两路,第一路乘以WeightUq和WeightUk经过两次上采样后得到query;第二路乘以WeightQr后经过旋转位置编码(ROPE)得到query_rope;第三路是输入x乘以WeightDkv进行下采样和RmsNorm后传入Cache中得到kvCache;第四路是输入x乘以Wkr后经过旋转位置编码后传入另一个Cache中得到krCache。
接口原型:
torch_npu.npu_mla_prolog(Tensor token_x, Tensor weight_dq, Tensor weight_uq_qr, Tensor weight_uk, Tensor weight_dkv_kr, Tensor rmsnorm_gamma_cq, Tensor rmsnorm_gamma_ckv, Tensor rope_sin, Tensor rope_cos, Tensor cache_index, Tensor kv_cache, Tensor kr_cache, *, Tensor? dequant_scale_x=None, Tensor? dequant_scale_w_dq=None, Tensor? dequant_scale_w_uq_qr=None, Tensor? dequant_scale_w_dkv_kr=None, Tensor? quant_scale_ckv=None, Tensor? quant_scale_ckr=None, Tensor? smooth_scales_cq=None, float rmsnorm_epsilon_cq=1e-05, float rmsnorm_epsilon_ckv=1e-05, str cache_mode="PA_BSND") -> (Tensor, Tensor, Tensor, Tensor)
参数说明:
- token_x(Tensor):必选参数,对应公式中x。shape支持2维和3维,格式为(T, He)和(B, S, He),dtype支持bfloat16,数据格式支持ND。
- weight_dq(Tensor):必选参数,表示计算Query的下采样权重矩阵,即公式中W<sup>DQ</sup>。shape支持2维,格式为(He, Hcq),dtype支持bfloat16,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。
- weight_uq_qr(Tensor):必选参数,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中W<sup>UQ</sup>和W<sup>QR</sup>。shape支持2维,格式为(Hcq, N*(D+Dr)),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- 当weight_uq_qr为int8类型时,weight_uq_qr是一个per-tensor的量化后的输入,表示当前为部分量化场景。
此时若kv_cache、kr_cache为bfloat16类型,对应kv_cache_out、kr_cache_out为非量化输出,此时dequant_scale_w_uq_qr字段必须传入,smooth_scales_cq字段可选传入。
此时若kv_cache、kr_cache为int8类型,对应kv_cache_out、kr_cache_out为量化输出,此时dequant_scale_w_uq_qr、quant_scale_ckv、quant_scale_ckr字段必须传入,smooth_scales_cq字段可选传入。
- 当weight_uq_qr为bfloat16类型时,表示当前为非量化场景。
此时dequant_scale_w_uq_qr、quant_scale_ckv、quant_scale_ckr、smooth_scales_cq字段不能传入(即为none)。
- weight_uk(Tensor):必选参数,表示计算Key的上采样权重,即公式中W<sup>UK</sup>。shape支持3维,格式为(N, D, Hckv),dtype支持bfloat16,数据格式支持ND。
- weight_dkv_kr(Tensor):必选参数,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中W<sup>DKV</sup>和W<sup>KR</sup>。shape支持2维,格式为(He, Hckv+Dr),dtype支持bfloat16,数据格式支持FRACTAL_NZ。
- rmsnorm_gamma_cq(Tensor):必选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hcq,),dtype支持bfloat16,数据格式支持ND。
- rmsnorm_gamma_ckv(Tensor):必选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hckv,),dtype支持bfloat16,数据格式支持ND。
- rope_sin(Tensor):必选参数,表示用于计算旋转位置编码的正弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- rope_cos(Tensor):必选参数,表示用于计算旋转位置编码的余弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- cache_index(Tensor):必选参数,表示用于存储kv_cache和kr_cache的索引。shape支持1维和2维,格式为(T,)和(B, S),dtype支持int64,数据格式支持ND。
- cache_index的取值范围为[0,BlockNum*BlockSize),当前不会对cache_index传入值的合法性进行校验,需用户自行保证。
- kv_cache(Tensor):必选参数,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache(Tensor):必选参数,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
- dequant_scale_x(Tensor):预留参数,暂未使用,使用默认值即可。
- dequant_scale_w_dq(Tensor):预留参数,暂未使用,使用默认值即可。
- dequant_scale_w_uq_qr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数维per-channel。shape支持2维,格式为(1, N*(D+Dr)),dtype支持float,数据格式支持ND。
- dequant_scale_w_dkv_kr(Tensor):预留参数,暂未使用,使用默认值即可。
- quant_scale_ckv(Tensor):可选参数,用于对输出到kv_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Hckv),dtype支持float,数据格式支持ND。
- quant_scale_ckr(Tensor):可选参数,用于对输出到kr_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Dr),dtype支持float,数据格式支持ND。
- smooth_scales_cq(Tensor):可选参数,用于对RmsNormCq输出做动态量化操作时的参数。shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- rmsnorm_epsilon_cq(float):可选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- rmsnorm_epsilon_ckv(float):可选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- cache_mode(str):可选参数,表示kvCache的模式,支持"PA_BSND"、"PA_NZ",其用户不特意指定时可传入默认值“PA_BSND”。
输出说明:
- query(Tensor):表示Query的输出Tensor,即公式中q<sup>N</sup>。shape支持3维和4维,格式为(T, N, Hckv)和(B, S, N, Hckv),dtype支持bfloat16,数据格式支持ND。
- query_rope(Tensor):表示Query位置编码的输出Tensor,即公式中q<sup>R</sup>。shape支持3维和4维,格式为(T, N, Dr)和(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。
- kv_cache_out(Tensor):表示Key输出到kv_cache中的Tensor(本质in-place更新),即公式中k<sup>C</sup>。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache_out(Tensor):表示Key的位置编码输出到kr_cache中的Tensor(本质in-place更新),即公式中k<sup>R</sup>。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持图模式。
- 接口参数中shape格式字段含义:
- B:Batch表示输入样本批量大小,取值范围为0~65536。
- S:Seq-Length表示输入样本序列长度,取值范围为0~16。
- He:Head-Size表示隐藏层的大小,取值为7168。
- Hcq:q低秩矩阵维度,取值为1536。
- N:Head-Num表示多头数,取值范围为1、2、4、8、16、32、64、128。
- Hckv:kv低秩矩阵维度,取值为512。
- D:qk不含位置编码维度,取值为128。
- Dr:qk位置编码维度,取值为64。
- Nkv:kv的head数,取值为1。
- BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0。
- BlockSize:PagedAttention场景下的块大小,取值范围为16、128。
- T:BS合轴后的大小,取值范围:0~1048576。注:若采用BS合轴,此时token_x、rope_sin、rope_cos均为2维,cache_index为1维,query、query_rope为3维。
- shape约束
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则query、query_rope输出空Tensor,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新。
- 如果Skv取值为0,则query、query_rope正常计算,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新,即输出空Tensor。
支持的芯片型号:
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例:
# 单算子调用方式
import math
import torch
import torch_npu
# 生成随机数据, 并发送到npu
B = 8
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 1024
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.rand(B, S).to(torch.int64).npu()
kv_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Hckv, dtype=torch.bfloat16).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla = torch_npu.npu_mla_prolog(
token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast,
rmsnorm_gamma_cq, rmsnorm_gamma_ckv, rope_sin, rope_cos,
cache_index, kv_cache, kr_cache,
rmsnorm_epsilon_cq=rmsnorm_epsilon_cq,
rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode
)
# 执行上述代码的输出类似如下
tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.aoe_config.aoe_mode = "2"
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
B = 8
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 1024
S = 2
Nkv = 1
BlockNum = 32
BlockSize = 128
token_x = torch.rand(B, S, He, dtype=torch.bfloat16).npu()
w_dq = torch.rand(He, Hcq, dtype=torch.bfloat16).npu()
w_uq_qr = torch.rand(Hcq, N * (D + Dr), dtype=torch.bfloat16).npu()
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He, Hckv + Dr, dtype=torch.bfloat16).npu()
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.rand(B, S).to(torch.int64).npu()
kv_cache = torch.rand(BlockNum, BlockSize, Nkv, Hckv, dtype=torch.bfloat16).npu()
kr_cache = torch.rand(BlockNum, BlockSize, Nkv, Dr, dtype=torch.bfloat16).npu()
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch_npu.npu_mla_prolog(
token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, cache_index, kv_cache, kr_cache)
def MetaInfershape():
with torch.no_grad():
model = Model()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
graph_output = model()
if __name__ == "__main__":
MetaInfershape()
# 执行上述代码的输出类似如下
single op output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
graph output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]], device='npu:0', dtype=torch.bfloat16)
"""
)
_add_torch_npu_docstr(
"npu_mla_prolog_v2",
"""
该接口中kv_cache和kr_cache进行原地计算,未按in-place算子实现接口,推荐使用`torch_npu.npu_mla_prolog_v3`接口进行替换。
功能描述:
推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分五路,首先对输入x乘以WeightDq进行下采样和RmsNorm后分成两路,第一路乘以WeightUq和WeightUk经过两次上采样后得到query;第二路乘以WeightQr后经过旋转位置编码(ROPE)得到query_rope;第三路是输入x乘以WeightDkv进行下采样和RmsNorm后传入Cache中得到kvCache;第四路是输入x乘以Wkr后经过旋转位置编码后传入另一个Cache中得到krCache;第五路是输出query经过DynamicQuant后得到的量化参数。
接口原型:
torch_npu.npu_mla_prolog_v2(Tensor token_x, Tensor weight_dq, Tensor weight_uq_qr, Tensor weight_uk, Tensor weight_dkv_kr, Tensor rmsnorm_gamma_cq, Tensor rmsnorm_gamma_ckv, Tensor rope_sin, Tensor rope_cos, Tensor cache_index, Tensor kv_cache, Tensor kr_cache, *, Tensor? dequant_scale_x=None, Tensor? dequant_scale_w_dq=None, Tensor? dequant_scale_w_uq_qr=None, Tensor? dequant_scale_w_dkv_kr=None, Tensor? quant_scale_ckv=None, Tensor? quant_scale_ckr=None, Tensor? smooth_scales_cq=None, float rmsnorm_epsilon_cq=1e-05, float rmsnorm_epsilon_ckv=1e-05, str cache_mode="PA_BSND") -> (Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明:
- token_x(Tensor):必选参数,对应公式中x。shape支持2维和3维,格式为(T, He)和(B, S, He),dtype支持bfloat16和int8,数据格式支持ND。
- weight_dq(Tensor):必选参数,表示计算Query的下采样权重矩阵,即公式中W<sup>DQ</sup>。shape支持2维,格式为(He, Hcq),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。
- weight_uq_qr(Tensor):必选参数,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中W<sup>UQ</sup>和W<sup>QR</sup>。shape支持2维,格式为(Hcq, N*(D+Dr)),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- weight_uk(Tensor):必选参数,表示计算Key的上采样权重,即公式中W<sup>UK</sup>。shape支持3维,格式为(N, D, Hckv),dtype支持bfloat16,数据格式支持ND。
- weight_dkv_kr(Tensor):必选参数,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中W<sup>DKV</sup>和W<sup>KR</sup>。shape支持2维,格式为(He, Hckv+Dr),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- rmsnorm_gamma_cq(Tensor):必选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hcq,),dtype支持bfloat16,数据格式支持ND。
- rmsnorm_gamma_ckv(Tensor):必选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hckv,),dtype支持bfloat16,数据格式支持ND。
- rope_sin(Tensor):必选参数,表示用于计算旋转位置编码的正弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- rope_cos(Tensor):必选参数,表示用于计算旋转位置编码的余弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- cache_index(Tensor):必选参数,表示用于存储kv_cache和kr_cache的索引。shape支持1维和2维,格式为(T)和(B, S),dtype支持int64,数据格式支持ND。
- cache_index的取值范围为[0,BlockNum*BlockSize),当前不会对cache_index传入值的合法性进行校验,需用户自行保证。
- kv_cache(Tensor):必选参数,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache(Tensor):必选参数,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
- dequant_scale_x(Tensor):可选参数,用于输入token_x为int8类型时,下采样后进行反量化操作时的参数,token_x量化方式为pertoken。其shape支持2维,格式为(T, 1)和(BS, 1),dtype支持float,数据格式支持ND。
- dequant_scale_w_dq(Tensor):可选参数,用于输入token_x为int8类型时,下采样后进行反量化操作时的参数,token_x量化方式为perchannel。其shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- dequant_scale_w_uq_qr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数维perchannel。shape支持2维,格式为(1, N*(D+Dr)),dtype支持float,数据格式支持ND。
- dequant_scale_w_dkv_kr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化算法为perchannel。其shape支持2维,格式为(1, Hckv+Dr),dtype支持float,数据格式支持ND。
- quant_scale_ckv(Tensor):可选参数,用于对输出到kv_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Hckv),dtype支持float,数据格式支持ND。
- quant_scale_ckr(Tensor):可选参数,用于对输出到kr_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Dr),dtype支持float,数据格式支持ND。
- smooth_scales_cq(Tensor):可选参数,用于对RmsNormCq输出做动态量化操作时的参数。shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- rmsnorm_epsilon_cq(float):可选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- rmsnorm_epsilon_ckv(float):可选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- cache_mode(str):可选参数,表示kvCache的模式,支持"PA_BSND"、"PA_NZ",其用户不特意指定时可传入默认值“PA_BSND”。
输出说明:
- query(Tensor):表示Query的输出Tensor,即公式中q<sup>N</sup>。shape支持3维和4维,格式为(T, N, Hckv)和(B, S, N, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- query_rope(Tensor):表示Query位置编码的输出Tensor,即公式中q<sup>R</sup>。shape支持3维和4维,格式为(T, N, Dr)和(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。
- kv_cache_out(Tensor):表示Key输出到kv_cache中的Tensor(本质in-place更新),即公式中k<sup>C</sup>。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache_out(Tensor):表示Key的位置编码输出到kr_cache中的Tensor(本质in-place更新),即公式中k<sup>R</sup>。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
- dequant_scale_q_nope(Tensor):表示Query的输出Tensor的反量化参数。其shape支持1维和3维,全量化kv_cache量化场景下,其shape为(T, N, 1)和(B*S, N, 1);其他场景下,其shape为(1),dtype支持float,数据格式支持ND。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持图模式。
- 接口参数中shape格式字段含义:
- B:Batch表示输入样本批量大小,取值范围为0~65536。
- S:Seq-Length表示输入样本序列长度,取值范围为0~16。
- He:Head-Size表示隐藏层的大小,取值为7168。
- Hcq:q低秩矩阵维度,取值为1536。
- N:Head-Num表示多头数,取值范围为1、2、4、8、16、32、64、128。
- Hckv:kv低秩矩阵维度,取值为512。
- D:qk不含位置编码维度,取值为128。
- Dr:qk位置编码维度,取值为64。
- Nkv:kv的head数,取值为1。
- BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0。
- BlockSize:PagedAttention场景下的块大小,取值范围为16、128。
- T:BS合轴后的大小,取值范围:0~1048576。
- shape约束:
- 若token_x的维度采用BS合轴,即(T, He),则rope_sin和rope_cos的shape为(T, Dr),cache_index的shape为(T,),dequant_scale_x的shape为(T, 1),query的shape为(T, N, Hckv),query_rope的shape为(T, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(T, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
- 若token_x的维度不采用BS合轴,即(B, S, He),则rope_sin和rope_cos的shape为(B, S, Dr),cache_index的shape为(B, S),dequant_scale_x的shape为(B*S, 1),query的shape为(B, S, N, Hckv),query_rope的shape为(B, S, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(B*S, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则query、query_rope、dequant_scale_q_nope输出空Tensor,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新。
- 如果Skv取值为0,则query、query_rope、dequant_scale_q_nope正常计算,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新,即输出空Tensor。
支持的芯片型号:
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
B = 32
He=7168
Hcq=1536
Hckv=512
N=32
D=128
Dr=64
Skv=6144
S=2
Nkv=1
block_size=128
block_num=math.ceil(B*Skv/block_size)
BS = B * S
token_x = torch.rand(B, S, He).to(torch.int8).npu()
torch_npu.get_npu_format(token_x)
w_dq = torch.rand(He, Hcq).to(torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
torch_npu.get_npu_format(w_dq_cast)
w_uq_qr = torch.rand(Hcq, N*(D+Dr)).to(torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He,Hckv+Dr).to(torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
sin = torch.rand(B,S,Dr, dtype=torch.bfloat16).npu()
cos = torch.rand(B,S,Dr, dtype=torch.bfloat16).npu()
cache_index = torch.rand(B,S).to(torch.int64).npu()
kv_cache = torch.rand(1, block_numblock_sizeNkvHckv).to(torch.int8).npu()
kv_cache = kv_cache.view(block_num, block_size, Nkv, Hckv)
kr_cache = torch.rand(1, block_numblock_sizeNkvDr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(block_num, block_size, Nkv, Dr)
dequant_scale_x = torch.rand(BS, 1, dtype=torch.float).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float).npu()
dequant_scale_w_uq_qr = torch.rand(1,N*(D+Dr), dtype=torch.float).npu()
dequant_scale_w_dkv_kr = torch.rand(1,Hckv+Dr, dtype=torch.float).npu()
quant_scale_ckv = torch.rand(1,Hckv, dtype=torch.float).npu()
cache_mode = "PA_NZ"
# 调用MlaPrologV2算子
query, query_rope, kvcache, krcache,dequant_scale_q_nope = torch.ops.npu.npu_mla_prolog_v2(token_x, w_dq, w_uq_qr, w_uk,
w_dkv_kr, gamma_cq, gamma_ckv, sin, cos, cache_index, kv_cache, kr_cache, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, cache_mode=cache_mode)
# 执行上述代码的输出类似如下
tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.aoe_config.aoe_mode = "2"
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
B = 32
He=7168
Hcq=1536
Hckv=512
N=32
D=128
Dr=64
Skv=6144
S=1
Nkv=1
block_size=128
block_num=math.ceil(B*Skv/block_size)
BS = B * S
class Model_ds(torch.nn.Module):
def init(self):
super().init()
def forward(self, token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, cache_index, kv_cache, kr_cache, dequant_scale_x,
dequant_scale_w_dq, dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr,
quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, cache_mode = "PA_BSND"):
query, query_rope, kvcache, krcache,dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2(token_x,
w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, cache_index, kv_cache, kr_cache, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, quant_scale_ckr=None,
smooth_scales_cq=None, cache_mode = cache_mode)
return query, query_rope, kvcache, krcache, dequant_scale_q_nope
if name=="main":
torch_npu.npu.set_device(0)
token_x = torch.rand(B, S, He).to(torch.int8).npu()
torch_npu.get_npu_format(token_x)
w_dq = torch.rand(He, Hcq).to(torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
torch_npu.get_npu_format(w_dq_cast)
w_uq_qr = torch.rand(Hcq, N*(D+Dr)).to(torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He,Hckv+Dr).to(torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
sin = torch.rand(B,S,Dr, dtype=torch.bfloat16).npu()
cos = torch.rand(B,S,Dr, dtype=torch.bfloat16).npu()
cache_index = torch.rand(B,S).to(torch.int64).npu()
kv_cache = torch.rand(1, block_num*block_size*Nkv*Hckv).to(torch.int8).npu()
kv_cache = kv_cache.view(block_num, block_size, Nkv, Hckv)
kr_cache = torch.rand(1, block_num*block_size*Nkv*Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(block_num, block_size, Nkv, Dr)
dequant_scale_x = torch.rand(BS, 1, dtype=torch.float).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float).npu()
dequant_scale_w_uq_qr = torch.rand(1,N*(D+Dr), dtype=torch.float).npu()
dequant_scale_w_dkv_kr = torch.rand(1,Hckv+Dr, dtype=torch.float).npu()
quant_scale_ckv = torch.rand(1,Hckv, dtype=torch.float).npu()
cache_mode = "PA_NZ" # PA_BSND
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
model = Model_ds().npu()
# 图模式调用
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
query, query_rope, kvcache, krcache,dequant_scale_q_nope = model(token_x, w_dq, w_uq_qr, w_uk,
w_dkv_kr, gamma_cq, gamma_ckv, sin, cos, cache_index, kv_cache, kr_cache, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, quant_scale_ckr=None,
smooth_scales_cq=None, cache_mode=cache_mode)
# 单算子调用
query, query_rope, kvcache, krcache,dequant_scale_q_nope = torch.ops.npu.npu_mla_prolog_v2(token_x, w_dq, w_uq_qr, w_uk,
w_dkv_kr, gamma_cq, gamma_ckv, sin, cos, cache_index, kv_cache, kr_cache, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, cache_mode=cache_mode)
# 执行上述代码的输出类似如下
single op output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
graph output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]], device='npu:0', dtype=torch.bfloat16)
"""
)
_add_torch_npu_docstr(
"npu_mla_prolog_v3",
"""
功能描述:
推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分五路,首先对输入x乘以WeightDq进行下采样和RmsNorm后分成两路,第一路乘以WeightUq和WeightUk经过两次上采样后得到query;第二路乘以WeightQr后经过旋转位置编码(ROPE)得到query_rope;第三路是输入x乘以WeightDkv进行下采样和RmsNorm后传入Cache中得到kvCache;第四路是输入x乘以Wkr后经过旋转位置编码后传入另一个Cache中得到krCache;第五路是输出query经过DynamicQuant后得到的量化参数。
接口原型:
torch_npu.npu_mla_prolog_v3(Tensor token_x, Tensor weight_dq, Tensor weight_uq_qr, Tensor weight_uk, Tensor weight_dkv_kr, Tensor rmsnorm_gamma_cq, Tensor rmsnorm_gamma_ckv, Tensor rope_sin, Tensor rope_cos, Tensor(a!) kv_cache, Tensor(b!) kr_cache, *, Tensor? cache_index=None, Tensor? dequant_scale_x=None, Tensor? dequant_scale_w_dq=None, Tensor? dequant_scale_w_uq_qr=None, Tensor? dequant_scale_w_dkv_kr=None, Tensor? quant_scale_ckv=None, Tensor? quant_scale_ckr=None, Tensor? smooth_scales_cq=None, float rmsnorm_epsilon_cq=1e-05, float rmsnorm_epsilon_ckv=1e-05, str cache_mode="PA_BSND", float qc_qr_scale=1.0, float kc_scale=1.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明:
- token_x(Tensor):必选参数,对应公式中x。shape支持2维和3维,格式为(T, He)和(B, S, He),dtype支持bfloat16和int8,数据格式支持ND。
- weight_dq(Tensor):必选参数,表示计算Query的下采样权重矩阵,即公式中W<sup>DQ</sup>。shape支持2维,格式为(He, Hcq),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。
- weight_uq_qr(Tensor):必选参数,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中W<sup>UQ</sup>和W<sup>QR</sup>。shape支持2维,格式为(Hcq, N*(D+Dr)),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- weight_uk(Tensor):必选参数,表示计算Key的上采样权重,即公式中W<sup>UK</sup>。shape支持3维,格式为(N, D, Hckv),dtype支持bfloat16,数据格式支持ND。
- weight_dkv_kr(Tensor):必选参数,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中W<sup>DKV</sup>和W<sup>KR</sup>。shape支持2维,格式为(He, Hckv+Dr),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- rmsnorm_gamma_cq(Tensor):必选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hcq,),dtype支持bfloat16,数据格式支持ND。
- rmsnorm_gamma_ckv(Tensor):必选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hckv,),dtype支持bfloat16,数据格式支持ND。
- rope_sin(Tensor):必选参数,表示用于计算旋转位置编码的正弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- rope_cos(Tensor):必选参数,表示用于计算旋转位置编码的余弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- kv_cache(Tensor):必选参数,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache(Tensor):必选参数,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
- cache_index(Tensor):可选参数,表示用于存储kv_cache和kr_cache的索引。shape支持1维和2维,格式为(T)和(B, S),dtype支持int64,数据格式支持ND。
- cache_index的取值范围为[0,BlockNum*BlockSize),当前不会对cache_index传入值的合法性进行校验,需用户自行保证。
- dequant_scale_x(Tensor):可选参数,用于输入token_x为int8类型时,下采样后进行反量化操作时的参数,token_x量化方式为pertoken。其shape支持2维,格式为(T, 1)和(BS, 1),dtype支持float,数据格式支持ND。
- dequant_scale_w_dq(Tensor):可选参数,用于输入token_x为int8类型时,下采样后进行反量化操作时的参数,token_x量化方式为perchannel。其shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- dequant_scale_w_uq_qr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数维perchannel。shape支持2维,格式为(1, N*(D+Dr)),dtype支持float,数据格式支持ND。
- dequant_scale_w_dkv_kr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化算法为perchannel。其shape支持2维,格式为(1, Hckv+Dr),dtype支持float,数据格式支持ND。
- quant_scale_ckv(Tensor):可选参数,用于对输出到kv_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Hckv),dtype支持float,数据格式支持ND。
- quant_scale_ckr(Tensor):可选参数,用于对输出到kr_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Dr),dtype支持float,数据格式支持ND。
- smooth_scales_cq(Tensor):可选参数,用于对RmsNormCq输出做动态量化操作时的参数。shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- actual_seq_len(Tensor):可选预留参数,当前版本暂未使用。
- k_nope_clip_alpha(Tensor):可选参数,表示kv_cache做clip操作时的缩放因子,当前仅在kvcache per-tile量化场景下使用。不支持非连续,数据格式支持ND,数据类型支持float,shape为[1]。
- rmsnorm_epsilon_cq(float):可选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- rmsnorm_epsilon_ckv(float):可选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- cache_mode(str):可选参数,表示kvCache的模式,支持"PA_BSND"、"PA_NZ",其用户不特意指定时可传入默认值“PA_BSND”。
- query_norm_flag(int):可选参数,表示是否输出query_norm,Host侧参数。仅支持bool类型,False表示不输出query_norm,True表示输出query_norm,默认值为0。
- weight_quant_mode(int):可选参数,表示weight_dq、weight_uq_qr、weight_uk、weight_dkv_kr的量化模式,Host侧参数。仅支持int64类型,0表示非量化,1表示weight_uq_qr量化,2表示weight_dq、 weight_uk、weight_dkv_kr量化,默认值为0。
- kv_cache_quant_mode(int):可选参数,表示kv_cache的量化模式,Host侧参数。仅支持int64类型,0表示非量化,1表示per-tensor量化,2表示per-channel量化,3-表示per-tile量化,默认值为0。
- query_quant_mode(int):可选参数,表示query的量化模式,Host侧参数。仅支持int64类型,0表示非量化,1表示per-token-head量化,默认值为0。
- ckvkr_repo_mode(int):可选参数,表示kv_cache和kr_cache的存储模式,Host侧参数。仅支持int64类型,0表示kv_cache和kr_cache分别存储,1表示kv_cache和kr_cache合并存储,默认值为0。
- quant_scale_repo_mode(int):可选参数,表示量化scale的存储模式,Host侧参数。仅支持int64类型,0表示量化scale和数据分别存储,1表示量化scale和数据合并存储,默认值为0。
- tile_size(int):可选参数,表示per-tile量化时每个tile的大小,仅在kv_cache_quant_mode为3时有效,Host侧参数,默认值为128。
- qc_qr_scale(float):可选参数,表示Query的尺度矫正参数,不传入的时候默认值为1.0。
- kc_scale(float):可选参数,表示Key的尺度矫正参数,不传入的时候默认值为1.0。
输出说明:
- query(Tensor):表示Query的输出Tensor,即公式中q<sup>N</sup>。shape支持3维和4维,格式为(T, N, Hckv)和(B, S, N, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- query_rope(Tensor):表示Query位置编码的输出Tensor,即公式中q<sup>R</sup>。shape支持3维和4维,格式为(T, N, Dr)和(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。
- dequant_scale_q_nope(Tensor):表示Query的输出Tensor的反量化参数。其shape支持1维和3维,全量化kv_cache量化场景下,其shape为(T, N, 1)和(B*S, N, 1);其他场景下,其shape为(1),dtype支持float,数据格式支持ND。
- query_norm(Tensor):预留输出,默认生成shape为(1,)的零张量,dtype支持bfloat16和int8,数据格式支持ND。
- dequant_scale_q_norm(Tensor):预留输出,默认生成shape为(1,)的零张量,dtype支持float,数据格式支持ND。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持图模式。
- 接口参数中shape格式字段含义:
- B:Batch表示输入样本批量大小,取值范围为0~65536。
- S:Seq-Length表示输入样本序列长度,取值范围为0~16。
- He:Head-Size表示隐藏层的大小,取值为7168、7680或6144。
- Hcq:q低秩矩阵维度,取值为1536。
- N:Head-Num表示多头数,取值范围为1、2、4、8、16、32、64、128。
- Hckv:kv低秩矩阵维度,取值为512。
- Dtile:kv_cache per-tile量化时的矩阵维度,取值为656
- D:qk不含位置编码维度,取值为128。
- Dr:qk位置编码维度,取值为64。
- Nkv:kv的head数,取值为1。
- BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0。
- BlockSize:PagedAttention场景下的块大小,取值范围为16、128。
- T:BS合轴后的大小,取值范围:0~1048576。
- shape约束:
- 若token_x的维度采用BS合轴,即(T, He),则rope_sin和rope_cos的shape为(T, Dr),cache_index的shape为(T,),dequant_scale_x的shape为(T, 1),query的shape为(T, N, Hckv),query_rope的shape为(T, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(T, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
- 若token_x的维度不采用BS合轴,即(B, S, He),则rope_sin和rope_cos的shape为(B, S, Dr),cache_index的shape为(B, S),dequant_scale_x的shape为(B*S, 1),query的shape为(B, S, N, Hckv),query_rope的shape为(B, S, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(B*S, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则query、query_rope、dequant_scale_q_nope输出空Tensor,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新。
- 如果Skv取值为0,则query、query_rope、dequant_scale_q_nope正常计算,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新,即输出空Tensor。
支持的芯片型号:
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 2
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
token_x = torch.randint(-100, 100, (B, S, He), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (He, Hcq), dtype=torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.randint(-100, 100, (Hcq, N * (D + Dr)), dtype=torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (He, Hckv + Dr), dtype=torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * BlockSize * Nkv * Hckv), dtype=torch.int8).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 10.0
kc_scale = 10.0
dequant_scale_x = torch.rand(B * S, 1, dtype=torch.float32).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float32).npu()
dequant_scale_w_uqqr = torch.rand(1, N * (D + Dr), dtype=torch.float32).npu()
dequant_scale_w_dkvkr = torch.rand(1, Hckv + Dr, dtype=torch.float32).npu()
quant_scale_ckv = torch.rand(1, Hckv, dtype=torch.float32).npu()
smooth_scale_cq = torch.ones(1, Hcq, dtype=torch.float32).npu()
# 调用MlaPrologV3算子
query_mla, query_rope_mla, dequant_scale_q_nope_mla, query_norm_mla, dequant_scale_q_norm_mla = torch.ops.npu.mla_prolog_npu_v3(token_x, w_dq_cast,
w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, qc_qr_scale, kc_scale, cache_index, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq)
# 执行上述代码的输出类似如下
tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.aoe_config.aoe_mode = "2"
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
class Model_ds(torch.nn.Module):
def init(self):
super().init()
def forward(self, token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, kv_cache, kr_cache, cache_index, dequant_scale_x,
dequant_scale_w_dq, dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr,
quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, epsilon_cq = 0.00001, epsilon_ckv = 0.00001,
cache_mode = "PA_BSND", qc_qr_scale = 1.0, kc_scale = 1.0):
return torch_npu.npu_mla_prolog_v3(token_x,
w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, kv_cache, kr_cache, cache_index=cache_index, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, quant_scale_ckr=None,
smooth_scales_cq=None, epsilon_cq=epsilon_cq, epsilon_ckv=epsilon_ckv, cache_mode=cache_mode,
qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
if name=="main":
torch_npu.npu.set_device(0)
token_x = torch.randint(-100, 100, (B, S, He), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (He, Hcq), dtype=torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.randint(-100, 100, (Hcq, N * (D + Dr)), dtype=torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (He, Hckv + Dr), dtype=torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * BlockSize * Nkv * Hckv), dtype=torch.int8).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 10.0
kc_scale = 10.0
dequant_scale_x = torch.rand(B * S, 1, dtype=torch.float32).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float32).npu()
dequant_scale_w_uqqr = torch.rand(1, N * (D + Dr), dtype=torch.float32).npu()
dequant_scale_w_dkvkr = torch.rand(1, Hckv + Dr, dtype=torch.float32).npu()
quant_scale_ckv = torch.rand(1, Hckv, dtype=torch.float32).npu()
smooth_scale_cq = torch.ones(1, Hcq, dtype=torch.float32).npu()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
model = Model_ds().npu()
# 图模式调用
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
query, query_rope, dequant_scale_q_nope, query_norm, dequant_scale_q_norm = model(token_x, w_dq, w_uq_qr, w_uk,
w_dkv_kr, gamma_cq, gamma_ckv, sin, cos, kv_cache, kr_cache, cache_index=cache_index, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, quant_scale_ckr=None,
smooth_scales_cq=None, epsilon_cq=epsilon_cq, epsilon_ckv=epsilon_ckv, cache_mode=cache_mode, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
# 单算子调用
query_mla, query_rope_mla, dequant_scale_q_nope_mla, query_norm_mla, dequant_scale_q_norm_mla = torch.ops.npu.mla_prolog_npu_v3(token_x, w_dq_cast,
w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, qc_qr_scale, kc_scale, cache_index, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq)
# 执行上述代码的输出类似如下
single op output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
graph output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]], device='npu:0', dtype=torch.bfloat16)
"""
)
_add_torch_npu_docstr(
"npu_mla_prolog_v3_functional",
"""
功能描述:
推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分五路,首先对输入x乘以WeightDq进行下采样和RmsNorm后分成两路,第一路乘以WeightUq和WeightUk经过两次上采样后得到query;第二路乘以WeightQr后经过旋转位置编码(ROPE)得到query_rope;第三路是输入x乘以WeightDkv进行下采样和RmsNorm后传入Cache中得到kvCache;第四路是输入x乘以Wkr后经过旋转位置编码后传入另一个Cache中得到krCache;第五路是输出query经过DynamicQuant后得到的量化参数。
接口原型:
torch_npu.npu_mla_prolog_v3_functional(Tensor token_x, Tensor weight_dq, Tensor weight_uq_qr, Tensor weight_uk, Tensor weight_dkv_kr, Tensor rmsnorm_gamma_cq, Tensor rmsnorm_gamma_ckv, Tensor rope_sin, Tensor rope_cos, Tensor kv_cache, Tensor kr_cache, *, Tensor? cache_index=None, Tensor? dequant_scale_x=None, Tensor? dequant_scale_w_dq=None, Tensor? dequant_scale_w_uq_qr=None, Tensor? dequant_scale_w_dkv_kr=None, Tensor? quant_scale_ckv=None, Tensor? quant_scale_ckr=None, Tensor? smooth_scales_cq=None, float rmsnorm_epsilon_cq=1e-05, float rmsnorm_epsilon_ckv=1e-05, str cache_mode="PA_BSND", float qc_qr_scale=1.0, float kc_scale=1.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明:
- token_x(Tensor):必选参数,对应公式中x。shape支持2维和3维,格式为(T, He)和(B, S, He),dtype支持bfloat16和int8,数据格式支持ND。
- weight_dq(Tensor):必选参数,表示计算Query的下采样权重矩阵,即公式中W<sup>DQ</sup>。shape支持2维,格式为(He, Hcq),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ(可通过torch_npu.npu_format_cast将ND格式转为FRACTAL_NZ格式)。
- weight_uq_qr(Tensor):必选参数,表示计算Query的上采样权重矩阵和Query的位置编码权重矩阵,即公式中W<sup>UQ</sup>和W<sup>QR</sup>。shape支持2维,格式为(Hcq, N*(D+Dr)),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- weight_uk(Tensor):必选参数,表示计算Key的上采样权重,即公式中W<sup>UK</sup>。shape支持3维,格式为(N, D, Hckv),dtype支持bfloat16,数据格式支持ND。
- weight_dkv_kr(Tensor):必选参数,表示计算Key的下采样权重矩阵和Key的位置编码权重矩阵,即公式中W<sup>DKV</sup>和W<sup>KR</sup>。shape支持2维,格式为(He, Hckv+Dr),dtype支持bfloat16和int8,数据格式支持FRACTAL_NZ。
- rmsnorm_gamma_cq(Tensor):必选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hcq,),dtype支持bfloat16,数据格式支持ND。
- rmsnorm_gamma_ckv(Tensor):必选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的_γ_参数。shape支持1维,格式为(Hckv,),dtype支持bfloat16,数据格式支持ND。
- rope_sin(Tensor):必选参数,表示用于计算旋转位置编码的正弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- rope_cos(Tensor):必选参数,表示用于计算旋转位置编码的余弦参数矩阵。shape支持2维和3维,格式为(T, Dr)和(B, S, Dr),dtype支持bfloat16,数据格式支持ND。
- kv_cache(Tensor):必选参数,表示用于cache索引的aclTensor。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache(Tensor):必选参数,表示用于key位置编码的cache。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
- cache_index(Tensor):可选参数,表示用于存储kv_cache和kr_cache的索引。shape支持1维和2维,格式为(T)和(B, S),dtype支持int64,数据格式支持ND。
- cache_index的取值范围为[0,BlockNum*BlockSize),当前不会对cache_index传入值的合法性进行校验,需用户自行保证。
- dequant_scale_x(Tensor):可选参数,用于输入token_x为int8类型时,下采样后进行反量化操作时的参数,token_x量化方式为pertoken。其shape支持2维,格式为(T, 1)和(BS, 1),dtype支持float,数据格式支持ND。
- dequant_scale_w_dq(Tensor):可选参数,用于输入token_x为int8类型时,下采样后进行反量化操作时的参数,token_x量化方式为perchannel。其shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- dequant_scale_w_uq_qr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数维perchannel。shape支持2维,格式为(1, N*(D+Dr)),dtype支持float,数据格式支持ND。
- dequant_scale_w_dkv_kr(Tensor):可选参数,用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化算法为perchannel。其shape支持2维,格式为(1, Hckv+Dr),dtype支持float,数据格式支持ND。
- quant_scale_ckv(Tensor):可选参数,用于对输出到kv_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Hckv),dtype支持float,数据格式支持ND。
- quant_scale_ckr(Tensor):可选参数,用于对输出到kr_cache_out中的数据做量化操作时的参数。shape支持2维,格式为(1, Dr),dtype支持float,数据格式支持ND。
- smooth_scales_cq(Tensor):可选参数,用于对RmsNormCq输出做动态量化操作时的参数。shape支持2维,格式为(1, Hcq),dtype支持float,数据格式支持ND。
- actual_seq_len(Tensor):可选预留参数,当前版本暂未使用。
- k_nope_clip_alpha(Tensor):可选参数,表示kv_cache做clip操作时的缩放因子,当前仅在kvcache per-tile量化场景下使用。不支持非连续,数据格式支持ND,数据类型支持float,shape为[1]。
- rmsnorm_epsilon_cq(float):可选参数,表示计算c<sup>Q</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- rmsnorm_epsilon_ckv(float):可选参数,表示计算c<sup>KV</sup>的RmsNorm公式中的ε参数,用户不特意指定时可传入默认值1e-05。
- cache_mode(str):可选参数,表示kvCache的模式,支持"PA_BSND"、"PA_NZ",其用户不特意指定时可传入默认值“PA_BSND”。
- query_norm_flag(int):可选参数,表示是否输出query_norm,Host侧参数。仅支持bool类型,False表示不输出query_norm,True表示输出query_norm,默认值为0。
- weight_quant_mode(int):可选参数,表示weight_dq、weight_uq_qr、weight_uk、weight_dkv_kr的量化模式,Host侧参数。仅支持int64类型,0表示非量化,1表示weight_uq_qr量化,2表示weight_dq、 weight_uk、weight_dkv_kr量化,默认值为0。
- kv_cache_quant_mode(int):可选参数,表示kv_cache的量化模式,Host侧参数。仅支持int64类型,0表示非量化,1表示per-tensor量化,2表示per-channel量化,3-表示per-tile量化,默认值为0。
- query_quant_mode(int):可选参数,表示query的量化模式,Host侧参数。仅支持int64类型,0表示非量化,1表示per-token-head量化,默认值为0。
- ckvkr_repo_mode(int):可选参数,表示kv_cache和kr_cache的存储模式,Host侧参数。仅支持int64类型,0表示kv_cache和kr_cache分别存储,1表示kv_cache和kr_cache合并存储,默认值为0。
- quant_scale_repo_mode(int):可选参数,表示量化scale的存储模式,Host侧参数。仅支持int64类型,0表示量化scale和数据分别存储,1表示量化scale和数据合并存储,默认值为0。
- tile_size(int):可选参数,表示per-tile量化时每个tile的大小,仅在kv_cache_quant_mode为3时有效,Host侧参数,默认值为128。
- qc_qr_scale(float):可选参数,表示Query的尺度矫正参数,不传入的时候默认值为1.0。
- kc_scale(float):可选参数,表示Key的尺度矫正参数,不传入的时候默认值为1.0。
输出说明:
- query(Tensor):表示Query的输出Tensor,即公式中q<sup>N</sup>。shape支持3维和4维,格式为(T, N, Hckv)和(B, S, N, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- query_rope(Tensor):表示Query位置编码的输出Tensor,即公式中q<sup>R</sup>。shape支持3维和4维,格式为(T, N, Dr)和(B, S, N, Dr),dtype支持bfloat16,数据格式支持ND。
- dequant_scale_q_nope(Tensor):表示Query的输出Tensor的反量化参数。其shape支持1维和3维,全量化kv_cache量化场景下,其shape为(T, N, 1)和(B*S, N, 1);其他场景下,其shape为(1),dtype支持float,数据格式支持ND。
- query_norm(Tensor):预留输出,默认生成shape为(1,)的零张量,dtype支持bfloat16和int8,数据格式支持ND。
- dequant_scale_q_norm(Tensor):预留输出,默认生成shape为(1,)的零张量,dtype支持float,数据格式支持ND。
- kv_cache_out(Tensor):表示Key输出到kv_cache中的Tensor(本质in-place更新),即公式中k<sup>C</sup>。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Hckv),dtype支持bfloat16和int8,数据格式支持ND。
- kr_cache_out(Tensor):表示Key的位置编码输出到kr_cache中的Tensor(本质in-place更新),即公式中k<sup>R</sup>。shape支持4维,格式为(BlockNum, BlockSize, Nkv, Dr),dtype支持bfloat16和int8,数据格式支持ND。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持图模式。
- 接口参数中shape格式字段含义:
- B:Batch表示输入样本批量大小,取值范围为0~65536。
- S:Seq-Length表示输入样本序列长度,取值范围为0~16。
- He:Head-Size表示隐藏层的大小,取值为7168、7680或6144。
- Hcq:q低秩矩阵维度,取值为1536。
- N:Head-Num表示多头数,取值范围为1、2、4、8、16、32、64、128。
- Hckv:kv低秩矩阵维度,取值为512。
- Dtile:kv_cache per-tile量化时的矩阵维度,取值为656
- D:qk不含位置编码维度,取值为128。
- Dr:qk位置编码维度,取值为64。
- Nkv:kv的head数,取值为1。
- BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0。
- BlockSize:PagedAttention场景下的块大小,取值范围为16、128。
- T:BS合轴后的大小,取值范围:0~1048576。
- shape约束:
- 若token_x的维度采用BS合轴,即(T, He),则rope_sin和rope_cos的shape为(T, Dr),cache_index的shape为(T,),dequant_scale_x的shape为(T, 1),query的shape为(T, N, Hckv),query_rope的shape为(T, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(T, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
- 若token_x的维度不采用BS合轴,即(B, S, He),则rope_sin和rope_cos的shape为(B, S, Dr),cache_index的shape为(B, S),dequant_scale_x的shape为(B*S, 1),query的shape为(B, S, N, Hckv),query_rope的shape为(B, S, N, Dr)。全量化kv_cache量化场景下,dequant_scale_q_nope的shape为(B*S, N, 1),其他场景下dequant_scale_q_nope的shape为(1)。
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则query、query_rope、dequant_scale_q_nope输出空Tensor,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新。
- 如果Skv取值为0,则query、query_rope、dequant_scale_q_nope正常计算,kv_cache、kr_cache、kv_cache_out、kr_cache_out不更新,即输出空Tensor。
支持的芯片型号:
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例:
# 单算子调用方式
import torch
import torch_npu
import math
# 生成随机数据, 并发送到npu
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 2
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
tile_size = 128
Dtile = (
Hckv
+ quant_scale_repo_mode * (Hckv // tile_size) * 4
+ ckvkr_repo_mode * Dr * 2
)
token_x = torch.randint(-100, 100, (B, S, He), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (He, Hcq), dtype=torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.randint(-100, 100, (Hcq, N * (D + Dr)), dtype=torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (He, Hckv + Dr), dtype=torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * BlockSize * Nkv * Hckv), dtype=torch.int8).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Dtile)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 1.0
kc_scale = 1.0
dequant_scale_x = torch.rand(B * S, 1, dtype=torch.float32).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float32).npu()
dequant_scale_w_uqqr = torch.rand(1, N * (D + Dr), dtype=torch.float32).npu()
dequant_scale_w_dkvkr = torch.rand(1, Hckv + Dr, dtype=torch.float32).npu()
quant_scale_ckv = None
quant_scale_ckr = None
smooth_scale_cq = torch.ones(1, Hcq, dtype=torch.float32).npu()
actual_seq_len = None
query_norm_flag = True
weight_quant_mode = 2
kv_cache_quant_mode = 3
query_quant_mode = 0
ckvkr_repo_mode = 1
quant_scale_repo_mode = 1
k_nope_clip_alpha = torch.tensor([1], dtype=torch.float32).npu()
# 调用npu_mla_prolog_v3_functional算子
query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = torch.ops.npu.npu_mla_prolog_v3_functional(token_x, w_dq_cast,
w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, quant_scale_ckr, smooth_scale_cq, actual_seq_len, k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, query_norm_flag, weight_quant_mode, kv_cache_quant_mode, query_quant_mode, ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale)
# 执行上述代码的输出类似如下
tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
# 入图方式
import torch
import torch_npu
import math
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"
# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.aoe_config.aoe_mode = "2"
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl
# 数据生成
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
tile_size = 128
Dtile = (
Hckv
+ quant_scale_repo_mode * (Hckv // tile_size) * 4
+ ckvkr_repo_mode * Dr * 2
)
class Model_ds(torch.nn.Module):
def init(self):
super().init()
def forward(self, token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, kv_cache, kr_cache, cache_index, dequant_scale_x,
dequant_scale_w_dq, dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr,
quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len, k_nope_clip_alpha, epsilon_cq = 0.00001, epsilon_ckv = 0.00001, cache_mode = "PA_BSND", query_norm_flag=False, weight_quant_mode=0, kv_cache_quant_mode=0, query_quant_mode=0, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, qc_qr_scale = 1.0, kc_scale = 1.0):
return torch_npu.npu_mla_prolog_v3_functional(token_x,
w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, kv_cache, kr_cache, cache_index=cache_index, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr, quant_scale_ckv=quant_scale_ckv, quant_scale_ckr=quant_scale_ckr,
smooth_scales_cq=smooth_scales_cq, actual_seq_len=actual_seq_len, k_nope_clip_alpha=k_nope_clip_alpha, epsilon_cq=epsilon_cq, epsilon_ckv=epsilon_ckv,
cache_mode=cache_mode, query_norm_flag=query_norm_flag, weight_quant_mode=weight_quant_mode, kv_cache_quant_mode=kv_cache_quant_mode, query_quant_mode=query_quant_mode, ckvkr_repo_mode=ckvkr_repo_mode, quant_scale_repo_mode=quant_scale_repo_mode, tile_size=tile_size, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
if name=="main":
torch_npu.npu.set_device(0)
token_x = torch.randint(-100, 100, (B, S, He), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (He, Hcq), dtype=torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.randint(-100, 100, (Hcq, N * (D + Dr)), dtype=torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (He, Hckv + Dr), dtype=torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * BlockSize * Nkv * Hckv), dtype=torch.int8).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 1.0
kc_scale = 1.0
dequant_scale_x = torch.rand(B * S, 1, dtype=torch.float32).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float32).npu()
dequant_scale_w_uqqr = torch.rand(1, N * (D + Dr), dtype=torch.float32).npu()
dequant_scale_w_dkvkr = torch.rand(1, Hckv + Dr, dtype=torch.float32).npu()
actual_seq_len = None
quant_scale_ckv = None
quant_scale_ckr = None
smooth_scale_cq = torch.ones(1, Hcq, dtype=torch.float32).npu()
query_norm_flag = True
weight_quant_mode = 2
kv_cache_quant_mode = 3
query_quant_mode = 0
ckvkr_repo_mode = 1
quant_scale_repo_mode = 1
k_nope_clip_alpha = torch.tensor([1], dtype=torch.float32).npu()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
model = Model_ds().npu()
# 图模式调用
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = model(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, gamma_cq, gamma_ckv,
sin, cos, kv_cache, kr_cache, cache_index=cache_index, dequant_scale_x=dequant_scale_x,
dequant_scale_w_dq=dequant_scale_w_dq, dequant_scale_w_uq_qr=dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr=dequant_scale_w_dkv_kr,
quant_scale_ckv=quant_scale_ckv, quant_scale_ckr=quant_scale_ckr, smooth_scales_cq=smooth_scales_cq, actual_seq_len=actual_seq_len, k_nope_clip_alpha=k_nope_clip_alpha, epsilon_cq=rmsnorm_epsilon_cq, epsilon_ckv=rmsnorm_epsilon_ckv, cache_mode=cache_mode, query_norm_flag=query_norm_flag, weight_quant_mode=weight_quant_mode, kv_cache_quant_mode=kv_cache_quant_mode, query_quant_mode=query_quant_mode, ckvkr_repo_mode=ckvkr_repo_mode, quant_scale_repo_mode=quant_scale_repo_mode, tile_size=tile_size, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
# 单算子调用
query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = torch.ops.npu.npu_mla_prolog_v3_functional(token_x, w_dq_cast,
w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, quant_scale_ckr, smooth_scale_cq, actual_seq_len, k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, query_norm_flag, weight_quant_mode, kv_cache_quant_mode, query_quant_mode, ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale)
# 执行上述代码的输出类似如下
single op output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]],
device='npu:0', dtype=torch.bfloat16)
graph output: tensor([[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140],
[ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117],
[ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062],
...,
[ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039],
[ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164],
[ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]], device='npu:0', dtype=torch.bfloat16)
"""
)
_add_torch_npu_docstr(
"npu_all_gather_base_mm",
"""
接口原型:
torch_npu.npu_all_gather_base_mm(input, x2, hcom, world_size, *, bias=None, x1_scale=None, x2_scale=None, gather_index=0, bool gather_output=True, comm_turn=0, output_dtype=None, comm_mode=None) -> (Tensor, Tensor)
功能描述
TP切分场景下, 实现allgather和matmul的融合, 实现通信和计算流水并行.
使用该接口时, 请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本, 否则将会引发报错, 比如BUS ERROR等.
参数说明
input: Tensor类型, 数据类型支持float16、bfloat16、int8, 数据格式支持ND, 输入shape支持2维, 形如(m, k)、(k, n), 轴满足matmul算子入参要求, k轴相等, 且k轴取值范围为[256, 65535).
x2: Tensor类型, 数据类型、输入shape维度需要和input保持一致, 数据格式支持ND、NZ。NZ仅在comm_mode为aiv时支持。
hcom: String类型, 通信域handle名, 通过get_hccl_comm_name接口获取.
world_size: int类型, 通信域内的rank总数, 仅支持为2、4、8.
*: 代表其之前的变量是位置相关, 按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
bias: Tensor类型, 可选输入, 数据类型支持float16、bfloat16, 数据格式支持ND格式. 数据类型需要和input保持一致. bias仅支持一维, 且维度大小与output的第1维大小相同. 当前版本暂不支持bias输入为非0的场景.
x1_scale: 可选Tensor类型,mm左矩阵反量化参数。数据类型支持float32,数据格式支持ND格式。数据维度为(m, 1), 支持pertoken量化。
x2_scale: 可选Tensor类型。mm右矩阵反量化参数。数据类型支持float32、int64,数据格式支持ND格式。数据维度为(1, n), 支持perchannel量化。如需传入int64数据类型的,需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的x2_scale。
gather_index: int类型, 表示gather操作对象, 0: 对input做gather, 1: 对x2做gather. 默认值0. 当前版本仅支持输入0.
gather_output: bool类型, 表示是否需要gather输出, 默认值true。
comm_turn: int类型, 表示rank间通信切分粒度, 默认值: 0, 表示默认的切分方式. 当前版本仅支持输入0.
output_dtype :可选dtype参数。表示第一个输出的数据类型。仅支持在量化场景且x1_scale和x2_scale均为float32时,可指定输出数据类型为bfloat16或float1,默认值为bfloat16。
comm_mode:可选str参数。表示通信模式,支持ai_cpu、aiv两种模式。ai_cpu模式仅支持基础场景。aiv模式支持基础场景和量化场景。
输出说明
两个输出, 均为Tensor类型: (Tensor, Tensor)
- Tensor:第一个输出Tensor是allgather+matmul的结果。
基础场景时数据类型和input保持一致。
量化场景下,x2_scale为int64数据类型时,输出数据类型为float16。x1_scale和x2_scale均为float32时, 输出数据类型由output_dtype指定,默认为bfloat16。
- Tensor:第二个输出Tensor是allgather的结果。
约束说明
该接口支持训练场景下使用.
该接口支持图模式.
Atlas A2 训练系列产品支持2、4、8卡, 支持hccs链路all mesh组网(每张卡和其它卡两两相连).
Atlas A3 训练系列产品支持2、4、8、16卡, 支持hccs链路double ring组网(多张卡按顺序组成一个圈, 每张卡只和左右卡相连).
input不支持输入转置后的tensor, x2转置后输入, 需要满足shape的第一维大小与x1的最后一维相同, 满足matmul的计算条件.
Atlas A2 训练系列产品: 一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce), 仅支持相同通信域.
支持的PyTorch版本
PyTorch 2.1
PyTorch 2.0
PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例
单算子模式调用
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcomm_info = default_pg.get_hccl_comm_name(rank)
tensor_allgather_shape = x1_shape
single_shape = [x1_shape[0] // world_size, x1_shape[1]]
input_ = torch.randn(single_shape, dtype=dtype).npu()
weight = torch.randn(x2_shape, dtype=dtype).npu()
output, gather_out = torch_npu.npu_all_gather_base_mm(input_, weight, hcomm_info, world_size)
if __name__ == "__main__":
worksize = 8
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [128, 512]
x2_shape = [512, 64]
dtype = torch.float16
mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
图模式调用
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
class ALLGATHER_MM_GRAPH_Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, weight, hcomm_info, world_size, gather_output):
output, gather_output = torch_npu.npu_all_gather_base_mm(input, weight, hcomm_info, world_size,
gather_output=gather_output)
return output, gather_output
def define_model(model, graph_type):
import torchair
if graph_type == 1: # 传统入图模式, 静态shape+在线编译场景
npu_backend = torchair.get_npu_backend(compiler_config=None)
model = torch.compile(model, backend=npu_backend, dynamic=False)
elif graph_type == 2: # ACLNN入图模式, 动态shape+二进制
npu_backend = torchair.get_npu_backend(compiler_config=None)
model = torch.compile(model, backend=npu_backend, dynamic=True)
else:
print("Error type")
return model
def get_graph(input, weight, hcomm_info, world_size, gather_output):
model = ALLGATHER_MM_GRAPH_Model()
model = define_model(model, 2)
model_output = model(input, weight, hcomm_info, world_size, gather_output=gather_output)
output_npu = model_output[0]
gather_output_npu = model_output[1]
return output_npu, gather_output_npu
def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcomm_info = default_pg.get_hccl_comm_name(rank)
single_shape = [x1_shape[0] // world_size, x1_shape[1]]
input = torch.randn(single_shape, dtype=dtype).npu()
weight = torch.randn(x2_shape, dtype=dtype).npu()
is_gather_out = True
output, gather_out = get_graph(input, weight, hcomm_info, world_size, is_gather_out)
print("output:", output)
if __name__ == "__main__":
worksize = 8
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [128, 512]
x2_shape = [512, 64]
dtype = torch.float16
mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
"""
)
_add_torch_npu_docstr(
"npu_group_norm_silu",
"""
接口原型:
torch_npu.npu_group_norm_silu(Tensor input, Tensor weight, Tensor bias, int group, float eps) -> (Tensor, Tensor, Tensor)
功能描述
计算输入input的组归一化结果out、均值meanOut、标准差的倒数rstdOut、以及silu的输出.
参数说明
input: Tensor类型, 必选输入, 源数据张量, 维度需大于一维, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16、float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float、bfloat16.
weight: Tensor类型, 必选输入, 索引张量, 维度为1且元素数量需与输入input的第1维度保持相同, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16、float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float、bfloat16.
bias: Tensor类型, 必选输入, 更新数据张量, 维度为1元素数量需与输入input的第1维度保持相同, 数据格式支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16、float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float、bfloat16.
group: int类型, 必选输入, 表示将输入input的第1维度分为group组.
eps: float类型, 可选参数, 数值稳定性而加到分母上的值, 若保持精度, 则eps需大于0.
输出说明
out: Tensor类型, 数据类型和shape与input相同, 支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16、float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float、bfloat16.
meanOut: Tensor类型, 数据类型与input相同, shape为(N, group)支持ND, 支持非连续的Tensor.
Atlas 推理系列产品: 数据类型支持float16、float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float、bfloat16.
rstdOut: Tensor类型, 数据类型与input相同, shape为(N, group).
Atlas 推理系列产品: 数据类型支持float16、float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float16、float、bfloat16.
约束说明
该接口支持图模式.
input、weight、bias、out、meanOut、rstdOut数据类型必须支持的范围之内.
out、meanOut、rstdOut的数据类型与input相同; weight、bias与input可以不同.
input第1维度能整除group.
out的shape与input相同.
meanOut与rstdOut的shape为(N, group), 其中N为input第0维度值.
weight与bias的数据类型必须保持一致, 且数据类型的精度不能低于input的数据类型.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列产品
调用示例
单算子调用:
import torch
import numpy as np
import torch_npu
dtype = np.float32
shape_x = [24,320,48,48]
num_groups = 32
shape_c = [320]
eps = 0.00001
x_npu=torch.randn(shape_x,dtype=torch.float32).npu()
gamma_npu=torch.randn(shape_c,dtype=torch.float32).npu()
beta_npu=torch.randn(shape_c,dtype=torch.float32).npu()
out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps)
x_npu=torch.randn(shape_x,dtype=torch.bfloat16).npu()
gamma_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu()
beta_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu()
out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps)
x_npu=torch.randn(shape_x,dtype=torch.float16).npu()
gamma_npu=torch.randn(shape_c,dtype=torch.float16).npu()
beta_npu=torch.randn(shape_c,dtype=torch.float16).npu()
out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps)
"""
)
_add_torch_npu_docstr(
"npu_mm_reduce_scatter_base",
"""
接口原型:
torch_npu.npu_mm_reduce_scatter_base(input, x2, hcom, world_size, *, reduce_op='sum', bias=None, x1_scale=None, x2_scale=None, comm_turn=0, output_dtype=None, comm_mode=None) -> Tensor
功能描述
TP切分场景下, 实现matmul和reduce_scatter的融合, 融合算子内部实现计算和通信流水并行. 支持perchanel, pertoken量化。
使用该接口时, 请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本, 否则将会引发报错, 比如BUS ERROR等.
参数说明
input: Tensor类型, 数据类型支持float16、bfloat16、int8, 数据格式支持ND, 输入shape支持2维.
x2: Tensor类型, 数据类型支持float16、bfloat16、int8, 数据格式支持ND、NZ。NZ仅在comm_mode为aiv时支持。数据类型需要和input保持一致, 输入shape维度和input保持一致.
hcom: String类型, 通信域handle名, 通过get_hccl_comm_name接口获取.
world_size: int类型, 通信域内的rank总数, 仅支持为2、4、8.
*: 代表其之前的变量是位置相关, 按照顺序输入, 必选; 之后的变量是键值对赋值的, 位置无关, 可选(不输入会使用默认值).
reduce_op: String类型, reduce操作类型, 当前仅支持'sum', 默认值: 'sum'.
bias: Tensor类型, 可选输入, 数据类型支持float16、bfloat16, 数据格式支持ND格式. 数据类型需要和input保持一致. bias仅支持一维, 且维度大小与output的第1维大小相同. 当前版本暂不支持bias输入为非0的场景.
x1_scale: Tensor类型,可选参数。mm左矩阵反量化参数。数据类型支持float32,数据格式支持$ND$格式。数据维度为(m, 1), 支持pertoken量化。
x2_scale: Tensor类型,可选参数。mm左矩阵反量化参数。数据类型支持float32、int64,数据格式支持$ND$格式。数据维度为(1, n), 支持perchannel量化。如需传入int64数据类型的,需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的x2_scale。
comm_turn:int类型, 可选参数。表示rank间通信切分粒度,默认值为0,表示默认的切分方式。当前版本仅支持输入0。
output_dtype: ScalarType, 可选参数。表示输出数据类型。仅支持在量化场景且x1_scale和x2_scale均为float32时,可指定输出数据类型为bfloat16或float16,默认值为bfloat16。
comm_mode:str类型,可选参数。表示通信模式,支持ai_cpu、aiv两种模式。ai_cpu模式仅支持基础场景。aiv模式支持基础场景和量化场景。
输出说明
shape维度和input保持一致。
基础场景时数据类型和input保持一致。
量化场景下,x2_scale为int64数据类型时,输出数据类型为float16。x1_scale和x2_scale均为float32时, 输出数据类型由output_dtype指定,默认为torch.bfloat16。
约束说明
comm_mode为ai_cpu时:
该接口仅在训练场景下使用.
该接口支持图模式.
输入input、x2必须是2维, 分别为(m, k)、(k, n), 轴满足matmul算子入参要求, k轴相等, 且k轴取值范围为[256, 65535), m轴约束如下:
m轴需要整除world_size.
Atlas A2 训练系列产品支持2、4、8卡, 支持hccs链路all mesh组网(每张卡和其它卡两两相连).
Atlas A3 训练系列产品支持2、4、8、16卡, 支持hccs链路double ring组网(多张卡按顺序组成一个圈, 每张卡只和左右卡相连).
input不支持输入转置后的tensor, x2转置后输入, 需要满足shape的第一维大小与input的最后一维相同, 满足matmul的计算条件.
Atlas A2 训练系列产品: 一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce), 仅支持相同通信域.
comm_mode为aiv时:
- 支持Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
- 支持Atlas A3 训练系列产品/Atlas A3 推理系列产品
支持的PyTorch版本
PyTorch 2.1
PyTorch 2.0
PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例
单算子模式调用
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcomm_info = default_pg.get_hccl_comm_name(rank)
input_ = torch.randn(x1_shape, dtype=dtype).npu()
weight = torch.randn(x2_shape, dtype=dtype).npu()
output = torch_npu.npu_mm_reduce_scatter_base(input_, weight, hcomm_info, world_size)
if __name__ == "__main__":
worksize = 8
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [128, 512]
x2_shape = [512, 64]
dtype = torch.float16
mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
图模式调用
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
class MM_REDUCESCATTER_GRAPH_Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, weight, hcomm_info, world_size, reduce_op):
output = torch_npu.npu_mm_reduce_scatter_base(input, weight, hcomm_info, world_size,
reduce_op=reduce_op)
return output
def define_model(model, graph_type):
import torchair
if graph_type == 1: # 传统入图模式, 静态shape+在线编译场景
npu_backend = torchair.get_npu_backend(compiler_config=None)
model = torch.compile(model, backend=npu_backend, dynamic=False)
elif graph_type == 2: # ACLNN入图模式, 动态shape+二进制
npu_backend = torchair.get_npu_backend(compiler_config=None)
model = torch.compile(model, backend=npu_backend, dynamic=True)
else:
print("Error type")
return model
def get_graph(input, weight, hcomm_info, world_size):
model = MM_REDUCESCATTER_GRAPH_Model()
model = define_model(model, 2)
model_output = model(input, weight, hcomm_info, world_size, reduce_op="sum")
return model_output
def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcomm_info = default_pg.get_hccl_comm_name(rank)
input = torch.randn(x1_shape, dtype=dtype).npu()
weight = torch.randn(x2_shape, dtype=dtype).npu()
output = get_graph(input, weight, hcomm_info, world_size)
print("output:", output)
if __name__ == "__main__":
worksize = 8
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [128, 512]
x2_shape = [512, 64]
dtype = torch.float16
mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
"""
)
_add_torch_npu_docstr(
"npu_moe_compute_expert_tokens",
"""
接口原型:
torch_npu.npu_moe_compute_expert_tokens(Tensor sorted_expert_for_source_row, int num_expert) -> Tensor
功能描述
算子功能: MoE(Mixture of Experts, 混合专家模型)计算中, 通过二分查找的方式查找每个专家处理的最后一行的位置.
计算公式:
expertTokens_{i}=BinaerSearch(sortedExpertForSourceRow,numExpert)
参数说明
sorted_expert_for_source_row: Tensor类型, 必选参数, 经过专家处理过的结果, 要求是一个1D的Tensor, 数据类型支持int32, 数据格式要求为ND. shape大小需要小于2147483647.
num_expert: int类型, 必选参数, 总专家数.
输出说明
expertTokens: Tensor类型, 公式中的输出, 要求的是一个1D的Tensor, 数据类型与sorted_expert_for_source_row保持一致.
约束说明
该接口支持推理场景下使用.
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 2.0
PyTorch 1.11.0
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例
单算子模式调用
import torch
import torch_npu
sorted_experts = torch.tensor([3,3,4,5,6,7], dtype=torch.int32)
num_experts = 5
output = torch_npu.npu_moe_compute_expert_tokens(sorted_experts.npu(), num_experts)
图模式调用
import torch
import torch.nn as nn
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class GMMModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, sorted_experts, num_experts):
return torch_npu.npu_moe_compute_expert_tokens(sorted_experts, num_experts)
def main():
sorted_experts = torch.tensor([3,3,4,5,6,7], dtype=torch.int32)
num_experts = 5
model = GMMModel().npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
custom_output = model(sorted_experts, num_experts)
if __name__ == '__main__':
main()
"""
)
_add_torch_npu_docstr(
"npu_moe_finalize_routing",
"""
接口原型:
torch_npu.npu_moe_finalize_routing(Tensor expanded_permuted_rows, Tensor? skip1, Tensor? skip2, Tensor? bias, Tensor? scales, Tensor expanded_src_to_dst_row, Tensor? export_for_source_row, int? drop_pad_mode=0) -> Tensor
功能描述
算子功能: MoE计算中, 最后处理合并MoE FFN的输出结果.
计算公式:
expertid=exportForSourceRow[i,k]
out(i,j)=skip1_{i,j}+skip2Optional_{i,j}+\sum_{k=0}^{K}(scales_{i,k}*(expandPermutedRows_{expandedSrcToDstRow_{i+k*num_rows},j}+bias_{expertid,j}))
参数说明
expanded_permuted_rows: Tensor类型, 必选参数, 经过专家处理过的结果, 要求是一个2D的Tensor, 数据类型支持float16、bfloat16、float32, 数据格式要求为ND. shape支持(NUM_ROWS * K, H), NUM_ROWS为行数, K为从总的专家E中选出K个专家, H为列数.
skip1: Tensor类型, 可选参数, 求和的输入参数1, 要求是一个2D的Tensor, 数据类型要求与expanded_permuted_rows一致 , shape要求与输出out的shape一致.
skip2: Tensor类型, 可选参数, 求和的输入参数2, 要求是一个2D的Tensor, 数据类型要求与expanded_permuted_rows一致 , shape要求与输出out的shape一致. skip2参数为None时, skip1参数必须也为None.
bias: Tensor类型, 可选参数, 专家的偏差, 要求是一个2D的Tensor, 数据类型要求与expanded_permuted_rows一致. shape支持(E, H), E为总的专家个数, H为列数.
scales: Tensor类型, 可选参数, 专家的权重, 要求是一个2D的Tensor, 数据类型要求与expanded_permuted_rows一致, shape支持(NUM_ROWS, K).
expanded_src_to_dst_row: Tensor类型, 必选参数, 保存每个专家处理结果的索引, 要求是一个1D的Tensor, 数据类型支持int32. shape支持(NUM_ROWS * K), NUM_ROWS为行数, K为从总的专家E中选出K个专家, drop_pad_mode参数为0时, Tensor中的值取值范围是[0, NUM_ROWS * K-1].
export_for_source_row: Tensor类型, 可选参数, 每行处理的专家号, 要求是一个2D的Tensor, 数据类型支持int32. shape支持(NUM_ROWS, K), NUM_ROWS为行数, K为从总的专家E中选出K个专家.
drop_pad_mode: int类型, 可选参数, 表示是否支持丢弃模式, 取值范围为0, 默认值为0.
输出说明
out: Tensor类型, 最后处理合并MoE FFN的输出结果.
约束说明
该接口支持推理场景下使用.
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例
单算子模式调用
import torch
import torch_npu
expert_num = 16
token_len = 10
top_k = 4
num_rows = 50
device =torch.device('npu')
dtype = torch.float32
expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype)
skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype)
skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype)
bias = torch.randn((expert_num, token_len), device=device, dtype=dtype)
scales = torch.randn((num_rows, top_k), device=device, dtype=dtype)
expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32)
expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32)
drop_pad_mode = 0
output = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
图模式调用
import torch
import torch.nn as nn
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class GMMModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode):
return torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
def main():
expert_num = 16
token_len = 10
top_k = 4
num_rows = 50
device =torch.device('npu')
dtype = torch.float32
expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype)
skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype)
skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype)
bias = torch.randn((expert_num, token_len), device=device, dtype=dtype)
scales = torch.randn((num_rows, top_k), device=device, dtype=dtype)
expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32)
expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32)
drop_pad_mode = 0
model = GMMModel().npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
custom_output = model(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
if __name__ == '__main__':
main()
"""
)
_add_torch_npu_docstr(
"npu_moe_gating_top_k_softmax",
"""
接口原型:
torch_npu.npu_moe_gating_top_k_softmax(Tensor x, Tensor? finished=None, int k=1) -> (Tensor, Tensor, Tensor)
功能描述
MoE计算中, 对输入x做Softmax计算, 再做topk操作.
参数说明
x: Tensor类型, 必选输入, 表示待计算的输入要求是一个2D/3D的Tensor, 数据类型支持float16、bfloat16、float32, 数据格式要求为ND.
finished: Tensor类型, 可选输入, 表示输入中需要参与计算的行, 要求是一个1D/2D的Tensor, 数据类型支持bool, shape为x[:-1], 数据格式要求为ND.
k: Host侧的int类型, 表示topk的k值, 大小为0<k<=x的-1轴大小, k<=1024.
输出说明
y: Tensor类型, 对x做softmax后取的topk值, 要求是一个2D/3D的Tensor, 数据类型与x需要保持一致, 其非-1轴要求与x的对应轴大小一致, 其-1轴要求其大小同k值. 数据格式要求为ND.
expert_idx: Tensor类型, 对x做softmax后取topk值的索引, 即专家的序号. shape要求与y一致, 数据类型支持int32, 数据格式要求为ND.
row_idx: Tensor类型, 指示每个位置对应的原始行位置, shape要求与y一致, 数据类型支持int32, 数据格式要求为ND.
约束说明
该接口支持推理场景下使用.
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例
单算子模式调用
import torch
import torch_npu
x = torch.rand((3, 3), dtype=torch.float32).to("npu")
finished = torch.randint(2, size=(3,), dtype=torch.bool).to("npu")
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x, finished, k=2)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class MoeGatingTopkSoftmaxModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, finish, k):
res = torch_npu.npu_moe_gating_top_k_softmax(x, finish, k)
return res
x = torch.randn((2, 4, 6),device='npu',dtype=torch.float16).npu()
moe_gating_topk_softmax_model = MoeGatingTopkSoftmaxModel().npu()
moe_gating_topk_softmax_model = torch.compile(moe_gating_topk_softmax_model, backend=npu_backend, dynamic=True)
res = moe_gating_topk_softmax_model(x, None, 2)
print(res)
"""
)
_add_torch_npu_docstr(
"npu_moe_gating_top_k_softmax_v2",
"""
接口原型:
npu_moe_gating_top_k_softmax_v2(Tensor x, *, int k=1, Tensor? finished=None, int? renorm=0, bool? output_softmax=False) -> (Tensor, Tensor, Tensor)
功能描述
MoE计算中,当renorm参数设置为0时,对输入x做softmax操作,再做topk操作;当renorm参数设置为1时,对x做topk操作,后做softmax操作。
参数说明
x: Tensor类型, 必选输入, 表示待计算的输入要求是一个2D/3D的Tensor, 数据类型支持float16、bfloat16、float32, 数据格式要求为ND.
finished: Tensor类型, 可选输入, 表示输入中需要参与计算的行, 要求是一个1D/2D的Tensor, 数据类型支持bool, shape为x[:-1], 数据格式要求为ND.
k: Host侧的int类型, 表示topk的k值, 大小为0<k<=x的-1轴大小, k<=1024.
renorm: int类型,可选输入,表示先计算softmax还是先计算topk。
output_softmax:bool类型,可选输入,表示是否输出softmax的结果,取值true和false。true表示输出softmax的结果,false表示不输出。
输出说明
y: Tensor类型, 对x做softmax后取的topk值, 要求是一个2D/3D的Tensor, 数据类型与x需要保持一致, 其非-1轴要求与x的对应轴大小一致, 其-1轴要求其大小同k值. 数据格式要求为ND.
expert_idx: Tensor类型, 对x做softmax后取topk值的索引, 即专家的序号. shape要求与y一致, 数据类型支持int32, 数据格式要求为ND.
row_idx: Tensor类型, 指示每个位置对应的原始行位置, shape要求与y一致, 数据类型支持int32, 数据格式要求为ND.
约束说明
该接口支持推理场景下使用.
该接口支持图模式(PyTorch 2.1版本).
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例
单算子模式调用
import torch
import torch_npu
x = torch.rand((3, 3), dtype=torch.float32).to("npu")
finished = torch.randint(2, size=(3,), dtype=torch.bool).to("npu")
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
x, finished, k=2)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class MoeGatingTopkSoftmaxModelV2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, finished, k, renorm, output_softmax):
res = torch_npu.npu_moe_gating_top_k_softmax_v2(x = x, finished = finished, k = k, renorm = renorm, output_softmax = output_softmax)
return res
x = torch.randn((2, 4, 6),device='npu',dtype=torch.float16).npu()
moe_gating_topk_softmax_model_v2 = MoeGatingTopkSoftmaxV2Model().npu()
moe_gating_topk_softmax_model_v2 = torch.compile(moe_gating_topk_softmax_model_v2, backend=npu_backend, dynamic=True)
res = moe_gating_topk_softmax_model_v2(x = x, finished = None, k = 2, renorm = 0, output_softmax = output_softmax = True)
"""
)
_add_torch_npu_docstr(
"npu_moe_gating_top_k",
"""
函数原型
npu_moe_gating_top_k(x, k, bias=None, k_group=1, group_count=1, group_select_mode=0, renorm=0, norm_type=0, out_flag=False, routed_scaling_factor=1.0, eps=1e-20) -> (Tensor, Tensor, Tensor)
## 参数说明
- **x**(`Tensor`):必选参数,表示待计算的输入。要求是一个2D的Tensor,数据类型支持`float16`、`bfloat16`、`float32`,数据格式要求为ND。支持非连续Tensor。最后一维的大小(即专家数)要求不大于`2048`。
- **k**(`int`):必选参数,表示每个token最终筛选得到的专家个数,数据类型为`int64`。要求`1 <= k <= x_shape[-1] / group_count * k_group`。
- <strong>*</strong>:代表其之前的变量是位置相关,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
- **bias**(`Tensor`):可选参数,表示与输入`x`进行计算的bias值。要求是1D的Tensor,要求shape值与`x`的最后一维相等。数据类型支持`float16`、`bfloat16`、`float32`,数据类型需要与`x`保持一致,数据格式要求为ND。支持非连续`Tensor`。
- **k_group**(`int`):可选参数,表示每个token组筛选过程中,选出的专家组个数,数据类型为`int64`,默认值为`1`。要求`1 <= k_group <= group_count`,并且`k_group * x_shape[-1] / group_count`的值要大于等于`k`。
- **group_count**(`int`):可选参数,表示将全部专家划分的组数,数据类型为`int64`,默认值为`1`。要求`group_count > 0,x_shape[-1]`能够被`group_count`整除且整除后的结果大于`2`,并且整除的结果按照`32`个数对齐后乘`group_count`的结果不大于`2048`。
- **group_select_mode**(`int`):可选参数,表示一个专家组的总得分计算方式。默认值为`0`,`0`表示组内取最大值,作为专家组得分;`1`表示取组内Top2的专家进行得分累加,作为专家组得分。
- **renorm**(`int`):可选参数,表示renorm标记,默认值为`0`,表示先进行norm再进行topk计算。当前仅支持`0`。
- **norm_type**(`int`):可选参数,表示norm函数类型,`1`表示使用Sigmoid函数,`0`表示Softmax函数。默认值为`0`。
- **out_flag**(`bool`):可选参数,是否输出norm函数中间结果。默认值为`False`。
- **routed_scaling_factor**(`float`):可选参数,表示计算`yOut`使用的`routed_scaling_factor`系数,默认值为`1.0`。
- **eps**(`float`):可选参数,表示计算`yOut`使用的`eps`系数,默认值为`1e-20`。
返回值说明
y_out(Tensor):表示对x做norm操作和分组排序topk后计算的结果。是一个2D的Tensor,数据类型与`x`一致。
expert_idx_out(Tensor):表示对`x`做norm操作和分组排序topk后的索引,即专家的序号。shape与y_out一致,数据类型为`int32`。
norm_out(Tensor):表示norm计算的输出结果。shape与`x`一致,数据类型为`float32`。
约束说明
该接口支持推理场景下使用。
该接口支持图模式。
调用示例
单算子模式调用
import torch
import torch_npu
import numpy
k = 1
k_group = 4
group_count = 8
group_select_mode = 1
renorm = 0
norm_type = 1
out_flag = False
routed_scaling_factor = 1.0
eps = 1e-20
# 生成随机数据, 并发送到npu
x = numpy.random.uniform(0, 2, (16, 256)).astype(numpy.float16)
bias = numpy.random.uniform(0, 2, (256,)).astype(numpy.float16)
x_tensor = torch.tensor(x).npu()
bias_tensor = torch.tensor(bias).npu()
# 调用MoeGatingTopK算子
y_npu, expert_idx_npu, out_npu = torch_npu.npu_moe_gating_top_k(x_tensor, k, bias=bias_tensor, k_group=k_group, group_count=group_count, group_select_mode=group_select_mode, renorm=renorm, norm_type=norm_type, out_flag=out_flag, routed_scaling_factor=routed_scaling_factor, eps=eps)
图模式调用
# 入图方式
import torch
import torch_npu
import torchair
import numpy
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x_tensor, bias_tensor):
return torch_npu.npu_moe_gating_top_k(x_tensor, k, bias=bias_tensor, k_group=k_group, group_count=group_count, group_select_mode=group_select_mode, renorm=renorm, norm_type=norm_type, out_flag=out_flag, routed_scaling_factor=routed_scaling_factor, eps=eps)
# 实例化模型model
model = Model().npu()
# 从TorchAir获取NPU提供的默认backend
config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
# 使用TorchAir的backend去调用compile接口编译模型
model = torch.compile(model, backend=npu_backend)
k = 1
k_group = 4
group_count = 8
group_select_mode = 1
renorm = 0
norm_type = 1
out_flag = False
routed_scaling_factor = 1.0
eps = 1e-20
# 生成随机数据, 并发送到npu
x = numpy.random.uniform(0, 2, (16, 256)).astype(numpy.float16)
bias = numpy.random.uniform(0, 2, (256,)).astype(numpy.float16)
x_tensor = torch.tensor(x).npu()
bias_tensor = torch.tensor(bias).npu()
# 调用MoeGatingTopK算子
y_npu, expert_idx_npu, out_npu = model(x_tensor, bias_tensor)
"""
)
_add_torch_npu_docstr(
"npu_moe_gating_top_k_backward",
"""
函数原型
npu_moe_gating_top_k_backward(x_norm, grad_y, expertIdx, *, renorm=0, norm_type=0, routed_scaling_factor=1.0, eps=1e-20) -> Tensor
参数说明
- **x_norm**(`Tensor`):必选参数,前向算子 sigmoid/softmax 归一化后的得分,对应前向输出 `normOut`。要求是一个 2D 的 Tensor,维度为 [M, N],数据类型支持 `float32`,数据格式要求为 ND。支持非连续 Tensor。最后一维的大小(即专家数 N)要求大于等于 `2`,并小于等于 `2048`。
- **grad_y**(`Tensor`):必选参数,前向算子输出 `yOut` 的上游梯度。要求是一个 2D 的 Tensor,维度为 [M, K],数据类型支持 `float16`、`bfloat16`、`float32`,数据格式要求为 ND。不支持非连续 Tensor。K的范围要求大于等于 `1`,并小于等于 N。
- **expertIdx**(`Tensor`):必选参数,前向算子输出 `expertIdxOut`,对应 top-K 专家的索引。shape 要求与 `grad_y` 一致,数据类型支持 `int32`,数据格式要求为 ND。不支持非连续 Tensor。
- **renorm**(`int`):可选参数,前向算子在 softmax 模式下的 renorm 标记。`0` 表示不做renorm,`1` 表示需要做 renorm。预留参数,当前仅支持 sigmoid 模式。
- **norm_type**(`int`):可选参数,norm 函数类型,`1` 表示使用 Sigmoid 函数,`0` 表示 Softmax 函数。默认值为 `0`。当前仅支持 `1`(sigmoid 模式)。
- **routed_scaling_factor**(`float`):可选参数,前向计算 `yOut` 使用的缩放系数,默认值为 `1.0`。
- **eps**(`float`):可选参数,前向归一化使用的防除零常数,默认值为 `1e-20`。
返回值说明
- **grad_x**(`Tensor`):前向算子输入 `x` 的梯度。要求是一个 2D 的 Tensor,维度为 [M, N],数据类型与 `grad_y` 需要保持一致,shape 与 `x_norm` 需要一致,数据格式要求为 ND。不支持非连续 Tensor。
约束说明
- 该接口支持训练场景下使用。
调用示例
单算子模式调用
import torch
import torch_npu
# 示例参数:M=2048, N=192, K=10,sigmoid 模式,PanGuV3 典型 shape
M, N, K = 2048, 192, 10
norm_type = 1
routed_scaling_factor = 1.0
eps = 1e-20
# 构造输入(模拟前向算子的保存值)
x_norm = torch.rand(M, N, dtype=torch.float32).npu() # 前向 normOut
grad_y = torch.randn(M, K, dtype=torch.bfloat16).npu() # 上游梯度
expert_idx = torch.randint(0, N, (M, K), dtype=torch.int32).npu() # 前向 expertIdxOut
# 调用反向算子
grad_x = torch_npu.npu_moe_gating_top_k_backward(
x_norm, grad_y, expert_idx,
norm_type=norm_type,
routed_scaling_factor=routed_scaling_factor,
eps=eps)
print(f"grad_x shape: {grad_x.shape}") # [2048, 192]
print(f"grad_x dtype: {grad_x.dtype}") # bfloat16(与 grad_y 一致)
通过正向算子自动反向调用
import torch
import torch_npu
M, N, K = 2048, 192, 10
# 前向传播
x = torch.randn(M, N, dtype=torch.bfloat16, requires_grad=False).npu().requires_grad_(True)
y, expert_idx, x_norm = torch_npu.npu_moe_gating_top_k(
x, K, norm_type=1, routed_scaling_factor=1.0, eps=1e-20, out_flag=True)
# 模拟上游梯度
grad_y = torch.randn_like(y)
# 反向传播
y.backward(grad_y)
grad_x = x.grad.detach().clone()
print(f"grad_x shape: {grad_x.shape}") # [2048, 192]
"""
)
_add_torch_npu_docstr(
"npu_moe_init_routing",
"""
接口原型:
torch_npu.npu_moe_init_routing(Tensor x, Tensor row_idx, Tensor expert_idx, int active_num) -> (Tensor, Tensor, Tensor)
功能描述
算子功能: MoE的routing计算, 根据torch_npu.npu_moe_gating_top_k_softmax的计算结果做routing处理.
计算公式为:
expanded_expert_idx, sorted_rowIdx=keyValueSort(expert_idx,row_idx)
expanded_row_idx[sorted_row_idx[i]]=i
expanded_x[i]=x[sorted_row_idx[i]%num_rows]
参数说明
x: Tensor类型, 必选输入, MOE的输入即token特征输入, 要求为一个2D的Tensor, shape为 (NUM_ROWS, H). 数据类型支持float16、bfloat16、float32, 数据格式要求为ND. shape大小需要小于2^24.
row_idx: Tensor类型, 必选输入, 指示每个位置对应的原始行位置, shape要求与expert_idx一致. 数据类型支持int32, 数据格式要求为ND.
expert_idx: Tensor类型, 必选输入, torch_npu.npu_moe_gating_top_k_softmax的输出每一行特征对应的K个处理专家, 要求是一个2D的shape (NUM_ROWS, K), 数据类型支持int32, 数据格式要求为ND.
active_num: int类型, 表示总的最大处理row数, 输出expanded_x只有这么多行是有效的.
输出说明
expanded_x: Tensor类型, 根据expert_idx进行扩展过的特征, 要求是一个2D的Tensor, shape (min(NUM_ROWS, activeNum) * k, H). 数据类型同x, 数据格式要求为ND.
expanded_row_idx: Tensor类型, expanded_x和x的映射关系, 要求是一个1D的Tensor, Shape为(NUM_ROWS*K, ), 数据类型支持int32, 数据格式要求为ND.
expanded_expert_idx: Tensor类型, 输出expert_idx排序后的结果.
约束说明
该接口支持推理场景下使用.
该接口支持图模式.
支持的PyTorch版本
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例
单算子模式调用
import torch
import torch_npu
x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2],[0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).to("npu")
row_idx = torch.tensor([[0, 3], [1, 4], [2, 5]], dtype=torch.int32).to("npu")
expert_idx = torch.tensor([[1, 2], [0, 1], [0, 2]], dtype=torch.int32).to("npu")
active_num = 3
expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(x, row_idx, expert_idx, active_num)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class MoeInitRoutingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, row_idx, expert_idx, active_num):
expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(x, row_idx, expert_idx, active_num=active_num)
return expanded_x, expanded_row_idx, expanded_expert_idx
x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2],[0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).to("npu")
row_idx = torch.tensor([[0, 3], [1, 4], [2, 5]], dtype=torch.int32).to("npu")
expert_idx = torch.tensor([[1, 2], [0, 1], [0, 2]], dtype=torch.int32).to("npu")
active_num = 3
moe_init_routing_model = MoeInitRoutingModel().npu()
moe_init_routing_model = torch.compile(moe_init_routing_model, backend=npu_backend, dynamic=True)
expanded_x, expanded_row_idx, expanded_expert_idx = moe_init_routing_model(x, row_idx, expert_idx, active_num=active_num)
print(expanded_x)
print(expanded_row_idx)
print(expanded_expert_idx)
"""
)
_add_torch_npu_docstr(
"npu_moe_init_routing_v2",
"""
算子功能:MoE(Mixture of Expert)的routing计算,根据3.28 torch_npu.npu_moe_gating_top_k_softmax的计算结果做routing处理,支持不量化和动态量化模式。
接口原型
torch_npu.npu_moe_init_routing_v2(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, bool expert_tokens_num_flag=False, int quant_mode=0, int[2] active_expert_range=[], int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)
参数说明
x:Tensor类型,表示MoE的输入即token特征输入,要求为2D的Tensor,shape为(NUM_ROWS, H),H代表每个Token的长度。数据类型支持float16、bfloat16、float32、int8,数据格式要求为ND。
expert_idx:Tensor类型,表示torch_npu.npu_moe_gating_top_k_softmax输出每一行特征对应的K个处理专家,要求是2D的Tensor,shape为(NUM_ROWS, K),且专家id不能超过专家数。数据类型支持int32,数据格式要求为ND。
scale:Tensor类型,可选参数,用于计算量化结果的参数。数据类型支持float32,数据格式要求为ND。如果不输入表示计算时不使用scale,且输出expanded_scale中的值未定义。
非量化场景下,如果输入则要求为1D的Tensor,shape为(NUM_ROWS,)。
动态quant场景下,如果输入则要求为2D的Tensor,shape为(expert_end-expert_start, H)。
offset:Tensor类型,可选参数,用于计算量化结果的偏移值。数据类型支持float32,数据格式要求为ND。
在非量化场景下不输入。
动态quant场景下不输入。
active_num:int类型,表示总的最大处理row数,输出expanded_x只有这么多行是有效的,当前入参校验需大于等于0。当前未使用,校验需等于NUM_ROWS*K。
expert_capacity:int类型,表示每个专家能够处理的tokens数,取值范围大于等于0。当前未使用,仅校验非空。
expert_num:int类型,表示专家数,要求大于0。expert_tokens_num_type为key_value模式时,取值范围为(0, 5120];其他模式取值范围为(0, 10240]。
drop_pad_mode:int类型,表示是否为drop_pad场景,取值为0和1。0表示dropless场景,该场景下不校验expert_capacity。1表示drop_pad场景。当前仅支持0。
expert_tokens_num_type:int类型,取值为0、1和2。0表示cumsum模式;1表示count模式,即输出的值为各个专家处理的token数量的累计值;2表示key_value模式,即输出的值为专家和对应专家处理token数量的累计值 。当前仅支持1和2。
expert_tokens_num_flag:bool类型,表示是否输出expert_token_cumsum_or_count,默认False表示不输出。当前仅支持True。
quant_mode:int类型,表示量化模式,支持取值为0、1、-1。0表示静态量化,-1表示不量化场景;1表示动态quant场景。当前仅支持-1和1。x数据类型为int8时仅支持-1,不可再量化。
active_expert_range:int类型长度为2的数组,表示活跃expert的范围。数组内值为[expert_start, expert_end],表示活跃的expert范围在[expert_start, expert_end)区间内,左闭右开。要求数组内的值大于等于0,并且expert_end不大于expert_num。
row_idx_type:int类型,表示输出expanded_row_idx使用的索引类型,支持取值0和1,默认值0。0表示gather类型的索引;1表示scatter类型的索引。
输出说明
expanded_x:Tensor类型,根据expert_idx进行扩展过的特征,要求是2D的Tensor,shape为(NUM_ROWS*K, H)。非量化场景下数据类型同x;量化场景下数据类型支持int8。数据格式要求为ND。前available_idx_num*H个元素为有效数据,其余由row_idx_type决定。其中available_idx_num为expert_idx中在active_expert_range范围的元素的个数。量化场景下,当x的数据类型为int8时,输出值未定义。
expanded_row_idx:Tensor类型,expanded_x和x的映射关系, 要求是1D的Tensor,shape为(NUM_ROWS*K,),数据类型支持int32,数据格式要求为ND。当row_idx_type为0时,有效元素与无效元素共存,其中无效元素由-1填充;当row_idx_type为1时,前available_idx_num个元素有效,其余无效元素值未定义。其中available_idx_num为expert_idx中在active_expert_range范围的元素的个数。量化场景下,当x的数据类型为int8时,输出值未定义。
expert_token_cumsum_or_count:Tensor类型,数据类型支持int64,数据格式要求为ND。
在expert_tokens_num_type为1的场景下,要求是1D的Tensor,表示active_expert_range范围内每个expert对应的处理token的总数,shape为(expert_end-expert_start,)。
在expert_tokens_num_type为2的场景下,要求是2D的Tensor,shape为(expert_num, 2),表示active_expert_range范围内的每个expert的expert_idx及其对应处理的token总数。有效元素对是指expert_idx在active_expert_range范围内,且处理的token数不为0的元素对,这些有效元素对按原顺序存放在Tensor头部。如果有效元素对的数量少于expert_num,其后会跟一对元素对(0,0)以表示有效元素对的结束。
expanded_scale:Tensor类型,数据类型支持float32,数据格式要求为ND。令available_idx_num为expert_idx中在active_expert_range范围的元素的个数。
非量化场景下,即quant_mode为-1,shape为(NUM_ROWS*K,)。当scale未输入时,输出值未定义;当scale输入时,输出表示一个1D的Tensor,前available_idx_num个元素为有效数据,其余为无效数据。
动态量化场景下,即quant_mode为1,输出量化计算过程中scale的中间值,shape为(NUM_ROWS*K,)。输出表示一个1D的Tensor,前available_idx_num个元素为有效数据,其余为无效数据,若x的输入类型为int8,输出值未定义。
约束说明
该接口支持推理场景下使用。
该接口支持图模式。
不支持静态量化模式。
该接口在部分产品型号下,支持两种性能模板。进入两种性能模板需要分别额外满足以下条件,不满足条件则进入通用模板:
支持性能模板的产品型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas A3 训练系列产品/Atlas A3 推理系列产品
性能模板的约束条件:
进入低时延性能模板需要同时满足以下条件:
x、expert_idx、scale输入Shape要求分别为:(1, 7168)、(1, 8)、(256, 7168)
x数据类型要求:bfloat16
属性要求:active_expert_range=[0, 256]、 quant_mode=1、expert_tokens_num_type=2、expert_num=256
进入大batch性能模板需要同时满足以下条件:
NUM_ROWS范围为[384, 8192]
K=8
expert_num=256
expert_end-expert_start<=32
quant_mode=-1
row_idx_type=1
expert_tokens_num_type=1
支持的PyTorch版本
PyTorch 2.6
PyTorch 2.5
PyTorch 2.3
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例
单算子模式调用
import torch
import torch_npu
bs = 1
h = 613
k = 475
active_num = 475
expert_capacity = -1
expert_num = 226
drop_pad_mode = 0
expert_tokens_num_type = 1
expert_tokens_num_flag = True
quant_mode = -1
active_expert_range = [23, 35]
row_idx_type = 0
x = torch.randn((bs, h), dtype=torch.float32).npu()
expert_idx = torch.randint(0, expert_num, (bs, k), dtype=torch.int32).npu()
scale = torch.randn((bs,), dtype=torch.float32).npu()
offset = None
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
x, expert_idx, scale=scale, offset=offset,
active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num, drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type)
图模式调用
import torch
import torch.nn as nn
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
class MoeInitRoutingV2Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, expert_idx, *, scale=None, offset=None, active_num=-1, expert_capacity=-1,
expert_num=-1, drop_pad_mode=0, expert_tokens_num_type=0, expert_tokens_num_flag=False,
quant_mode=0, active_expert_range=0, row_idx_type=0):
return torch.ops.npu.npu_moe_init_routing_v2(x, expert_idx, scale=scale, offset=offset,
active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num, drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type)
def main():
bs = 1
h = 613
k = 475
active_num = 475
expert_capacity = -1
expert_num = 226
drop_pad_mode = 0
expert_tokens_num_type = 1
expert_tokens_num_flag = True
quant_mode = -1
active_expert_range = [23, 35]
row_idx_type = 0
x = torch.randn((bs, h), dtype=torch.float32).npu()
expert_idx = torch.randint(0, expert_num, (bs, k), dtype=torch.int32).npu()
scale = torch.randn((bs,), dtype=torch.float32).npu()
offset = None
moe_init_routing_v2_model = MoeInitRoutingV2Model().npu()
moe_init_routing_v2_model = torch.compile(moe_init_routing_v2_model, backend=npu_backend, dynamic=False)
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = moe_init_routing_v2_model(x,
expert_idx, scale=scale, offset=offset, active_num=active_num,
expert_capacity=expert_capacity, expert_num=expert_num, drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type)
if __name__ == '__main__':
main()
"""
)
_add_torch_npu_docstr(
"npu_moe_token_permute_with_routing_map",
"""
接口原型:
torch_npu.npu_moe_token_permute_with_routing_map(Tensor tokens, Tensor routing_map, *, Tensor? probs=None, int? num_out_tokens=None, bool drop_and_pad=False) -> (Tensor, Tensor, Tensor)
算子功能:
MoE的permute计算,将token和expert的标签作为routingMap传入,根据routing_map将tokens和可选probs广播后排序
接口说明:
输入:
tokens(Tensor,计算输入):输入toke,要求为一个维度为2D的Tensor,shape为 (tokens_num, hidden_size),数据类型支持BFLOAT16,FLOAT16,FLOAT,数据格式要求为ND。支持非连续的Tensor。
routing_map(Tensor ,计算输入):表token到expert的映射关系,要求shape为一个2D的(tokens_num,experts_num),数据类型支持INT8、BOOL。当数据类型为INT8,取值支持0、1,当数据类型为bool,取值支持true、false,数据格式要求为ND。支持非连续的Tensor。非droppad模式要求每行中包含topK个true 或 1。
probs(Tensor,计算输入):可选输入probs,关键字参数,默认值为None,要求元素个数与routing_map相同,当probs为None时,可选输出permute_probs_out_optional为空,数据类型同tokens。数据格式要求为ND。支持非连续的Tensor。
num_out_tokens(int64_t,计算输入):可选输入,默认值为token_num, 用于计算topK 和capacity 的有效输出token数。
drop_and_pad(bool,计算输入):可选输入,默认值为False,表示是否开启drop_and_pad模式。
输出:
permuted_tokens_out(Tensor,第一个输出):根据indices进行扩展并排序筛选过的tokens,要求是一个2D的Tensor,shape为(outToken , hidden_size)。数据类型同tokens,数据格式要求为ND。支持非连续的Tensor。
permute_probs_out_optional(Tensor,第二个输出):根据indices进行排序并筛选过的probs,Shape为(outToken),数据类型同probsOptional,数据格式要求为ND。支持非连续的Tensor。
sorted_indices_out(Tensor,第三个输出):permute_tokens和tokens的映射关系, 要求是一个1D的Tensor,Shape为(outToken),数据类型支持INT32,数据格式要求为ND。支持非连续的Tensor。
输入约束:
1. 由于float无损转int限制,tokens_num和experts_num要求小于16777215。
2. 由于UB限制,routing_map 中 每行为1或true的个数固定且小于512,num_out_tokens/num_tokens小于512。
3. drop_and_pad为False场景,num_out_tokens / num_tokens需和routing_map中每行1或True的个数一致。
调用示例:
import torch,torch_npu
x = torch.randn((3, 4), dtype=torch.float).npu()
rounting_map = torch.tensor(
[[True, True], [True, True], [True, True]], dtype=torch.bool).npu()
numtoken = 6
pad_mode = False
permuted_tokens_out, _, sorted_indices_out = torch_npu.npu_moe_token_permute_with_routing_map(x, rounting_map, num_out_tokens=numtoken, drop_and_pad=pad_mode)
"""
)
_add_torch_npu_docstr(
"npu_prefetch",
"""
接口原型:
torch_npu.npu_prefetch(Tensor input, Tensor? dependency, int max_size, int offset=0) -> ()
功能描述
提供网络weight预取功能,用于在计算执行前将指定的权重数据预先加载到L2 Cache中,减少算子访问这些权重时的访存等待时间。例如,在MatMul等算子之前进行预取,算子执行时可直接从低时延的L2 Cache中读取权重,进而提升算子数据访问与计算效率。实际性能收益取决于用户采用的并行方式和配置。
参数说明
input: Tensor类型, 表示需要预取的权重, 不做数据处理, 与数据类型和数据格式无关; 输入不能含有为None.
dependency: Tensor类型, 表示开始预取的节点, 单算子下不生效可为None, 图模式下不可为None; 不做数据处理, 与数据类型和数据格式无关.
max_size: int类型, 取值需大于0, 表示权重预取的最大size, 超过预取权重的size时, 会设置为权重的最大size. 数据类型为int32、int64.
offset: int类型, 默认值0, 取值大于等于0, 表示权重预取内存地址偏移, 不允许超过权重地址范围. 数据类型为int32、int64.
输出说明
无
约束说明
该接口支持图模式.
支持的PyTorch版本
Pytorch 2.5
PyTorch 2.4
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
调用示例:
单算子多流并发调用
import torch
import torch_npu
s_cmo = torch.npu.Stream()
x = torch.randn(10000, 10000, dtype=torch.float32).npu()
y = torch.randn(10000, 1, dtype=torch.float32).npu()
add = torch.add(x, 1)
with torch.npu.stream(s_cmo):
torch_npu.npu_prefetch(y, None, 10000000)
abs = torch.abs(add)
mul = torch.matmul(abs, abs)
out = torch.matmul(mul, y)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.debug.graph_dump.type = 'pbtxt'
npu_backend = tng.get_npu_backend(compiler_config=config)
x = torch.randn(10000, 10000, dtype=torch.float32).npu()
y = torch.randn(10000, 1, dtype=torch.float32).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self,x,y):
add = torch.add(x, 1)
torch_npu.npu_prefetch(y, add, 10000000)
abs = torch.abs(add)
mul = torch.matmul(abs, abs)
out = torch.matmul(mul, y)
return out
npu_model = Model().npu()
model = torch.compile(npu_model, backend=npu_backend, dynamic=False, fullgraph=True)
output = model(x,y)
"""
)
_add_torch_npu_docstr(
"npu_quantize",
"""
接口原型:
torch_npu.npu_quantize(Tensor input, Tensor scales, Tensor? zero_points, int dtype, int axis=1, bool div_mode=True) -> Tensor
功能描述
算子功能: 对输入的张量进行量化处理.
计算公式:
如果div_mode为True: result=(input/scales)+zero_points
如果div_mode为False: result=(input*scales)+zero_points
参数说明
input: Tensor类型, 需要进行量化的源数据张量, 数据格式支持ND、NZ, 支持非连续的Tensor. div_mode为False且dtype为torch.quint4x2时, 最后一维需要能被8整除.
Atlas 推理系列产品: 数据类型支持float、float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float、float16、bfloat16.
scales: Tensor类型, 对input进行scales的张量, 必选输入:
div_mode为True时
Atlas 推理系列产品: 数据类型支持float.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float、bfloat16.
div_mode为False时, 数据格式支持ND, 支持非连续的Tensor. 支持1维或多维(1维时, 对应轴的大小需要与input中第axis维相等或等于1; 多维时, scales的shape需要与input的shape维度相等, 除axis指定的维度, 其他维度为1, axis指定的维度必须和input对应的维度相等或等于1).
Atlas 推理系列产品: 数据类型支持float、float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float、float16、bfloat16.
zero_points: Tensor类型, 对input进行offset的张量, 可选输入.
div_mode为True时
Atlas 推理系列产品: 数据类型支持int8、uint8、int32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、uint8、int32、bfloat16.
div_mode为False时, 数据格式支持ND, 支持非连续的Tensor. 支持1维或多维(1维时, 对应轴的大小需要与input中第axis维相等或等于1; 多维时, scales的shape需要与input维度相等, 除axis指定的维度, 其他维度为1, axis指定的维度必须和input对应的维度相等). zero_points的shape和dtype需要和scales一致.
Atlas 推理系列产品: 数据类型支持float、float16.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持float、float16、bfloat16.
dtype: int类型, 指定输出参数的类型.
div_mode为True时,
Atlas 推理系列产品: 类型支持torch.qint8、torch.quint8、torch.int32.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 类型支持torch.qint8、torch.quint8、torch.int32.
div_mode为False时, 类型支持torch.qint8、torch.quint4x2. 如果dtype为torch.quint4x2时, 输出tensor类型为int32, 由8个int4拼接.
axis: int类型, 量化的elemwise轴, 其他的轴做broadcast, 默认值为1.
div_mode为False时, axis取值范围是[-2, +∞)且指定的轴不能超过输入input的维度数. 如果axis=-2, 代表量化的elemwise轴是输入input的倒数第二根轴; 如果axis大于-2, 量化的elemwise轴是输入的最后一根轴.
div_mode: 布尔类型, 表示计算scales模式. 当div_mode为True时, 表示用除法计算scales; div_mode为False时, 表示用乘法计算scales, 默认值为True.
输出说明
y: Tensor类型, 公式中的输出, 输出大小与input一致. 数据类型由参数dtype指定, 如果参数dtype为torch.quint4x2, 输出的dtype是torch.int32, shape的最后一维是输入shape最后一维的1/8, shape其他维度和输入一致.
约束说明
该接口支持推理场景下使用.
该接口支持图模式.
input数据格式为NZ时, input输入shape支持3维, 形如(e, k, n), scales输入shape支持1维, zero_points输入为None, dtype为quint4x2.
div_mode为False时:
支持Atlas A2 训练系列产品/Atlas 800I A2 推理产品.
当dtype为torch.quint4x2或者axis为-2时, 不支持Atlas 推理系列产品.
支持的PyTorch版本
PyTorch 2.4
PyTorch 2.3
PyTorch 2.1
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
x = torch.randn((2, 3, 12), dtype=torch.float).npu()
scale = torch.tensor(([3] * 12),dtype=torch.float).npu()
out = torch_npu.npu_quantize(x, scale, None, torch.qint8, -1, False)
print(out)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
x = torch.randn((2, 3, 12), dtype=torch.float16).npu()
scale = torch.tensor(([3] * 12),dtype=torch.float16).npu()
axis =1
div_mode = False
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, x, scale,zero_points, dst_type,div_mode):
return torch_npu.npu_quantize(x, scale, zero_points=zero_points, dtype=dst_type, div_mode=div_mode)
model = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
config.debug.graph_dump.type = 'pbtxt'
model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)
output_data = model(x, scale,None,dst_type=torch.qint8, div_mode=div_mode)
print(output_data)
"""
)
_add_torch_npu_docstr(
"npu_kronecker_quant",
"""
接口原型:
npu_kronecker_quant(Tensor x, Tensor kronecker_p1, Tensor kronecker_p2, float? clip_ratio=None, ScalarType? dst_dtype=None) -> (Tensor out, Tensor quant_scale)
功能描述
为矩阵x依次进行两次小矩阵乘法,然后针对矩阵乘的结果进行量化处理。
参数说明
x: Device侧的Tensor类型,表示输入;数据类型支持FLOAT16、BFLOAT16类型;shape为[K, M, N],其中N必须为8的整数倍。
kronecker_p1: Device侧的Tensor类型,表示输入;数据类型支持FLOAT16、BFLOAT16类型,数据类型与x一致;shape为[M, M],M与x第一维相同。
kronecker_p2: Device侧的Tensor类型,表示输入;数据类型支持FLOAT16、BFLOAT16类型,数据类型与x一致;shape为[N, N],N与x第二维相同。
clip_ratio: float类型,可选参数,数据范围为(0, 1],默认值为1。
dst_dtype:ScalarType类型,可选参数,输入值允许为torch.int32,默认值为torch.int32。
dst_type_max: float类型,可选参数,数据范围为0、[6, 12],默认值为0。
输出说明
out:Device侧的Tensor类型,表示量化输出;数据类型支持INT32;shape为[K, M, N/8],第零维和第一维与x一致,第二维是x的1/8。
quant_scale: Device侧的Tensor类型,表示量化缩放系数;数据类型支持FLOAT32;shape为[K],K与x第零维相同。
约束说明
输入数据类型仅支持float16和bfloat16,x、kronecker_p1和kronecker_p2数据类型要保持一致。
支持的型号
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
x = torch.rand((16, 3, 64), dtype=torch.bfloat16).npu()
p1 = torch.rand((3, 3), dtype=torch.bfloat16).npu()
p2 = torch.rand((64, 64), dtype=torch.bfloat16).npu()
out, quant_scale = torch_npu.npu_kronecker_quant(x, p1, p2, 0.7848)
"""
)
_add_torch_npu_docstr(
"scatter_update",
"""
接口原型:
torch_npu.scatter_update(Tensor data, Tensor indices, Tensor updates, int axis) -> Tensor
功能描述
将tensor updates中的值按指定的轴axis和索引indices更新tensor data中的值, 并将结果保存到输出tensor, data本身的数据不变.
参数说明
data: Tensor类型, data只支持2-8维, 且维度大小需要与updates一致; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas A3 训练系列产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas 训练系列产品: 数据类型支持int8、float16、float32、int32.
indices: Tensor类型, 数据类型支持int32、int64; 目前仅支持一维跟二维; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
updates: Tensor类型, updates的维度大小需要与data一致; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas A3 训练系列产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas 训练系列产品: 数据类型支持int8、float16、float32、int32.
axis: 整型, 用来scatter的维度, 数据类型为int64.
输出说明
out: Tensor类型, 计算输出, out只支持2-8维, 且维度大小需要与data一致; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas A3 训练系列产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas 训练系列产品: 数据类型支持int8、float16、float32、int32.
约束说明
data与updates的秩一致.
不支持索引越界, 索引越界不校验.
支持的PyTorch版本
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的型号
Atlas 训练系列产品
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用:
import torch
import torch_npu
import numpy as np
data = torch.tensor([[[[1,1,1,1,1,1,1,1],[2,2,2,2,2,2,2,2]]]], dtype=torch.float32).npu()
indices = torch.tensor ([1],dtype=torch.int64).npu()
updates = torch.tensor([[[[3,3,3,3,3,3,3,3]]]] , dtype=torch.float32).npu()
out = torch_npu.scatter_update(data, indices, updates, axis=-2)
"""
)
_add_torch_npu_docstr(
"scatter_update_",
"""
接口原型:
torch_npu.scatter_update_(Tensor(a!) data, Tensor indices, Tensor updates, int axis) -> Tensor(a!)
功能描述
将tensor updates中的值按指定的轴axis和索引indices更新tensor data中的值, 并将结果保存到输出tensor, data本身的数据被改变.
参数说明
data: Tensor类型, data只支持2-8维, 且维度大小需要与updates一致; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas A3 训练系列产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas 训练系列产品: 数据类型支持int8、float16、float32、int32.
indices: Tensor类型, 数据类型支持int32、int64; 目前仅支持一维跟二维; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
updates: Tensor类型, updates的维度大小需要与data一致; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas A3 训练系列产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas 训练系列产品: 数据类型支持int8、float16、float32、int32.
axis: 整型, 用来scatter的维度, 数据类型为int64.
输出说明
out: Tensor类型, 计算输出, 复用输入地址; out只支持2-8维, 且维度大小需要与data一致; 支持非连续的tensor; 数据格式支持ND; 不支持空Tensor.
Atlas A2 训练系列产品/Atlas 800I A2 推理产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas A3 训练系列产品: 数据类型支持int8、float16、float32、bfloat16、int32.
Atlas 训练系列产品: 数据类型支持int8、float16、float32、int32.
约束说明
data与updates的秩一致.
不支持索引越界, 索引越界不校验.
支持的PyTorch版本
PyTorch 2.3
PyTorch 2.2
PyTorch 2.1
PyTorch 1.11.0
支持的型号
Atlas 训练系列产品
Atlas A2 训练系列产品/Atlas 800I A2 推理产品
Atlas A3 训练系列产品
调用示例:
单算子模式调用:
import torch
import torch_npu
import numpy as np
data = torch.tensor([[[[1,1,1,1,1,1,1,1],[2,2,2,2,2,2,2,2]]]], dtype=torch.float32).npu()
indices = torch.tensor ([1],dtype=torch.int64).npu()
updates = torch.tensor([[[[3,3,3,3,3,3,3,3]]]] , dtype=torch.float32).npu()
out = torch_npu.scatter_update_(data, indices, updates, axis=-2)
"""
)
_add_torch_npu_docstr(
"npu_group_norm_swish",
"""
接口原型:
npu_group_norm_swish(Tensor input, int num_groups, Tensor weight, Tensor bias, float? eps=1e-5, float? swish_scale=1.0) -> (Tensor, Tensor, Tensor)
功能描述
计算输入input的组归一化结果,并进行swish计算。
参数说明
input: Device侧的Tensor类型,计算输入;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型;input只支持2-8维;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
num_groups:int类型, 计算输入;表示将input的第1维分为num_groups组,inpu的第1维必须能被num_groups整除。
weight: Device侧的Tensor类型,计算输入;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型,并且与input一致;input只支持1维,且第0维大小与input的第1维大小相同;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
bias: Device侧的Tensor类型,计算输入;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型,并且与input一致;input只支持1维,且第0维大小与input的第1维大小相同;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
eps: Float类型,可选;用于防止产生除0操作;默认值为1e-5。
swish_scale: Float类型,可选; 用于计算swish;默认值为1.0。
输出说明
out:Device侧的Tensor类型,计算输出;表示将输入组归一化的结果;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型。
mean: Device侧的Tensor类型,计算输出;表述将输入分组后的均值;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型,。
rstd: Device侧的Tensor类型,计算输出;表述将输入分组后的标准差的倒数;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型。
约束说明
BFLOAT16数据类型仅支持如下产品型号:Atlas A2训练系列产品/Atlas 800I A2推理产品
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例:
import torch
import torch_npu
input = torch.randn(3, 4, 6, dtype=torch.float32).npu()
weight = torch.randn(input.size(1), dtype=torch.float32).npu()
bias = torch.randn(input.size(1), dtype=torch.float32).npu()
num_groups = input.size(1)
swish_scale = 1.0
eps = 1e-5
out = torch_npu.npu_group_norm_swish(input, num_groups, weight, bias, eps=eps, swish_scale=swish_scale)
"""
)
_add_torch_npu_docstr(
"npu_dequant_swiglu_quant",
"""
功能描述:
- 对输入张量 x 进行反量化、Swiglu 激活计算及量化,输出量化后的结果 y 和量化 scale。
- 支持静态量化和动态量化两种模式;支持 per-token 激活 scale、bias、平滑量化系数、分组计算和 Swiglu 变种计算。
接口原型:
torch_npu.npu_dequant_swiglu_quant(
Tensor x,
*,
Tensor weight_scale=None,
Tensor? activation_scale=None,
Tensor? bias=None,
Tensor? quant_scale=None,
Tensor? quant_offset=None,
Tensor? group_index=None,
bool activate_left=False,
int quant_mode=0,
int swiglu_mode=0,
float clamp_limit=7.0,
float glu_alpha=1.702,
float glu_bias=1.0
) -> (Tensor y, Tensor scale)
>Tensor中shape使用的变量说明:
>- TokensNum:表示传输的Tokens数,取值≥0。
>- H:表示嵌入向量的长度,取值\>0。
>- groupNum:表示group\_index输入的长度,取值\>0。
- x:Tensor类型,表示目标张量。要求是2D的Tensor,shape为\[TokensNum, 2H\],尾轴为偶数。数据类型支持int32和bfloat16,数据格式为ND。
- weight\_scale:Tensor类型,可选参数,表示权重量化对应的反量化系数。要求是2D的Tensor,shape为\[groupNum, 2H\],数据类型支持float32,数据格式为ND。当x为int32时,要求该参数非None,表示需要做反量化。
- activation\_scale:Tensor类型,可选参数,表示per-token权重量化对应的反量化系数。要求是1D的Tensor,shape为\[TokensNum\],数据类型支持float32,数据格式为ND。当x为int32时,要求该参数非None,表示需要做反量化。
- bias:Tensor类型,可选参数,表示x的偏置变量。数据类型支持int32,数据格式为ND。group\_index场景下(非None),该参数不生效为None。
- quant\_scale:Tensor类型,可选参数,表示smooth量化系数。要求是2D的Tensor,shape为\[groupNum, H\],数据类型支持float32、float16和bfloat16,数据格式为ND。
- quant\_offset:Tensor类型,可选参数,表示量化中的偏移项。数据类型支持float32、float16和bfloat16,数据格式为ND。group\_index场景下(非None),该参数不生效为None。
- group\_index:Tensor类型,可选参数,当前只支持count模式,表示该模式下指定分组的Tokens数(要求非负整数)。要求是1D的Tensor,数据类型支持int64,数据格式ND。
- activate_left(bool):可选参数,是否进行左激活,默认 False。
- 取True时,out=swish\(split\[x, -1, 2\]\[0\]\)\*split\[x, -1, 2\]\[1\]
- 取False时,out=swish\(split\[x, -1, 2\]\[1\]\)\*split\[x, -1, 2\]\[0\]
- quant_mode(int):可选参数,量化模式,0 表示静态量化,1 表示动态量化。group_index 场景下必须取 1。
- swiglu_mode(int):可选参数,swiglu 计算模式,0 表示传统 swiglu,1 表示变种 swiglu(支持 clamp、alpha、bias)。
- clamp_limit(float):可选参数,swiglu 输入门限,默认 7.0。
- glu_alpha(float):可选参数,glu 激活函数系数,默认 1.702。
- glu_bias(float):可选参数,swiglu 计算中的偏差,默认 1.0。
输出说明:
- y:Tensor类型,表示量化后的输出tensor。要求是2D的Tensor,shape=\[TokensNum, H\],数据类型支持int8,数据格式为ND。
- scale:Tensor类型,表示量化的scale参数。要求是1D的Tensor,shape=\[TokensNum\],数据类型支持float32,数据格式为ND。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持图模式。
- group\_index场景下(非None)约束说明:
- group\_index只支持count模式,需要网络保证group\_index输入的求和不超过x的TokensNum维度,否则会出现越界访问。
- H轴有维度大小限制:H≤10496同时64对齐场景;规格不满足场景会进行校验。
- 输出y和scale超过group\_index总和的部分未进行清理处理,该部分内存为垃圾数据,可能会存在inf/nan异常值,网络使用的时候需要注意影响。
- 当 x 为 int32 时,必须提供 weight_scale。
- 当 x 为 float16 或 bfloat16 时,weight_scale、activation_scale、bias 必须为 None。
- x 的最后一维长度必须为偶数。
- 当激活维度不是 x 的最后一维时,group_index 必须为 None。
- 当 group_index 非 None 时,仅支持动态量化(quant_mode=1),且 bias、quant_offset 必须为 None。
- y 的类型仅支持 int8。
- clamp_limit、glu_alpha、glu_bias 仅在 swiglu_mode=1 时生效。
支持的芯片型号:
- <term>Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件</term>
- <term>Atlas A3 训练系列产品/Atlas A3 推理系列产品</term>
调用示例:
- 单算子模式调用
import os
import shutil
import unittest
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUDequantSwigluQuant(TestCase):
def test_npu_dequant_swiglu_quant(self, device="npu"):
tokens_num = 4608
hidden_size = 2048
x = torch.randint(-10, 10, (tokens_num, hidden_size), dtype=torch.int32)
weight_scale = torch.randn((1, hidden_size), dtype=torch.float32)
activation_scale = torch.randn((tokens_num, 1), dtype=torch.float32)
quant_scale = torch.randn((1, hidden_size // 2), dtype=torch.float32)
group_index = torch.tensor([tokens_num], dtype=torch.int64)
bias = None
y, scale = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activation_scale.npu(),
bias=None,
quant_scale=quant_scale.npu(),
quant_offset=None,
group_index=group_index.npu(),
activate_left=True,
quant_mode=1,
swiglu_mode=1,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0
)
if __name__ == "__main__":
run_tests()
- 图模式调用
```python
import os
import shutil
import unittest
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
from torchair.configs.compiler_config import CompilerConfig
import torchair as tng
config = CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
npu_backend = tng.get_npu_backend(compiler_config=config)
class TestNPUDequantSwigluQuant(TestCase):
def test_npu_dequant_swiglu_quant(self, device="npu"):
tokens_num = 4608
hidden_size = 2048
x = torch.randint(-10, 10, (tokens_num, hidden_size), dtype=torch.int32)
weight_scale = torch.randn((1, hidden_size), dtype=torch.float32)
activation_scale = torch.randn((tokens_num, 1), dtype=torch.float32)
quant_scale = torch.randn((1, hidden_size // 2), dtype=torch.float32)
group_index = torch.tensor([tokens_num], dtype=torch.int64)
bias = None
y, scale = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activation_scale.npu(),
bias=None,
quant_scale=quant_scale.npu(),
quant_offset=None,
group_index=group_index.npu(),
activate_left=True,
quant_mode=1,
swiglu_mode=1,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0
)
if __name__ == "__main__":
run_tests()
"""
)
_add_torch_npu_docstr(
"npu_clipped_swiglu",
"""
接口原型:
torch_npu.npu_clipped_swiglu(x, *, group_index=None, dim=-1, alpha=1.702, limit=7.0, bias=1.0, interleaved=True) -> Tensor
功能描述:
新增带截断的Swish门控线性单元激活函数,实现x的变体SwiGlu计算。
计算公式:
(1)将x基于输入参数dim进行合轴,合轴后维度为[pre, cut, after]。其中cut轴为合轴之后需要切分为两个张量的轴,切分方式分为前后切分或者奇偶切分;pre,after 可以等于1。此外,由于after轴的元素为连续存放,且计算操作为逐元素的,因此将cut轴与after轴合并,得到x的维度为[pre, cut * after]。
(2)根据输入参数group_index, 对x的pre轴进行过滤处理,公式如下:sum = Sum(group_index), x = x[ : sum, : ]。其中sum表示group_index的所有元素之和。当不输入 group_index 时,跳过该步骤。
(3)根据输入参数interleaved,对x进行切分,公式如下:当 interleaved 为 true 时,表示奇偶切分:A = x[ : , : : 2], B = x[ : , 1 : : 2]
当 interleaved 为 false 时,表示前后切分:h = x.shape[1] // 2, A = x[ : , : h], B = x[ : , h : ]
(4)根据输入参数 alpha、limit、bias 进行变体SwiGlu计算,公式如下:A = A.clamp(min=None, max=limit), B = B.clamp(min=-limit, max=limit)
y_glu = A * sigmoid(alpha * A)
y = y_glu * (B + bias)
(5)重塑输出张量y的维度数量与合轴前的x的维度数量一致,第dim轴上的大小为x的一半,其他维度与x相同。
参数说明:
x:Tensor类型,必选参数,表示目标张量。数据类型支持float16、bfloat16、float32,不支持非连续的Tensor,数据格式为ND,x的维数必须大于1维,第dim轴为偶数。
group_index:Tensor类型,可选参数,表示对x进行分组的情况。要求为1维张量,第i个元素代表第i组需要处理的x合轴后的token数量,数据类型支持int64,数据格式ND。
dim: int类型,可选参数,表示需要对x进行切分的维度序号,取值范围为[-x.dim(), x.dim()-1],默认 -1。
alpha:float类型,可选参数,表示glu激活函数系数,默认 1.702。
limit:float类型,可选参数,表示变体swiglu输入门限,默认 7.0。
bias:float类型,可选参数,表示变体swiglu计算中的偏差,默认 1.0。
interleaved: bool类型,可选参数,表示输入x是否按奇偶方式切分,true表示为奇偶方式切分,false表示为前后方式切分,默认为true。
输出说明:
y:Tensor类型,表示激活函数的输出,数据类型同输入x,在维度上,第dim维是输入x的1/2,其余维度与输入x相同,数据格式为ND。
约束说明:
无
支持的型号:
Atlas A2 推理系列产品
Atlas A3 推理系列产品
调用示例:
单算子调用
import torch
import torch_npu
tokens_num = 4608
hidden_size = 2048
x = torch.randint(-10, 10, (tokens_num, hidden_size), dtype=torch.float32)
group_index = torch.randint(1, 101, (4, ), dtype=torch.int64)
y = torch_npu.npu_clipped_swiglu(
x.npu(),
group_index=group_index.npu(),
dim=-1,
alpha=1.702,
limit=7.0,
bias=1.0,
interleaved=True
)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device = torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class ClippedSwigluModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, group_index, dim, alpha, limit, bias, interleaved):
y = torch_npu.npu_clipped_swiglu(
x.npu(),
group_index=group_index.npu(),
dim=dim,
alpha=alpha,
limit=limit,
bias=bias,
interleaved=interleaved
)
return y
tokens_num = 4608
hidden_size = 2048
x = torch.randint(-10, 10, (tokens_num, hidden_size), dtype=torch.float32)
group_index = torch.randint(1, 101, (4, ), dtype=torch.int64)
clipped_swiglu_model = ClippedSwigluModel().npu()
clipped_swiglu_model = torch.compile(clipped_swiglu_model, backend=npu_backend, dynamic=True)
y = clipped_swiglu_model(x, group_index, -1, 1.702, 7.0, 1.0, True)
"""
)
_add_torch_npu_docstr(
"npu_fused_causal_conv1d",
"""
接口原型:
torch_npu.npu_fused_causal_conv1d(x, weight, conv_states, *, query_start_loc=None, cache_indices=None,
initial_state_mode=None, bias=None, num_accepted_tokens=None,
activation="None", pad_slot_id=-1, run_mode=0,
residual_connection=0, max_query_len=-1,
num_computed_tokens=None, block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None, initial_state_idx=None,
block_size=128, conv_mode="default") -> Tensor
功能描述:
对序列执行因果一维卷积(非原地版本)。支持 prefill/decode/PD混部场景,支持 APC 模式。
沿序列维度使用缓存数据对各序列头部进行 padding,卷积完成后将序列尾部数据原地更新到 conv_states 中。
输出 y 为新分配的 Tensor。
参数说明:
x:Tensor类型,必选参数,输入序列。数据类型支持float16、bfloat16,数据格式为ND。shape为2维[cu_seq_len, dim]或3维[batch, seq_len, dim]。dim需满足dim % 128 == 0。
weight:Tensor类型,必选参数,因果1维卷积核。shape为2维[K, dim],K固定为3。
conv_states:Tensor类型,必选参数,缓存状态张量。shape为3维[..., state_len, dim],计算完成后原地更新。
query_start_loc:Tensor类型,可选参数,序列起始位置索引。shape为1维[batch+1],x为2维时不可省略。
cache_indices:Tensor类型,可选参数,缓存索引。APC未开启时为1维[batch];APC开启时为2维[batch, max_num_blocks]。
initial_state_mode:Tensor类型,可选参数,初始状态标志。shape为1维[batch],暂不支持。
bias:Tensor类型,可选参数,卷积偏置。shape为1维[dim],暂不支持。
num_accepted_tokens:Tensor类型,可选参数,投机解码场景下各batch实际接受的token个数。shape为1维[batch]。
activation:str类型,可选参数,激活函数类型,支持"None"、"silu"、"swish",默认值为"None",暂不支持。
pad_slot_id:int类型,可选参数,用于跳过不需要参与计算的batch,默认值为-1。
run_mode:int类型,可选参数,历史遗留接口,默认值为0,暂不支持。
residual_connection:int类型,可选参数,是否做残差连接。0:不做;1:做。默认值为0。
max_query_len:int类型,可选参数,所有batch seq_len的最大值,默认值为-1。
num_computed_tokens:Tensor类型,可选参数,当前batch已处理的token总数。shape为1维[batch]。Pangu模式下不能为None。
block_idx_first_scheduled_token:Tensor类型,可选参数,APC开启时当前batch首token对应的block索引。shape为1维[batch]。
block_idx_last_scheduled_token:Tensor类型,可选参数,APC开启时当前batch末token对应的block索引。shape为1维[batch]。
initial_state_idx:Tensor类型,可选参数,APC开启时初始索引块的索引。shape为1维[batch]。
block_size:int类型,可选参数,APC block大小,支持128/256,默认值为128。
conv_mode:str类型,可选参数,卷积模式。"default":正常卷积计算;"pangu":盘古模型下卷积计算前k-1个token置零。默认值为"default"。
输出说明:
y:Tensor类型,卷积计算结果。residual_connection=1时输出为卷积结果+输入x。shape与x一致。
conv_states将被原地更新。
约束说明:
x为2维时query_start_loc不可省略。
APC开启时(cache_indices为2维):block_size不能为0;必须提供block_idx_first_scheduled_token、block_idx_last_scheduled_token、initial_state_idx。
dim必须满足dim % 128 == 0。
conv_mode为"pangu"时num_computed_tokens不能为None。
支持的型号:
昇腾950 AI处理器
调用示例:
import torch
import torch_npu
K, dim, dtype = 3, 128, torch.bfloat16
weight = torch.randn(K, dim, dtype=dtype).npu()
cu_seq_len = sum([5, 3, 7, 4])
x = torch.randn(cu_seq_len, dim, dtype=dtype).npu()
query_start_loc = torch.tensor([0, 5, 8, 15, 19], dtype=torch.int32).npu()
conv_states = torch.randn(8, K - 1, dim, dtype=dtype).npu()
cache_indices = torch.tensor([0, 3, 1, 5], dtype=torch.int32).npu()
out = torch_npu.npu_fused_causal_conv1d(
x, weight, conv_states,
query_start_loc=query_start_loc,
cache_indices=cache_indices,
residual_connection=1,
pad_slot_id=-1,
)
"""
)
_add_torch_npu_docstr(
"npu_masked_causal_conv1d",
"""
接口原型:
torch_npu.npu_masked_causal_conv1d(input, weight, *, mask=None) -> Tensor
功能描述:
带掩码的因果一维深度可分离卷积(Masked Causal Depthwise Conv1d)。对输入张量`input`沿序列维度执行窗口大小为3的因果卷积(仅使用当前及历史时刻的输入),并可选地通过`mask`对输出进行掩码置零。
参数说明:
input:Tensor类型,必选参数,表示卷积输入张量。数据类型支持float16、bfloat16,不支持空Tensor,数据格式为ND。shape为3维[S, B, H],S为序列长度,B为批量大小,H为特征维度,H须为64的倍数。
weight:Tensor类型,必选参数,表示卷积权重张量。数据类型与input相同,不支持空Tensor,数据格式为ND。shape为2维[W, H],W为卷积核宽度,固定为3;H为特征维度,须与input的H一致。
mask:Tensor类型,可选参数,表示卷积输出的掩码。数据类型支持bool,数据格式为ND。shape为2维[B, S],True表示有效位置,False表示需要置零的位置。默认值为None,表示不进行掩码操作。
输出说明:
output:Tensor类型,因果卷积的输出结果,shape和数据类型与input一致,数据格式为ND。
约束说明:
该接口支持推理和训练场景下使用。
该接口支持单算子模式和图模式。
input和weight的数据类型必须一致。
weight的W维度目前只支持3。
H须为64的倍数。
不支持非连续的input、weight、mask张量。
支持的型号:
昇腾950 AI处理器
调用示例:
import torch
import torch_npu
S, B, H, W = 2048, 4, 768, 3
input = torch.randn(S, B, H, dtype=torch.bfloat16).npu()
weight = torch.randn(W, H, dtype=torch.bfloat16).npu()
mask = torch.rand(B, S).npu() > 0.3
output = torch_npu.npu_masked_causal_conv1d(input, weight, mask=mask)
"""
)
_add_torch_npu_docstr(
"npu_masked_causal_conv1d_backward",
"""
接口原型:
torch_npu.npu_masked_causal_conv1d_backward(grad_output, input, weight, *, mask=None) -> (Tensor, Tensor)
功能描述:
torch_npu.npu_masked_causal_conv1d的反向算子。计算带掩码的因果一维深度可分离卷积对输入`input`和权重`weight`的梯度。该接口通常由PyTorch自动微分机制自动调用,也可手动调用用于自定义梯度计算。
参数说明:
grad_output:Tensor类型,必选参数,表示前向输出的梯度张量。数据类型支持float16、bfloat16,不支持空Tensor,数据格式为ND。shape为3维[S, B, H]。
input:Tensor类型,必选参数,表示前向的卷积输入张量。数据类型与grad_output相同,不支持空Tensor,数据格式为ND。shape须与grad_output一致。
weight:Tensor类型,必选参数,表示前向的卷积权重张量。数据类型与grad_output相同,不支持空Tensor,数据格式为ND。shape为2维[W, H],W固定为3。
mask:Tensor类型,可选参数,表示前向卷积输出的掩码,与前向调用时传入的mask一致。数据类型支持bool,数据格式为ND。shape为2维[B, S]。默认值为None。
输出说明:
grad_input:Tensor类型,对前向输入input的梯度,shape和数据类型与grad_output一致,数据格式为ND。
grad_weight:Tensor类型,对前向权重weight的梯度,shape为[W, H],W=3,数据类型与grad_output一致,数据格式为ND。
约束说明:
该接口支持训练场景下使用。
该接口支持单算子模式和图模式。
grad_output、input、weight的数据类型必须一致。
weight的W维度目前只支持3。
不支持非连续的grad_output、input、weight、mask张量。
该接口已通过derivatives.yaml与npu_masked_causal_conv1d绑定,PyTorch自动微分时会自动调用,通常无需手动调用。
支持的型号:
昇腾950 AI处理器
调用示例:
import torch
import torch_npu
S, B, H, W = 2048, 4, 768, 3
grad_output = torch.randn(S, B, H, dtype=torch.bfloat16).npu()
input = torch.randn(S, B, H, dtype=torch.bfloat16).npu()
weight = torch.randn(W, H, dtype=torch.bfloat16).npu()
mask = torch.rand(B, S).npu() > 0.3
grad_input, grad_weight = torch_npu.npu_masked_causal_conv1d_backward(
grad_output, input, weight, mask=mask
)
"""
)
_add_torch_npu_docstr(
"npu_cross_entropy_loss",
"""
接口原型:
torch_npu.npu_cross_entropy_loss(Tensor input, Tensor target, Tensor? weight=None, str reduction="mean", int ignore_index=-100, float label_smoothing=0.0, float lse_square_scale_for_zloss=0.0, bool return_zloss=False) -> (Tensor, Tensor, Tensor, Tensor)
功能描述:
将原生CrossEntropyLoss中的log_softmax和nll_loss融合,降低计算时使用的内存。接口允许计算zloss。
参数说明:
input: Device侧的Tensor类型,表示输入;数据类型支持FLOAT16、FLOAT32、BFLOAT16类型;shape为[N, C],N为批处理大小,C为标签数,必须大于0。
target: Device侧的Tensor类型,表示标签;数据类型支持INT64类型;shape为[N],与input第零维相同,取值范围[0, C)。
weight: Device侧的Tensor类型,表示每个类别指定的缩放权重,可选;数据类型支持FLOAT32类型;shape为[C],与input第一维相同,取值范围(0, 1],不指定值时默认全一。
reduction: str类型,表示loss的归约方式;支持范围["mean", "sum", "none"],默认为"mean"。
ignore_index: int类型,指定忽略的标签;数值必须小于C,当小于0时视为无忽略标签;默认值为-100。
label_smoothing: float类型,表示计算loss时的平滑量;取值范围[0.0, 1.0);默认值为0.0。
lse_square_scale_for_zloss: float类型,表示计算zloss所需要的scale;取值范围[0.0, 1.0);默认值为0.0;当前暂不支持。
return_zloss: bool类型,控制是否返回zloss;设置为True将返回zloss,设置为False时不返回zloss;默认值为False;当前暂不支持。
输出说明:
loss:Device侧的Tensor类型,表示输出损失;数据类型与input相同;reduction为"none"时shape为[N],与input第零维一致,否则shape为[1]。
log_prob: Device侧的Tensor类型,输出给反向计算的输出;数据类型与input相同;shape为[N, C],与input一致。
zloss: Device侧的Tensor类型,表示辅助损失;数据类型与input相同;shape与loss一致;当return_zloss为True时输出zloss,否则将返回空tensor;当前暂不支持。
lse_for_zloss: Device侧的Tensor类型,zloss场景输出给反向计算的输出;数据类型与input相同;shape为[N],与input第零维一致;lse_square_scale_for_zloss不为0.0时将返回该输出,否则将返回空tensor;当前暂不支持。
约束说明:
输入shape中N取值范围(0, 200000]。
当input.requires_grad=True时,sum/none模式下不支持修改label_smoothing的默认值;mean模式下不支持修改所有含默认值的入参的值,包括weight,reduction,ignor_index,label_smoothing,lse_square_scale_for_zloss和return_zloss。
属性lse_square_scale_for_zloss与return_zloss暂未使能。
输出zloss与lse_for_zloss暂未使能。
输出中仅loss和zloss支持梯度计算。
支持的型号:
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例:
import torch
import torch_npu
N = 4096
C = 8080
input = torch.randn(N, C).npu()
target = torch.arange(0, N).npu()
loss, log_prob, _, _ = torch_npu.npu_cross_entropy_loss(input, target)
"""
)
_add_torch_npu_docstr(
"npu_chunk_gated_delta_rule",
"""
功能描述:
Chunked Gated Delta Rule(CGDR)是GDR的chunk版实现,它通过将输入序列切块,实现了一定的并行效果,在长上下文场景其计算效率相对Recurrent Gated Delta Rule更高,适用于prefill阶段。
接口原型:
npu_chunk_gated_delta_rule(Tensor query, Tensor key, Tensor value, *, Tensor? beta=None, Tensor? initial_state=None, Tensor? actual_seq_lengths=None, float? scale=None, Tensor? g=None) -> (Tensor, Tensor)
参数说明:
令 $B$ 表示batch size,$L_i$ 表示第i个序列的长度,$T=\sum_i^B L_i$ 表示累积序列长度。$N_k$ 表示key的头数,$N_v$ 表示value的头数,$D_k$ 表示key向量的维度,$D_v$ 表示value向量的维度。
- query (Tensor):必选输入,对应公式中的q,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nk, Dk)。
- key (Tensor):必选输入,对应公式中的k,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nk, Dk)。
- value (Tensor):必选输入,对应公式中的v,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nv, Dv)。
- beta (Tensor):必选输入,对应公式中的β,数据类型支持bfloat16,数据格式支持ND,shape为(T, Nv)。
- initial_state (Tensor):必选输入&输出,对应公式中的状态矩阵S,数据类型支持bfloat16,数据格式支持ND,shape为(BlockNum, Nv, Dv, Dk)。
- actual_seq_lengths (Tensor):必选输入,各batch的输入序列长度。数据类型支持int32,数据格式支持ND,shape为(B,)。
- g (Tensor):必选输入,衰减系数,对应公式中的α=e^g。默认为None,表示全0。数据类型支持float32,数据格式支持ND,shape为(T, Nv)。
- scale_value (Scalar):必选输入,query的缩放因子,对应公式中的 $1/\sqrt{d_k}$。数据类型支持float32。
输出说明:
注意力计算结果。输出的数据类型为bfloat16,数据格式为ND,形状为(T, Nv, Dv)。
约束说明:
- 该接口支持推理场景下使用。
- 该接口支持静态图模式。
- 输入shape大小需满足约束:$L_i \le 8$,$N_k \le 64$,$N_v \le 64$,$D_k = 128$,$D_v = 128$,$N_v$是$N_k$整数倍。
支持的PyTorch版本:
PyTorch 2.1 及更高版本
支持的型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例:
单算子模式调用
import torch
import torch_npu
# 构造输入
B, seqlen, nk, nv, dk, dv = (2, 100, 4, 8, 128, 128)
actual_seq_lengths = (torch.ones(B) * seqlen).to("npu").to(torch.int32)
T = int(torch.sum(actual_seq_lengths))
state = torch.rand((B, nv, dv, dk), dtype=torch.bfloat16).npu()
query = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
key = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
value = torch.rand((T, nv, dv), dtype=torch.bfloat16).npu()
g = torch.rand((T, nv), dtype=torch.float32).npu() * (-1.0)
beta = torch.rand((T, nv), dtype=torch.bfloat16).npu()
query = torch.nn.functional.normalize(query, p=2, dim=-1)
key = torch.nn.functional.normalize(key, p=2, dim=-1)
scale = dk ** -0.5
# 调用算子
o, final_state = torch_npu.npu_chunk_gated_delta_rule(
query, key, value,
beta=beta,
initial_state=state,
actual_seq_lengths=actual_seq_lengths,
scale=scale,
g=g)
print(o.shape, final_state.shape)
"""
)
_add_torch_npu_docstr(
"npu_gemma_rms_norm",
"""
接口原型:
npu_gemma_rms_norm(Tensor input, Tensor gamma, float epsilon=1e-06) -> (Tensor, Tensor)
功能描述
通过对数据的root mean square(RMS)进行归一化,避免均值的使用
参数说明
input: Device侧的Tensor类型,表示输入的需要归一化的数据。shape支持1-8维度,数据格式支持ND。数据类型支持FLOAT32、FLOAT16、BFLOAT16。
gamma: Device侧的Tensor类型,表示数据缩放因子;shape支持1-8维度,数据格式支持ND。shape需要满足gamma_shape = input_shape\[n:\], n < input_shape.dims()。数据类型支持FLOAT32、FLOAT16、BFLOAT16,与input数据类型保持一致。
epsilon: float数据类型,用于防止除0错误。默认值1e-06。
输出说明
y:Device侧的Tensor类型,表示归一化后的输出数据。shape支持1-8维度,数据格式支持ND。数据类型支持FLOAT32、FLOAT16、BFLOAT16,与输入input数据类型保持一致。
rstd: Device侧的Tensor类型,输入input数据的标准差;shape支持1-8维度,数据格式支持ND。数据类型支持FLOAT32、FLOAT16、BFLOAT16,与输入input数据类型保持一致。shape与输入input的shape前几维保持一致,前几维指输入input的维度减去输入gamma的维度,表示不需要norm的维度。
约束说明
不支持空进空出
不支持非连续tensor
支持的型号
Atlas A2训练系列产品/Atlas 800I A2中的推理产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
input_x = torch.randn([20, 10, 64], dtype=torch.float32).npu()
input_gamma = torch.randn([64], dtype=torch.float32).npu()
y, rstd = torch_npu.npu_gemma_rms_norm(input_x, input_gamma)
"""
)
_add_torch_npu_docstr(
"npu_add_rms_norm_dynamic_quant",
"""
接口原型:
npu_add_rms_norm_dynamic_quant(Tensor x1, Tensor x2, Tensor gamma, *, Tensor? smooth_scale1=None, Tensor? smooth_scale2=None, Tensor? beta=None, float epsilon=1e-6, bool[2] output_mask=[], ScalarType? y_dtype=None) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述
将RmsNorm前的Add操作、RmsNorm归一化和最多两路DynamicQuant量化进行融合,减少数据搬入搬出操作。
参数说明
x1: Tensor类型,标准化输入张量。shape支持2-8维,数据类型支持FLOAT16、BFLOAT16,格式支持ND。不支持空Tensor。
x2: Tensor类型,标准化输入张量。shape支持2-8维,数据类型支持FLOAT16、BFLOAT16,格式支持ND。不支持空Tensor。
gamma: Tensor类型,归一化权重张量。shape为1维,需与x1最后一维一致,数据类型与x1一致。不支持空Tensor。
smooth_scale1: Tensor类型,第一路量化的smooth_scale张量。可选,shape和数据类型与gamma一致。不支持空Tensor。
smooth_scale2: Tensor类型,第二路量化的smooth_scale张量。可选,shape和数据类型与gamma一致,需与smooth_scale1配合使用。不支持空Tensor。
beta: Tensor类型,归一化偏置项。可选,shape和数据类型与gamma一致。不支持空Tensor
epsilon: double类型,防止除0错误,默认值为1e-6.
output_mask: 数组,表示输出的掩码,数据类型支持BOOL,支持空指针,或长度为2的数组
y_dtype: ScalarType类型,y1/y2的量化输出数据类型。None或torch.int8表示INT8(默认),shape与x1一致;torch.quint4x2表示INT4,输出为torch.int32(8个int4打包为1个int32),shape最后一维为x1最后一维的1/8。
输出说明
y1: Tensor类型,第一路量化输出。y_dtype=torch.int8时dtype为int8、shape与x1一致;y_dtype=torch.quint4x2时dtype为int32、shape最后一维为x1最后一维/8。支持非连续的Tensor,不支持空Tensor。
y2: Tensor类型,第二路量化输出。dtype、shape与y1一致。支持非连续的Tensor,不支持空Tensor。若未输入smooth_scale2,此输出无实际意义。
x_out: Tensor类型,x1与x2之和。shape、数据类型与x1一致。
scale1: Tensor类型,第一路量化scale输出。shape为x1除最后一维后的shape,数据类型为float32,数据格式支持ND,支持非连续的Tensor,不支持空Tensor
scale2: Tensor类型,第二路量化scale输出,shape同scale1,数据类型为float32, 数据类型支持ND,支持非连续的Tensor,不支持空Tensor。若未输入smooth_scale2,此输出无实际意义。
约束说明
所有输入输出Tensor的数据格式推荐使用ND格式,其他数据格式会由框架默认转换成ND格式进行处理。
当outputMaskOptional不为空时,参数smoothScale1Optional有值时,则outputMaskOptional[0]必须为True。参数smoothScale2Optional有值时,则outputMaskOptional[1]必须为True。
当outputMaskOptional不为空时,outputMaskOptional[0]与outputMaskOptional[1]不能同时为False。
当outputMaskOptional为空时,参数smoothScale2Optional有值时,参数smoothScale1Optional也必须有值。
如果y2Out为有效输出时,shape需要与y1Out保持一致;如果y2Out为无效输出时,shape为[1]。
支持的型号
Atlas A3训练系列产品/Atlas A3推理系列产品
Atlas A2训练系列产品/Atlas 800I A2推理产品/A200I A2 Box异构组件
调用示例:
import torch
import torch_npu
x1 = torch.randn(3, 4, 8, dtype=torch.float16, device='npu')
x2 = torch.randn(3, 4, 8, dtype=torch.float16, device='npu')
gamma = torch.ones(8, dtype=torch.float16, device='npu')
beta = torch.zeros(8, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(8, dtype=torch.float16, device='npu')
smooth_scale2 = torch.ones(8, dtype=torch.float16, device='npu')
epsilon = 1e-6
output_mask = [True, True]
y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
x1_n, x2_n, gamma_n,
smooth_scale1=s1_n,
smooth_scale2=s2_n,
beta=beta_n,
epsilon=eps_f,
output_mask=output_mask,
)
"""
)
_add_torch_npu_docstr(
"npu_add_rms_norm_cast",
"""
接口原型:
npu_add_rms_norm_cast(Tensor x1, Tensor x2, Tensor gamma, float epsilon=1e-06) -> (Tensor, Tensor, Tensor, Tensor)
功能描述
add_rms_norm和cast的融合算子,对add_rms_norm计算后的输出做指定类型的cast操作,减少搬入搬出。
参数说明
x1:Device侧的Tensor类型,需要归一化的原始数据输入。shape支持1-8维。数据类型支持BFLOAT16、FLOAT16,数据格式支持ND。不支持空tensor。
x2:Device侧的Tensor类型,需要归一化的原始数据输入。shape支持1-8维,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。shape、数据格式、数据类型均需要与入参x1保持一致。不支持空tensor。
gamma:Device侧的Tensor类型,数据缩放因子。shape支持1-8维,数据格式支持ND,数据类型支持FLOAT16、BFLOAT16。shape需要满足gamma_shape = x_shape\[n:\], n < x_shape.dims()。数据类型、数据格式需要与入参x1保持一致。不支持空tensor。
epsilon:float数据类型,用于防止除0错误,数据类型为DOUBLE,默认值为1e-6。
输出说明
y1:Device侧的Tensor类型,归一化后经过类型转换的输出数据。shape支持1-8维,数据格式支持ND,数据类型支持FLOAT32。shape、数据格式需要与入参x1保持一致。不支持空tensor。
y2:Device侧的Tensor类型,归一化后的输出数据。shape支持1-8维,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。shape、数据格式、数据类型均需要与入参x1保持一致。不支持空tensor。
rstd:Device侧的Tensor类型,x的标准差。数据类型支持FLOAT32,shape支持1-8维。shape与入参x1的shape前几维保持一致,前几维指x1的维度减去gamma的维度,表示不需要norm的维度。数据格式支持ND,需要与入参x1的数据格式保持一致。不支持空tensor。
x:Device侧的Tensor类型,归一化的数据和。shape支持1-8维,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。shape、数据格式、数据类型均需要与入参x1保持一致。不支持空tensor。
支持的型号
Atlas A2训练系列产品/Atlas 800I A2中的推理产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
input_x1 = torch.randn([20, 10, 64], dtype=torch.float16).npu()
input_x2 = torch.randn([20, 10, 64], dtype=torch.float16).npu()
input_gamma = torch.randn([64], dtype=torch.float16).npu()
y1, y2, rstd, x = torch_npu.npu_add_rms_norm_cast(input_x1, input_x2, input_gamma)
"""
)
_add_torch_npu_docstr(
"npu_add_rms_norm_quant",
"""
接口原型:
func: npu_add_rms_norm_quant(Tensor x1, Tensor x2, Tensor gamma, Tensor scales1,
Tensor? zero_points1, Tensor? beta, Tensor? scales2=None, Tensor? zero_points2=None, *, int axis=-1,
float epsilon=1e-06, bool div_mode=True) -> (Tensor, Tensor, Tensor)
功能描述
add_rms_norm_quant算子将rms_norm前的add算子以及之后的quantize算子融合,减少搬入搬出。新增偏置项beta参数
参数说明
x1:Device侧的Tensor类型,表示标准化过程中的源数据张量。shape支持1-8维。数据类型支持BFLOAT16、FLOAT16,数据格式支持ND,支持非连续的Tensor。不支持空tensor。
x2:Device侧的Tensor类型,表示标准化过程中的源数据张量。shape支持1-8维。数据类型支持BFLOAT16、FLOAT16,数据格式支持ND,支持非连续的Tensor。数据格式、数据类型均需要与入参x1保持一致。不支持空tensor。
gamma:Device侧的Tensor类型,表示标准化过程中的权重张量。shape支持1-8维,shape需要与x1需要Norm的维度一致。数据类型支持BFLOAT16、FLOAT16,数据格式支持ND,支持非连续的Tensor。数据类型需要与入参x1保持一致。不支持空tensor。
scales1:Device侧的Tensor类型,表示量化过程中得到y1进行的scales张量。shape需要与gamma保持一致。数据类型支持FLOAT32、BFLOAT16,数据格式支持ND,支持非连续的Tensor。当参数divMode的值为True时,该参数的值不能为0。
zero_points1:Device侧的Tensor类型,表示量化过程中得到y1进行的offset张量。可选参数。shape需要与gamma保持一致。数据类型支持INT32、BFLOAT16,数据格式支持ND,支持非连续的Tensor。
beta:Device侧的Tensor类型,表示标准化过程中的偏置项。可选参数。shape支持1-8维,shape需要与gamma的shape保持一致。数据类型支持BFLOAT16、FLOAT16,数据类型需要与gamma保持一致,数据格式支持ND,支持非连续的Tensor。
scales2:Device侧的Tensor类型,表示量化过程中得到y2进行的scales张量。可选参数。shape需要与gamma保持一致。数据类型支持FLOAT32、BFLOAT16,数据类型需要与scales1保持一致。数据格式支持ND,支持非连续的Tensor。当参数divMode的值为True时,该参数的值不能为0。
zero_points2:Device侧的Tensor类型,表示量化过程中得到y2进行的offset张量。可选参数。shape需要与gamma保持一致。数据类型支持INT32、BFLOAT16,数据类型需要与zero_points1保持一致。数据格式支持ND,支持非连续的Tensor。
axis:Host侧的整型,表示需要进行量化的elewise轴,其他的轴做broadcast,指定的轴不能超过输入x的维度数。数据类型为int64_t,当前仅支持-1,传其他值均不生效。
epsilon:用于防止除0错误,数据类型为double。建议传较小的正数。默认值为1e-6。
输出说明
y1:Device侧的Tensor类型,表示量化后的输出数据。shape支持1-8维度,shape需要与输入x1/x2一致。数据类型支持INT8,数据格式支持ND,支持非连续的Tensor。
y2:Device侧的Tensor类型,表示量化后的输出数据。shape支持1-8维度,shape需要与输入x1/x2一致,数据类型支持INT8,数据格式支持ND,支持非连续的Tensor。
x_out:Device侧的Tensor类型,表示x1和x2的和。shape支持1-8维度,shape需要与输入x1/x2一致。数据类型支持BFLOAT16、FLOAT16,需要与输入x1、x2一致。数据格式支持ND,支持非连续的Tensor。
支持的型号
Atlas 推理系列产品
Atlas A2训练系列产品/Atlas 800I A2中的推理产品/A200I A2 Box异构组件/Atlas A3训练系列产品/Atlas A3推理系列产品
调用示例:
import torch
import torch_npu
x_shape = [16, 32]
quant_shape = [32, ]
x1 = torch.randn(x_shape, dtype=torch.float16).npu()
x2 = torch.randn(x_shape, dtype=torch.float16).npu()
gamma = torch.randn(quant_shape, dtype=torch.float16).npu()
beta = torch.randn(quant_shape, dtype=torch.float16).npu()
scales1 = torch.randn(quant_shape, dtype=torch.float32).npu()
zero_points1 = torch.randint(-10, 10, quant_shape, dtype=torch.int32).npu()
y1, _, x_out = torch_npu.npu_add_rms_norm_quant(x1, x2, gamma, scales1, zero_points1, beta)
"""
)
_add_torch_npu_docstr(
"npu_advance_step_flashattn",
"""
接口原型:
npu_advance_step_flashattn(Tensor(a!) input_tokens, Tensor sampled_token_ids, Tensor(b!) input_positions, Tensor(c!) seq_lens, Tensor(d!) slot_mapping, Tensor block_tables, int num_seqs, int num_queries, int block_size, *, Tensor? spec_token=None, Tensor? accepted_num=None) -> ()
功能描述
在npu上实现vLLM库中advance_step_flashattn的功能,在每个生成步骤中原地更新input_tokens,input_positions,seq_lens和slot_mapping,增加可选入参,用于支持投机推理的计算。
参数说明
input_tokens: Device侧的Tensor类型,输入/输出张量,用于更新vLLM模型中的token值;数据类型支持int64类型;如果是非投机场景,shape为[num_seqs,],如果是投机场景,shape为[num_seqs, 1 + spec_num];Shape第一维长度与num_seqs相同,不支持空tensor,必须为大于0的正整数;
sampled_token_ids: Device侧的Tensor类型,输入张量,用于储存token_id;数据类型支持INT64类型;如果是非投机场景,shape为[num_queries, 1],第二维长度是1;如果是投机场景,shape为[num_seqs, 1 + spec_num];Shape第一维长度与num_queries相同,不支持空tensor,必须为大于0的正整数。
input_positions: Device侧的Tensor类型,输入/输出张量,用于记录token的index;数据类型支持INT64类型;如果是非投机场景,shape为[num_queries, 1],第二维长度是1;如果是投机场景,shape为[num_seqs, 1 + spec_num];第一维长度与num_seqs相同,不支持空tensor,必须为大于0的正整数。
seq_lens: Device侧的Tensor类型,输入/输出张量,用于记录不同block_idx下seq的长度;数据类型支持INT64类型;非投机场景与投机场景shape均为[num_seqs,],第一维长度与num_seqs相同,不支持空tensor,必须为大于0的正整数。
slot_mapping: Device侧的Tensor类型,输入/输出张量,用于将token值在序列中的位置映射到物理位置;数据类型支持INT64类型;非投机场景与投机场景shape均为[num_seqs,],第一维长度与num_seqs相同,不支持空tensor,必须为大于0的正整数。
block_tables: Device侧的Tensor类型,输入/输出张量,用于记录不同block_idx下block的大小;数据类型支持INT64类型;非投机场景与投机场景shape均为二维,第一维长度与num_seqs相同,第二维长度需要大于seq_lens_cpu中最大值除以block_size的整数部分,不支持空tensor,必须为大于0的正整数。
num_seqs: int类型,记录输入的seq数量;非投机场景与投机场景相同,必须为大于0的正整数。
num_queries: int类型,记录输入的query数量;投机场景下与num_seqs相同,必须为大于0的正整数。
block_size: int类型,记录每个block的大小;非投机场景与投机场景相同,必须为大于0的正整数。
spec_token: 可选参数,Device侧的Tensor类型,输入张量,用于记录投机场景下当前的token的idx。数据类型支持INT64类型;spec_token为空时,则为非投机场景,默认为None;spec_token不为空时,则为投机场景,shape为[num_seqs, spec_num];不支持空tensor,必须为大于0的正整数。
accepted_num: 可选参数,Device侧的Tensor类型,输入张量,用于记录投机场景下每个request接受的投机的数量。数据类型支持INT64类型;accepted_num为空时,则为非投机场景,默认为None;accepted_num不为空时,则为投机场景,shape为[num_seqs,];不支持空tensor,必须为大于0的正整数。
输出说明
此接口将原地更新input_tokens,input_positions,seq_lens和slot_mapping的值,无返回值。
约束说明
1. 输入input_tokens,input_positions,seq_lens,slot_mapping和block_tables的第一维长度与num_seqs相同
2. 投机场景下,输入input_tokens的第二维长度为1 + spec_num
3. 输入sampled_token_ids的第一维长度与num_queries相同,非投机场景下第二维长度为1,投机场景下第二维长度为1 + spec_num
4. 输入block_tables的shape的第二维长度大于seq_lens_cpu中最大值除以block_size的整数部分
5. 非投机场景下,输入num_seqs必须大于输入num_queries;投机场景下,num_queries下与num_seqs相同
支持的型号
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
非投机场景:
import numpy as np
import torch
import torch_npu
num_seqs = 16
num_queries = 8
block_size = 8
input_tokens = np.random.randint(10, size=(num_seqs,))
sampled_token_ids = np.random.randint(10, size=(num_queries,1))
input_positions = np.random.randint(10, size=(num_seqs,))
seq_lens = np.random.randint(10, size=(num_seqs,))
slot_mapping = np.random.randint(10, size=(num_seqs,))
input_tokens = torch.tensor(input_tokens, dtype=torch.int64, device="npu")
sampled_token_ids = torch.tensor(sampled_token_ids, dtype=torch.int64, device="npu")
input_positions = torch.tensor(input_positions, dtype=torch.int64, device="npu")
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device="npu")
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, device="npu")
max_seq_len = seq_lens.max().item()
block_tables = np.random.randint(10, size=(num_seqs, max_seq_len // block_size + 1))
block_tables = torch.tensor(block_tables, dtype=torch.int64, device="npu")
torch_npu.npu_advance_step_flashattn(input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, num_seqs, num_queries, block_size)
投机场景:
import numpy as np
import torch
import torch_npu
num_seqs = 16
num_queries = 16
block_size = 8
spec_num = 2
input_tokens = np.random.randint(10, size=(num_seqs*(1 + spec_num),))
sampled_token_ids = np.random.randint(10, size=(num_seqs, 1 + spec_num))
input_positions = np.random.randint(10, size=(num_seqs*(1 + spec_num),))
seq_lens = np.random.randint(10, size=(num_seqs*(1 + spec_num),))
slot_mapping = np.random.randint(10, size=(num_seqs*(1 + spec_num),))
spec_token = np.random.randint(10, size=(num_seqs, spec_num))
accepted_num = np.random.randint(10, size=(num_seqs,))
input_tokens = torch.tensor(input_tokens, dtype=torch.int64, device="npu")
sampled_token_ids = torch.tensor(sampled_token_ids, dtype=torch.int64, device="npu")
input_positions = torch.tensor(input_positions, dtype=torch.int64, device="npu")
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device="npu")
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, device="npu")
spec_token = torch.tensor(spec_token, dtype=torch.int64, device="npu")
accepted_num = torch.tensor(accepted_num, dtype=torch.int64, device="npu")
max_seq_len = seq_lens.max().item()
block_tables = np.random.randint(10, size=(num_seqs, max_seq_len // block_size + 1))
block_tables = torch.tensor(block_tables, dtype=torch.int64, device="npu")
torch_npu.npu_advance_step_flashattn(input_tokens, sampled_token_ids, input_positions,
seq_lens, slot_mapping, block_tables, num_seqs,
num_queries, block_size, spec_token=spec_token, accepted_num=accepted_num)
"""
)
_add_torch_npu_docstr(
"empty_with_swapped_memory",
"""
接口原型:
torch_npu.empty_with_swapped_memory(size, dtype, device) -> Tensor
功能描述
申请一个device信息为NPU、实际内存在host侧的特殊tensor。
参数说明
size (ListInt) - 定义输出张量shape的整数序列。可以是参数数量(可变值),也可以是列表或元组等集合。
dtype (torch.dtype, 可选,默认值为None) - 返回张量所需数据类型。如果值为None,请使用全局默认值(请参见torch.set_default_tensor_type()).
device (torch.device, 可选,默认值为None) - 返回张量的所需设备。
输出说明
此接口将返回一个device信息为NPU、实际内存在host侧的特殊tensor。
约束说明
1. 当前申请出来的特殊tensor仅支持如下算子:
torch.fill_
torch.zero_
torch_npu.npu_apply_adam_w
torch_npu.npu_hans_encode
torch_npu.npu_hans_decode
2. 支持版本
PyTorch 2.1,PyTorch 2.5及更高版本
支持的型号
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
swapped_tensor = torch_npu.empty_with_swapped_memory([12, 12], dtype=torch.float32, device=torch.device("npu:0"))
swapped_tensor.zero_()
"""
)
_add_torch_npu_docstr(
"npu_alltoallv_gmm",
"""
接口原型:
npu_alltoallv_gmm(Tensor gmm_x, Tensor gmm_weight, str hcom, int ep_world_size, int[] send_counts, int[] recv_counts, *, Tensor? send_counts_tensor=None, Tensor? recv_counts_tensor=None, Tensor? mm_x=None, Tensor? mm_weight=None, bool trans_gmm_weight=False, bool trans_mm_weight=False, bool permute_out_flag=False) -> (Tensor, Tensor, Tensor)
功能描述
alltoallv和grouped matmul的融合算子,对alltoallv通信后的输出做grouped matmul操作,通信时间和计算时间进行掩盖。
参数说明
gmmX: device侧Tensor,表示输入,数据类型支持float16,bfloat16。该输入进行AllToAllv通信,仅支持二维, 数据格式支持ND,通信后结果作为GrouedMatMul计算的左矩阵
gmmWeight:device侧Tensor,表示输入,数据类型支持float16, bfloat16,类型需与gmmX保持一致,仅支持三维, 数据格式支持ND,GrouedMatMul计算的右矩阵
hcom:char*类型,计算输入,专家并行的通信域名。字符串长度需大于0,小于128。
ep_world_size:int类型,计算输入,ep通信域size,支持8/16/32/64。
sendCounts:int[],计算输入,支持int数据类型,通信发送的数据量。
recvCounts:int[],计算输入,支持int数据类型,通信接收的数据量。
send_counts_tensor:device侧Tensor,表示输入,暂不支持。
recv_counts_tensor:device侧Tensor,表示输入,暂不支持。
mm_x:device侧Tensor,表示输入,数据类型支持float16,bfloat16,共享专家的左矩阵。
mm_weight:device侧Tensor,表示输入,数据类型支持float16,bfloat16,共享专家的右矩阵。
transGmmWeight:为True:表明gmm的右矩阵要转置,为False时表明gmm右矩阵不转置,默认为false
transMmWeight:为True:表明mm的右矩阵要转置,为False时表明mm右矩阵不转置,默认为false
permute_out_flag:为True:表明permute结果输出,为False时表明permute结果不输出,默认为false
输出说明
gmmY:device侧Tensor, 计算输出,数据类型支持float16, bfloat16。最终计算结果,数据类型与输入gmmX保持一致
mmY:device侧Tensor, 数据类型支持float16, bfloat16,共享专家matmul的输出,仅当传入mmX与mmWeight才输出,数据类型与mmX保持一致。
permute_out:device侧Tensor, 数据类型支持float16, bfloat16,alltoallv输出的中间结果,permute_out_flag为True表明permute结果输出,为False时表明permute结果不输出。
支持的型号
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_npu_alltoallv_gmm(rank, world_size, master_ip, master_port, gmm_x, gmm_w, send_counts, recv_counts, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
input = torch.randn(gmm_x, dtype=dtype).npu()
weight = torch.randn(gmm_w, dtype=dtype).npu()
gmmYOut, _, _ = torch_npu.npu_alltoallv_gmm(gmm_x=input,
gmm_weight=weight,
send_counts_tensor=None,
recv_counts_tensor=None,
mm_x=None,
mm_weight=None,
group=hcom_info,
ep_world_size=world_size,
send_counts=send_counts,
recv_counts=recv_counts,
trans_gmm_weight=False,
trans_mm_weight=False,
permute_out_flag=True)
def generate_matrix(self, e, ep_world_size, bsk, name="alltoallv_gmm", max_iter=10000):
import hashlib
hash_bytes = hashlib.sha256(name.encode()).digest()
seed = int.from_bytes(hash_bytes[:4], byteorder='big')
np.random.seed(seed)
row_size = ep_world_size
col_size = e * ep_world_size
matrix = []
avg = bsk // col_size
tail_num = bsk % col_size
matrix = np.full((row_size, col_size), avg)
matrix[:, -1] += tail_num
return matrix
if __name__ == "__main__":
worksize = 8
e = 4
master_ip = '127.0.0.1'
master_port = '50001'
BS = 128
K = 8
x1_shape = [BS*K, 2048]
x2_shape = [2048, 2048]
send_counts = self.generate_matrix(e, worksize, BS*K)
recv_counts = np.hstack(np.split(mc2_send_counts.reshape(-1, e), epWorldSize, axis=0))
dtype = torch.float16
mp.spawn(run_npu_alltoallv_gmm, args=(worksize, master_ip, master_port, gmm_x, gmm_weight, send_counts, recv_counts, dtype), nprocs=worksize)
"""
)
_add_torch_npu_docstr(
"npu_alltoallv_quant_gmm",
"""
函数原型
torch_npu.npu_alltoallv_quant_gmm(gmm_x, gmm_weight, gmm_x_scale, gmm_weight_scale, hcom, ep_world_size, send_counts, recv_counts, gmm_y_dtype, *, send_counts_tensor=None, recv_counts_tensor=None, mm_x=None, mm_weight=None, mm_x_scale=None, mm_weight_scale=None, gmm_x_quant_mode=None, gmm_weight_quant_mode=None, mm_x_quant_mode=None, mm_weight_quant_mode=None, permute_out_flag=False, group_size=None, gmm_x_dtype=None, gmm_weight_dtype=None, gmm_x_scale_dtype=None, gmm_weight_scale_dtype=None, mm_x_dtype=None, mm_weight_dtype=None, mm_x_scale_dtype=None, mm_weight_scale_dtype=None, mm_y_dtype=None) -> (Tensor, Tensor, Tensor)
参数说明
gmm_x(Tensor):必选输入,表示本卡通信前路由专家原始左矩阵的输入。数据类型支持hifloat8。支持2维,shape为(BSK,H1),数据格式支持ND。
gmm_weight(Tensor):必选输入,表示本卡路由专家GroupedMatmul的右矩阵输入。数据类型支持hifloat8。支持3维,shape为(e,H1,N1),数据格式支持ND。
gmm_x_scale(Tensor):必选输入,表示路由专家左矩阵gmm_x的量化系数。数据类型支持float32。pertensor量化场景下支持1维,shape为(1,)。数据格式支持ND。
gmm_weight_scale(Tensor):必选输入,表示路由专家右矩阵gmm_weight的量化系数。数据类型支持float32。pertensor量化场景下支持1维,shape为(1,)。数据格式支持ND。
hcom(str):必选输入。Host侧标识列组的字符串,即通信域名称,通过get_hccl_comm_name接口获取。
ep_world_size(int):必选输入。通信域内的rank总数。支持范围为2、4、8、16、32、64、128、256。
send_counts(List(int)):必选输入。长度为e*ep_world_size的整数列表,表示本卡发送给每个目标卡的token数。假设目标卡号为i(0<=i<ep_world_size),发送专家号为j(0<=j<e),send_counts[i][j]表示本卡发送给第i张卡第j个专家的token数。约束:长度必须等于e*ep_world_size,且元素均为非负整数。
recv_counts(List(int)):必选输入。长度为e*ep_world_size的整数列表,表示本卡从每个目标卡接收的token数。假设目标卡号为i(0<=i<ep_world_size),接收专家号为j(0<=j<e),recv_counts[i][j]表示本卡接收第i张卡第j个专家的token数。约束:长度必须等于e*ep_world_size,且元素均为非负整数。
gmm_y_dtype(int):必选输入。表示路由专家GroupedMatmul计算输出张量gmm_y的数据类型(例如:torch.float16)。数据类型支持float16、bfloat16。
send_counts_tensor(Tensor):可选输入,当前仅支持输入None。
recv_counts_tensor(Tensor):可选输入,当前仅支持输入None。
mm_x(Tensor):可选输入,默认None。表示共享专家Matmul的左矩阵输入,仅在启用共享专家时输入。数据类型支持hifloat8。支持2维,shape为 (BS,H2)。数据格式支持ND。
mm_weight(Tensor):可选输入,默认None。表示共享专家Matmul的右矩阵输入,仅在启用共享专家时输入。数据类型支持hifloat8。支持2维,shape为(H2,N2)。数据格式支持ND。
mm_x_scale(Tensor):可选输入,默认None。表示共享专家左矩阵mm_x的量化系数。数据类型支持float32。pertensor量化场景下支持1维,shape为(1,)。数据格式支持ND。
mm_weight_scale(Tensor): 可选输入,默认None。表示共享专家右矩阵mm_weight的量化系数。数据类型支持float32。pertensor量化场景下支持1维,shape为(1,)。数据格式支持ND。
gmm_x_quant_mode(int):可选输入,表示路由专家左矩阵的量化模式。当前仅支持1,表示pertensor量化。
gmm_weight_quant_mode(int):可选输入,表示路由专家右矩阵的量化模式。当前仅支持1,表示pertensor量化。
mm_x_quant_mode(int):可选输入,表示共享专家左矩阵的量化模式。当前仅支持1,表示pertensor量化。
mm_weight_quant_mode(int):可选输入,表示共享专家右矩阵的量化模式。当前仅支持1,表示pertensor量化。
permute_out_flag(bool):可选输入,默认False。是否返回通信后重排的路由专家矩阵(即permute_out)。若为True,则返回值中包含该张量。
group_size(List(int)):可选输入,当前仅支持None,预留参数
gmm_x_dtype(int):可选输入,默认None。表示路由专家左矩阵gmm_x的实际数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
gmm_weight_dtype(int):可选输入,默认None。表示路由专家右矩阵gmm_weight的实际数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
gmm_x_scale_dtype(int):可选输入,默认None。表示路由专家左矩阵量化系数gmm_x_scale的实际数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
gmm_weight_scale_dtype(int):可选输入,默认None。表示路由专家右矩阵量化系数gmm_weight_scale的实际数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
mm_x_dtype(int):可选输入,表示共享专家左矩阵mm_x的数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
mm_weight_dtype(int):可选输入,表示共享专家右矩阵mm_weight的数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
mm_x_scale_dtype(int):可选输入,表示共享专家左矩阵量化系数mm_x_scale 的数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
mm_weight_scale_dtype(int):可选输入,表示共享专家右矩阵量化系数mm_weight_scale的数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。
mm_y_dtype(int):可选输入,表示共享专家输出张量mm_y的数据类型。对于PyTroch原生不支持的数据类型(如float8_e8m0)需要指定该参数取值。数据类型支持float16、bfloat16。
返回值说明
gmm_y(Tensor):表示路由专家GroupedMatmul计算的输出,数据类型为gmm_y_dtype指定的类型,支持float16、bfloat16。支持2维,shape为(A,N1)。数据格式支持ND。
mm_y(Tensor):表示共享专家MatMul的输出,数据类型为mm_y_dtype指定的类型,支持float16、bfloat16。支持2维,shape为(BS,N2)。仅当传入mm_x与mm_weight才输出。数据格式支持ND。
permute_out(Tensor):计算输出,Permute之后的输出,数据类型与gmm_x保持一致。支持2维,shape为(A,N1)。数据格式支持ND。
约束说明
该接口支持训练、推理场景下使用。
该接口仅支持单算子模式调用。
参数说明中Shape涉及的变量说明:
BS表示batch sequence size。
K表示选取的topK专家个数。当存在共享专家计算时,K需要满足取值范围[2,8]。
BSK=sum(send_counts),表示本卡Alltoallv通信中发送给其他卡的总token数,取值范围(0,52428800)。
H1表示本卡路由专家的hidden size,取值范围(0,65536)。
H2表示本卡共享专家的hidden size,取值范围(0,12288]。
N1表示路由专家输出维度,取值范围(0,65536)。
N2表示共享专家输出维度,取值范围(0,65536)。
e表示通信后每张卡上的专家数量,取值范围(0,32]。e*ep_world_size<=256。
A表示路由专家计算输出的总token数。A=sum(recv_counts)。EP通信域内所有卡上的A累加和等于所有卡上的BSK累加和。
第i张卡发送到第j张卡数据量为send_counts[j]与第j张卡接收数据量为recv_counts[i]必须相等。
gmm_x_quant_mode、gmm_weight_quant_mode、mm_x_quant_mode、mm_weight_quant_mode值与量化模式关系如下:
0:非量化
1:pertensor
2:perchannel
3:pertoken
4:pergroup
5:perblock
6:mx量化
7:pertoken动态量化
当前仅支持取1,表示pertensor量化场景。
调用示例
单算子模式调用:
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
def generate_counts(ep_world_size, e, total_tokens, seed=None):
np.random.seed(seed if seed is not None else 42)
per_rank_total = total_tokens
base = per_rank_total // (ep_world_size * e)
remainder = per_rank_total % (ep_world_size * e)
send_counts = [base] * (ep_world_size * e)
for i in range(remainder):
send_counts[-1 - i] += 1
recv_counts = send_counts.copy()
return send_counts, recv_counts
def run_npu_alltoallv_quant_gmm(rank, world_size, master_ip, master_port):
torch_npu.npu.set_device(rank)
init_method = f"tcp://{master_ip}:{master_port}"
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > "2.0.1":
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
BS = 128
K = 2
e = 2
H1, N1 = 256, 256
H2, N2 = 256, 128
total_tokens = BS * K
send_counts, recv_counts = generate_counts(world_size, e, total_tokens, seed=rank)
gmm_x = torch.randint(0, 255, (total_tokens, H1), dtype=torch.uint8).npu()
gmm_weight = torch.randint(0, 255, (e, H1, N1), dtype=torch.uint8).npu()
gmm_x_scale = torch.tensor([0.5], dtype=None).npu()
gmm_weight_scale = torch.tensor([0.3], dtype=None).npu()
mm_x = torch.randint(0, 255, (BS, H2), dtype=torch.uint8).npu()
mm_weight = torch.randint(0, 255, (H2, N2), dtype=torch.uint8).npu()
mm_x_scale = torch.tensor([0.4], dtype=None).npu()
mm_weight_scale = torch.tensor([0.2], dtype=None).npu()
quant_mode = 1
out_dtype = torch.float16
gmm_y, mm_y, permute_out = torch_npu.npu_alltoallv_quant_gmm(
gmm_x=gmm_x,
gmm_weight=gmm_weight,
gmm_x_scale=gmm_x_scale,
gmm_weight_scale=gmm_weight_scale,
hcom=hcom_info,
ep_world_size=world_size,
send_counts=send_counts,
recv_counts=recv_counts,
gmm_y_dtype=out_dtype,
mm_x=mm_x,
mm_weight=mm_weight,
mm_x_scale=mm_x_scale,
mm_weight_scale=mm_weight_scale,
gmm_x_quant_mode=quant_mode,
gmm_weight_quant_mode=quant_mode,
mm_x_quant_mode=quant_mode,
mm_weight_quant_mode=quant_mode,
permute_out_flag=True,
gmm_x_dtype=torch_npu.hifloat8,
gmm_weight_dtype=torch_npu.hifloat8,
gmm_x_scale_dtype=None,
gmm_weight_scale_dtype=None,
mm_x_dtype=torch_npu.hifloat8,
mm_weight_dtype=torch_npu.hifloat8,
mm_x_scale_dtype=None,
mm_weight_scale_dtype=None,
mm_y_dtype=out_dtype,
send_counts_tensor=None,
recv_counts_tensor=None,
group_size=None
)
if __name__ == "__main__":
world_size = 2
master_ip = "127.0.0.1"
master_port = "50001"
mp.spawn(run_npu_alltoallv_quant_gmm, args=(world_size, master_ip, master_port), nprocs=world_size, join=True)
"""
)
_add_torch_npu_docstr(
"npu_swiglu_quant",
"""
torch_npu.npu_swiglu_quant(Tensor x, *, Tensor? smooth_scales=None, Tensor? offsets=None, Tensor? group_index=None, bool activate_left=False, int quant_mode=0, int group_list_type=0, ScalarType? dst_type=None) -> (Tensor, Tensor)
功能描述
在swiglu激活函数后添加quant操作。
参数说明
x (Tensor):必选参数,表示目标张量。数据类型支持float16、bfloat16、float32,支持非连续的Tensor,数据格式为ND,x的维数必须大于1维,尾轴为偶数且长度不超过8192,当dst_type传入值为29(输出为int4量化)时,x的最后一维需要为4的倍数。
smooth_scales (Tensor):可选参数,表示smooth量化系数。数据类型支持float32,支持非连续的Tensor,数据格式为ND。shape支持[G, N],[G, ]。
offsets (Tensor):可选参数,表示量化中的偏移项,该参数在动态量化场景下不生效,传入None即可。静态量化场景下:数据类型支持FLOAT,支持非连续的Tensor,数据格式为ND。per_channel模式下shape支持[G, N],per_tensor模式下shape支持[G, ],且数据类型和shape需要与smooth_scales保持一致。
group_index (Tensor):可选参数,当前支持cumsum和count两种模式,要求为1维张量,数据类型支持int32,数据格式ND,shape支持[G, ],group_index内元素要求为非递减,且最大值不得超过输入x的除最后一维之外的所有维度大小之积。
activate_left (bool):可选参数,Swiglu流程中是否进行左激活,默认False。
quant_mode (int):可选参数,表示量化类型,默认值为0。0表示静态量化,1表示动态量化。
group_list_type (int):可选参数,表示group_index类型,默认值为0。0表示cumsum模式,1表示count模式。
dst_type (ScalarType): 可选参数,指定量化输出的类型, 传None时当做torch.int8处理。
输出说明
out (Tensor):表示量化后的输出tensor。数据类型支持int8和int4,支持非连续的Tensor,数据格式为ND。
scale (Tensor):表示量化的scale参数,计算输出scale的shape与计算输入x相比,无最后一维,其余维度与计算输入x保持一致,数据类型支持float32,数据格式为ND。
支持的型号
A2训练、推理系列产品
A3训练、推理系列产品
调用示例
import os
import shutil
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUSwigluQuant(TestCase):
def test_npu_swiglu_quant(self, device="npu"):
batch_size = 4608
hidden_size = 2048
x_shape = (batch_size, hidden_size)
input_data = np.random.randn(*x_shape).astype(np.float32)
quant_mode = 1
group_list_type = 0
dst_type = torch.int8
activate_left = False
offsets = None
num_groups = 8
group_sizes = batch_size // num_groups
group_index = [(i + 1) * group_sizes for i in range(num_groups)]
smooth_scales = np.random.randn(num_groups, hidden_size // 2).astype(np.float32)
device = "npu"
npu_x = torch.from_numpy(input_data).to(device)
npu_group_index = torch.from_numpy(np.array(group_index)).to(device)
npu_smooth_scales = torch.from_numpy(smooth_scales).to(device)
result = torch_npu.npu_swiglu_quant(
npu_x,
smooth_scales=npu_smooth_scales,
offsets=offsets,
group_index=npu_group_index,
activate_left=False,
quant_mode=quant_mode,
group_list_type=group_list_type,
dst_type=dst_type
)
if __name__ == "__main__":
run_tests()
"""
)
_add_torch_npu_docstr(
"npu_grouped_matmul_swiglu_quant",
"""
torch_npu.npu_grouped_matmul_swiglu_quant(Tensor x, Tensor weight, Tensor group_list, Tensor weight_scale, Tensor x_scale, *, Tensor? bias=None, Tensor? offset=None) -> (Tensor, Tensor, Tensor)
功能描述
aclnnGroupedMatmulV4、aclnnDynamicDequant、aclnnSwigluQuant融合, deepseek模型使用,对比小算子做性能优化。
参数说明
x(Tensor):输入,左矩阵,公式中的X,Device侧的aclTensor。shape支持2维,数据类型支持INT8,数据格式支持ND,支持非连续的Tensor。
weight(Tensor):输入,权重矩阵,公式中的W,Device侧的aclTensor。shape支持5维,数据类型支持INT8,数据格式支持FRACTAL_NZ,支持非连续的Tensor,需注意该接口会将weight的数据格式强制视为FRACTAL_NZ格式。
group_list (Tensor):输入,指示每个分组参与计算的Token个数,公式中的grouplist,Device侧的aclTensor。shape支持1维,长度需与weight的首轴维度相等,数据类型支持INT64,数据格式支持ND,支持非连续的Tensor。
weight_scale (Tensor):输入,右矩阵的量化因子,公式中的w_scale,Device侧的aclTensor。shape支持2维,首轴长度需与weight的首轴维度相等,尾轴长度需要与weight还原为ND格式的尾轴相同,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND,支持非连续的Tensor。
x_scale (Tensor):输入,左矩阵的量化因子,公式中的x_scale,Device侧的aclTensor。shape支持1维,长度需与x的首轴维度相等,数据类型支持FLOAT,数据格式支持ND,支持非连续的Tensor。
bias(可选,暂不支持,Tensor):输入,矩阵乘计算的偏移值,公式中的bias,shape支持2维,数据类型支持INT32,预留输入,暂不支持。
offset(可选,暂不支持,Tensor):输入,per-channel非对称反量化的偏移,公式中的offset,shape支持2维,数据类型支持Float,预留输入,暂不支持。
输出说明
group_list指导了输入和输出中的有效值范围,该数值由前置算子得到,动态变化。应根据group_list,对结果中脏数据做截断处理,即有效数据截至到group_list[-1],即:output[:groupList[-1],:],output_scale[:groupList[-1]]
output(Tensor):输出的量化结果,公式中的Q,Device侧的aclTensor。数据类型支持INT8,shape支持2维,Device侧的aclTensor。数据格式支持ND,支持非连续的Tensor。
output_scale(Tensor):输出的量化因子,公式中的Q_scale,Device侧的aclTensor。数据类型支持FLOAT,shape支持1维,Device侧的aclTensor。数据格式支持ND,支持非连续的Tensor。
output_offset(预留输出,暂不支持,Tensor):输出的非对称量化的偏移,公式中的Q_offset,Device侧的aclTensor,shape支持1维,数据类型支持FLOAT。
支持的型号
A2训练、推理系列产品
A3训练、推理系列产品
调用示例
import torch
import torch_npu
import numpy as np
def generate_non_decreasing_sequence(length, upper_limit):
# 生成随机增量
random_increments = torch.randint(1, 128, (length,), dtype=torch.int64) # 避免零增量
# 累加生成非递减序列
sequence = torch.cumsum(random_increments, dim=0)
# 确保最后一个元素不超过上限
if sequence[-1] > upper_limit:
# 线性缩放以确保总和不超过上限
scale_factor = upper_limit / sequence[-1].item()
sequence = (sequence * scale_factor).to(torch.int64)
for i in range(1, length):
if sequence[i] <= sequence[i-1]:
sequence[i] = sequence[i-1] + 1
return sequence
def gen_input_data(E=16, M=512, K=7168, N=4096):
x = torch.randint(-128, 127, (M, K), dtype=torch.int8).npu()
weight = torch.randint(-128, 127, (E, K, N), dtype=torch.int8).npu()
weight_npu = torch_npu.npu_format_cast(weight.npu(), 29)
weight_scale = torch.randn(E, N, dtype=torch.float32).npu()
x_scale = torch.randn(M, dtype=torch.float32).npu()
group_list = generate_non_decreasing_sequence(E, M).npu()
output, output_scale, output_offset = torch_npu.npu_grouped_matmul_swiglu_quant(
x, weight_npu, group_list, weight_scale, x_scale,
bias=None,
offset=None
)
return output, output_scale, output_offset
def main():
output, output_scale, output_offset = gen_input_data()
if __name__ == "__main__":
main()
"""
)
_add_torch_npu_docstr(
"npu_grouped_matmul_swiglu_quant_v2",
"""
torch_npu.npu_grouped_matmul_swiglu_quant_v2(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale, Tensor group_list, *, Tensor? smooth_scale=None, Tensor[]? weight_assist_matrix=None, Tensor? bias=None, int? dequant_mode=0, int? dequant_dtype=0, int? quant_mode=0, int? quant_dtype=0, int? group_list_type=0, int[]? tuning_config=None) -> (Tensor, Tensor)
功能描述
`npu_grouped_matmul_swiglu_quant_v2`是一种融合分组矩阵乘法(GroupedMatmul)、SwiGLu混合激活函数、量化(quant)的计算方法。该方法适用于需要对矩阵乘法结果进行SwiGlu激活函数激活的场景,融合算子在底层能够对部分过程并行,达到性能优化的效果。支持 A8W8、A8W4、A4W4;A4W4 场景下 smooth_scale 必填(与 aclnnGroupedMatmulSwigluQuantV2 一致)。
参数说明
x(Tensor):必选输入,矩阵乘法的左矩阵。shape支持2维[m,k],数据类型支持`int8`,数据格式支持ND,支持非连续的Tensor。
weight(TensorList):必选输入,权重矩阵(矩阵乘法右矩阵),shape支持3维[e,k,n]、5维(FRACTAL_NZ),数据类型支持`int8`、`int32`,数据格式支持FRACTAL_NZ(通过接口npu_format_cast,可实现格式转换),支持非连续的Tensor。
weight_scale(TensorList):必选输入,右矩阵的量化因子。`weight`数据类型为`int8`时,`weight_scale`的shape支持2维,`weight`数据类型为`int32`时,`weight_scale`的shape支持2维和3维。数据类型支持`float32`,数据格式支持ND,支持非连续的Tensor。
x_scale(Tensor):必选输入,左矩阵的量化因子。shape支持1维[m],数据类型支持`float32`,数据格式支持ND,支持非连续的Tensor。
group_list(Tensor):必选输入,指示每个分组参与计算的Token个数。shape支持1维[e],数据类型支持`int64`,数据格式支持ND,支持非连续的Tensor。
smooth_scale(Tensor):可选输入,平滑缩放因子,数据类型为`float32`。A4W4 场景下必填,形状 (E, N/2) 或 (E,);其他场景传 None。
weight_assist_matrix(TensorList):可选输入,右矩阵的辅助矩阵,数据类型支持`float32`。仅 A8W4 场景使用,其他场景传 None。
bias(Tensor):可选输入,矩阵乘计算的偏移值,公式中的bias,shape支持2维,数据类型支持`int32`,当前仅支持传入默认值None。
dequant_mode(int):可选输入,表示反量化模式。`weight`数据类型为`int8`时仅支持0,`weight`数据类型为`int32`时支持0和1。0:左pertoken,右perchannel;1:左pertoken,右pergroup。
dequant_dtype(int):可选输入,表示反量化类型,当前仅支持传入默认值0。
quant_dtype(int):可选输入,参数表示量化后低比特数据类型。0:`int8`;1:`float8_e8m0`;2:`float8_e5m2`;3:`float8_e4m3`,当前仅支持传入默认值0。
quant_mode(int):可选输入,参数表示SwiGLU后的量化模式。0:pertoken;1:perchannel,当前仅支持传入默认值0。
group_list_type(int):可选输入,参数表示grouplist的输入类型。0:cumsum;1:count,默认0。
tuning_config(List[int]):可选输入,默认设置为None。
输出说明
output(Tensor):输出的量化结果,数据类型支持`int8`,shape支持2维[m,n]。数据格式支持ND,支持非连续的Tensor。
output_scale(Tensor):输出的量化因子,数据类型支持`float`,shape支持1维[m]。数据格式支持ND,支持非连续的Tensor。
支持的型号
Ascend 950PR/950DT、A2训练与推理系列产品、A3训练与推理系列产品
调用示例
import torch
import torch_npu
import numpy as np
def test_grouped_matmul_swiglu_quant_v2(E=16, M=512, K=7168, N=4096):
x = torch.randint(-128, 127, (M, K), dtype=torch.int8).npu()
weight = torch.randint(-128, 127, (E, K, N), dtype=torch.int8).npu()
weight_npu = torch_npu.npu_format_cast(weight, 29)
weight_scale = torch.randn(E, N, dtype=torch.float32).npu()
x_scale = torch.randn(M, dtype=torch.float32).npu()
group_list = torch.tensor([128, 128], dtype=torch.int64).npu()
output, output_scale = torch_npu.npu_grouped_matmul_swiglu_quant_v2(
x, [weight_npu], [weight_scale], x_scale, group_list, bias=None
)
return output, output_scale
def main():
output, output_scale = test_grouped_matmul_swiglu_quant_v2()
if __name__ == "__main__":
main()
"""
)
_add_torch_npu_docstr(
"npu_gmm_alltoallv",
"""
接口原型:
npu_gmm_alltoallv(Tensor gmm_x, Tensor gmm_weight, str hcom, int ep_world_size, int[] send_counts, int[] recv_counts, *, Tensor? send_counts_tensor=None, Tensor? recv_counts_tensor=None, Tensor? mm_x=None, Tensor? mm_weight=None, bool trans_gmm_weight=False, bool trans_mm_weight=False) -> (Tensor, Tensor)
功能描述
grouped matmul和alltoallv的融合算子,对grouped matmul计算后的结果进行alltoallv通信的输出做操作,通信时间和计算时间进行掩盖。
参数说明
gmm_x: device侧Tensor,表示输入,数据类型支持float16,bfloat16。该输入进行AllToAllv通信,仅支持二维, 数据格式支持ND,通信后结果作为GrouedMatMul计算的左矩阵
gmm_weight:device侧Tensor,表示输入,数据类型支持float16, bfloat16,类型需与gmmX保持一致,仅支持三维, 数据格式支持ND,GrouedMatMul计算的右矩阵
hcom:char*类型,计算输入,专家并行的通信域名。字符串长度需大于0,小于128。
ep_world_size:int类型,计算输入,ep通信域size,支持8/16/32/64。
send_counts:int[],计算输入,支持int数据类型,通信发送的数据量。
recv_counts:int[],计算输入,支持int数据类型,通信接收的数据量。
send_counts_tensor:device侧Tensor,表示输入,暂不支持。
recv_counts_tensor:device侧Tensor,表示输入,暂不支持。
mm_x:device侧Tensor,表示输入,数据类型支持float16,bfloat16,共享专家的左矩阵。
mm_weight:device侧Tensor,表示输入,数据类型支持float16,bfloat16,共享专家的右矩阵。
trans_gmm_weight:为True:表明gmm的右矩阵要转置,为False时表明gmm右矩阵不转置,默认为false。
trans_mm_weight:为True:表明mm的右矩阵要转置,为False时表明mm右矩阵不转置,默认为false。
输出说明
y:device侧Tensor, 计算输出,数据类型支持float16, bfloat16。最终计算结果,数据类型与输入gmm_X保持一致
mm_y:device侧Tensor, 数据类型支持float16, bfloat16,共享专家matmul的输出,仅当传入mm_x与mm_weight才输出,数据类型与mm_x保持一致。
支持的型号
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_npu_gmm_alltoallv(rank, world_size, master_ip, master_port, gmm_x, gmm_w, send_counts, recv_counts, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
input = torch.randn(gmm_x, dtype=dtype).npu()
weight = torch.randn(gmm_w, dtype=dtype).npu()
y, _= torch_npu.npu_gmm_alltoallv(gmm_x=input,
gmm_weight=weight,
send_counts_tensor=None,
recv_counts_tensor=None,
mm_x=None,
mm_weight=None,
group=hcom_info,
ep_world_size=world_size,
send_counts=send_counts,
recv_counts=recv_counts,
trans_gmm_weight=False,
trans_mm_weight=False)
def generate_matrix(self, e, ep_world_size, bsk, name="alltoallv_gmm", max_iter=10000):
import hashlib
hash_bytes = hashlib.sha256(name.encode()).digest()
seed = int.from_bytes(hash_bytes[:4], byteorder='big')
np.random.seed(seed)
row_size = ep_world_size
col_size = e * ep_world_size
matrix = []
avg = bsk // col_size
tail_num = bsk % col_size
matrix = np.full((row_size, col_size), avg)
matrix[:, -1] += tail_num
return matrix
if __name__ == "__main__":
worksize = 8
e = 4
master_ip = '127.0.0.1'
master_port = '50001'
BS = 128
K = 8
x1_shape = [BS*K, 2048]
x2_shape = [2048, 2048]
send_counts = self.generate_matrix(e, worksize, BS*K)
recv_counts = np.hstack(np.split(mc2_send_counts.reshape(-1, e), epWorldSize, axis=0))
dtype = torch.float16
mp.spawn(run_npu_gmm_alltoallv, args=(worksize, master_ip, master_port, gmm_x, gmm_weight, send_counts, recv_counts, dtype), nprocs=worksize)
"""
)
_add_torch_npu_docstr(
"npu_quant_gmm_alltoallv",
"""
接口原型:
npu_quant_gmm_alltoallv(Tensor gmm_x, Tensor gmm_weight, Tensor gmm_x_scale, Tensor gmm_weight_scale, str hcom, int ep_world_size, int[] send_counts, int[] recv_counts, int gmm_y_dtype, *, Tensor? send_counts_tensor=None, Tensor? recv_counts_tensor=None, Tensor? mm_x=None, Tensor? mm_weight=None, Tensor? mm_x_scale=None, Tensor? mm_weight_scale=None, Tensor? gmm_x_offset=None, Tensor? gmm_weight_offset=None, Tensor? mm_x_offset=None, Tensor? mm_weight_offset=None, Tensor? comm_quant_scale=None, int? gmm_x_quant_mode=None, int? gmm_weight_quant_mode=None, int? mm_x_quant_mode=None, int? mm_weight_quant_mode=None, int? comm_quant_mode=None, int[]? group_size=None, int? gmm_x_dtype=None, int? gmm_weight_dtype=None, int? gmm_x_scale_dtype=None, int? gmm_weight_scale_dtype=None, int? mm_x_dtype=None, int? mm_weight_dtype=None, int? mm_x_scale_dtype=None, int? mm_weight_scale_dtype=None, int? comm_quant_dtype=None, int? mm_y_dtype=None) -> (Tensor, Tensor)
功能描述
grouped matmul和alltoallv的融合算子,对grouped matmul计算后的结果进行alltoallv通信的输出做操作,通信时间和计算时间进行掩盖, 添加量化操作。
参数说明
gmm_x(Tensor):必选输入,表示GroupedMatMul计算的做矩阵Tensor。数据类型为hifloat8,支持ND格式,不支持非连续Tensor,维度为2维,Shape为(BSK,H1)。
gmm_weight(Tensor):必选输入,GroupedMatmul的右矩阵,数据类型需与gmm_x一致(hifloat8),支持ND格式,不支持非连续Tensor,维度为3维,Shape为(e, H1, N1)。
gmm_x_scale(Tensor):必选输入,表示左矩阵的量化缩放系数,数据类型为float32,支持ND格式,不支持非连续Tensor,维度为1维,Shape通常为(1,)或(e,)。
gmm_weight_scale(Tensor):必选输入,表示右矩阵的量化缩放系数,数据类型为float32,支持ND格式,不支持非连续Tensor,维度为1维,Shape通常为(1,)或(e,)。
hcom(str):必选输入,表示专家并行的通信域名称,字符串长度需在(0,128)范围内。
ep_world_size(int):必选输入,表示专家并行通信域的size,支持2、4、6、16、32、64、128、256等数值。
send_counts(List[int]):必选输入,表示发送给其他卡的token数列表,数据类型支持int64,数组大小为e*ep_world_size,当前暂不支持且不支持非连续Tensor。
recv_counts(List[int]):必选输入,表示从其他卡接收的token数列表,数据类型支持int64,数组大小为e*ep_world_size,当前暂不支持且不支持非连续Tensor。
gmm_y_dtype(int):必选输入,表示GroupedMatmul输出矩阵y的目标数据类型,支持float16或bfloat16。
send_counts_tensor(Tensor):可选输入,默认为None,表示设备侧接口计数Tensor,维度为1维,shape为(e*ep_world_size),当前暂不支持且不支持非连续Tensor。
recv_counts_tensor(Tensor):可选输入,默认为None,表示设备侧接口计数Tensor,维度为1维,shape为(e*ep_world_size),当前暂不支持且不支持非连续Tensor。
mm_x(Tensor):可选输入,默认为None,表示共享专家MatMul计算中的左矩阵,数据类型为hifloat8,支持ND格式,不支持非连续Tensor,维度为2维,Shape维(BS, H2)。
mm_weight(Tensor):可选输入,默认为None,表示共享专家MatMul计算中的右矩阵,数据类型为hifloat8,支持ND格式,不支持非连续Tensor,维度为2维,Shape维(H2, N2)。
mm_x_scale(Tensor):可选输入,默认值None,共享专家matmul计算中左矩阵的量化参数,数据类型为float32,不支持非连续Tensor。
mm_weight_scale(Tensor):可选输入,默认值None,共享专家matmul计算中右矩阵的量化参数,数据类型为float32,不支持非连续Tensor。
gmm_x_offset(Tensor):可选输入,默认值None,表示左矩阵的量化偏置,数据类型为float32,不支持非连续Tensor,目前暂不支持此参数。
gmm_weight_offset(Tensor):可选输入,默认值None,表示右矩阵的量化偏置,数据类型为float32,不支持非连续Tensor,目前暂不支持此参数。
mm_x_offset(Tensor):可选输入,默认值None,表示共享专家左矩阵的量化偏置,数据类型为float32,不支持非连续Tensor,目前暂不支持此参数。
mm_weight_offset(Tensor):可选输入,默认值None,表示共享专家左矩阵的量化偏置,数据类型为float32,不支持非连续Tensor,目前暂不支持此参数。
comm_quant_scale(Tensor):可选输入,默认值None,低比特通信的量化参数,数据类型为float32,不支持非连续Tensor,当为None时表示不进行低比特通信。
gmm_x_quant_mode(int):可选输入,默认值为None,表示左矩阵量化模式,当前仅支持配置为1(pertensor量化)。
gmm_weight_quant_mode(int):可选输入,默认值为None,表示右矩阵量化模式,当前仅支持配置为1(pertensor量化)。
mm_x_quant_mode(int):可选输入,默认值为None,表示共享专家左矩阵量化模式,当前仅支持配置为1(pertensor量化)。
mm_weight_quant_mode(int):可选输入,默认值为None,表示共享专家右矩阵量化模式,当前仅支持配置为1(pertensor量化)。
comm_quant_mode(int):可选输入,默认值为None,低比特通信量化模式,当前仅支持0,表示不进行低比特通信。
group_size(int):可选输入,默认值为None,表示通信域的规模属性列表,用于Matmul计算三个方向上的量化分组大小,预留参数,当前不支持。
gmm_x_dtype(int):可选输入,默认值为None,表示左矩阵在计算内部的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
gmm_weight_dtype(int):可选输入,默认值为None,表示右矩阵在计算内部的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
gmm_x_scale_dtype(int):可选输入,默认值为None,表示左矩阵量化缩放系数精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
gmm_weight_scale_dtype(int):可选输入,默认值为None,表示右矩阵量化缩放系数的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
mm_x_dtype(int):可选输入,默认值为None,表示共享专家左矩阵在计算内部的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
mm_weight_dtype(int):可选输入,默认值为None,表示共享专家右矩阵在计算内部的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
mm_x_scale_dtype(int):可选输入,默认值为None,表示共享专家左矩阵量化缩放系数的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
mm_weight_scale_dtype(int):可选输入,默认值为None,表示共享专家右矩阵量化缩放系数的精度类型,默认为输入Tensor的原始类型,用于适配pytorch原生不支持的数据类型(hifloat8)。
comm_quant_dtype(int):可选输入,低比特通信量化后的数据类型,当前不支持。
mm_y_dtype(int):可选输入,默认值为None,表示共享专家输出矩阵的目标数据类型。
输出说明
y(Tensor):主计算输出。GroupedMatmul的最终计算结果。支持2维Tensor,Shape为(BSK,N1),数据类型由输入参数gmm_y_dtype决定。
mm_y(Tensor):共享专家计算输出,当且仅当输入参数中提供了mm_x与mm_weight时,该输出才包括有效数据。数据类型由mm_y_dtype决定(未指定时与mm_x保持一致),支持2维Tensor,形状为(BS,N2),未启用共享专家返回空的Tensor。
约束说明
* 该接口支持训练、推理场景下使用
* 该接口仅支持单算子模式调用。
* 参数说明里shape使用的变量:
- BSK:本卡接收的token数,是recvCounts参数累加之和,取值范围(0, 52428800)。
- H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
- e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
- N1:表示路由专家的head_num,取值范围(0, 65536)。
- N2:表示共享专家的head_num,取值范围(0, 65536)。
- BS:batch sequence size。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:本卡发送的token数,是sendCounts参数累加之和。
- ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
支持的型号
Atlas A5训练系列产品
调用示例:
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_npu_quant_gmm_alltoallv(rank, world_size, master_ip, master_port):
torch_npu.npu.set_device(rank)
init_method = f"tcp://{master_ip}:{master_port}"
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
BS, K = 128, 2
H1, N1 = 256, 256
H2, N2 = 256, 128
e = 2
ep_world_size = world_size
total_tokens = BS * K
out_dtype = torch.float16
gmm_x = torch.randint(0, 30, (total_tokens, H1), dtype=torch.uint8).npu()
gmm_weight = torch.randint(0, 30, (e, H1, N1), dtype=torch.uint8).npu()
gmm_x_scale = torch.tensor([0.5], dtype=torch.float32).npu()
gmm_weight_scale = torch.tensor([0.3], dtype=torch.float32).npu()
mm_x = torch.randint(0, 30, (BS, H2), dtype=torch.uint8).npu()
mm_weight = torch.randint(0, 30, (H2, N2), dtype=torch.uint8).npu()
mm_x_scale = torch.tensor([0.4], dtype=torch.float32).npu()
mm_weight_scale = torch.tensor([0.2], dtype=torch.float32).npu()
send_counts = [total_tokens // (e * ep_world_size)] * (e * ep_world_size)
recv_counts = [total_tokens // (e * ep_world_size)] * (e * ep_world_size)
y, mm_y = torch_npu.npu_quant_gmm_alltoallv(
gmm_x=gmm_x,
gmm_weight=gmm_weight,
gmm_x_scale=gmm_x_scale,
gmm_weight_scale=gmm_weight_scale,
hcom=hcom_info,
ep_world_size=ep_world_size,
send_counts=send_counts,
recv_counts=recv_counts,
gmm_y_dtype=out_dtype,
mm_x=mm_x,
mm_weight=mm_weight,
mm_x_scale=mm_x_scale,
mm_weight_scale=mm_weight_scale,
gmm_x_quant_mode=1,
gmm_weight_quant_mode=1,
mm_x_quant_mode=1,
mm_weight_quant_mode=1,
comm_quant_mode=0,
gmm_x_dtype=torch_npu.hifloat8,
gmm_weight_dtype=torch_npu.hifloat8,
gmm_x_scale_dtype=torch.float32,
gmm_weight_scale_dtype=torch.float32,
mm_x_dtype=torch_npu.hifloat8,
mm_weight_dtype=torch_npu.hifloat8,
mm_x_scale_dtype=torch.float32,
mm_weight_scale_dtype=torch.float32,
mm_y_dtype=out_dtype,
)
if __name__ == "__main__":
world_size = 2
master_ip = 'your-master-address'
master_port = 'your-master-port'
mp.spawn(run_npu_quant_gmm_alltoallv, args=(world_size, master_ip, master_port), nprocs=world_size, join=True)
"""
)
_add_torch_npu_docstr(
"npu_nsa_compress",
"""
torch_npu.npu_nsa_compress(input, weight, compress_block_size, compress_stride, actual_seq_len=None)
功能描述
实现Native Sparse Attention算法中训练场景下的压缩功能。
参数说明
input(Tensor):必选参数,待压缩张量,shape支持[T,N,D],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
weight(Tensor):必选参数,压缩的权重,shape支持[compress_block_size, N],weight和input的shape满足broadcast关系,数据类型支持bfloat16、float16,数据类型与input保持一致,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
compress_block_size(int):必选参数,压缩滑窗的大小。
compress_stride(int):必选参数,两次压缩滑窗间隔大小。
actual_seq_len(list[int]):必选参数,长度表示query有多少个batch,值表示各batch的token长度的前缀和,例如,actual_seq_len[0]=s0,actual_seq_len[1]=s0+s1,...,actual_seq_len[-1]=T。
输出说明
代表压缩后的结果。
约束说明
input.shape[1] = weight.shape[1] = head_num
compress_block_size、compress_stride 必须是16的整数倍,且compress_block_size>=compress_stride
input.shape[0] = act_seq_len[-1]
input.shape[2] = head_dim必须是16的整数倍
目前仅支持head_num<=128,compress_block_size <= 128, head_dim <= 256
支持的型号
Atlas A2训练系列产品
调用示例
>>> import torch
>>> import torch_npu
>>> import numpy as np
>>> actual_seq_len = np.random.randint(0, 100, [48])
>>> actual_seq_len = np.cumsum(actual_seq_len).astype(np.int64)
>>> head_num = 4
>>> head_dim = 128
>>> compress_block_size = 16
>>> compress_stride = 16
>>> input = torch.randn(actual_seq_len[-1], head_num, head_dim, dtype=torch.float16).npu()
>>> weight = torch.randn(compress_block_size, head_num, dtype=torch.float16).npu()
>>> torch_npu.npu_nsa_compress(input, weight, compress_block_size, compress_stride, actual_seq_len=actual_seq_len)
"""
)
_add_torch_npu_docstr(
"npu_nsa_compress_infer",
"""
torch_npu.npu_nsa_compress_infer(input, weight, slot_mapping, compress_block_size, compress_stride, page_block_size, block_table=None, actual_seq_len=None, cache)
功能描述
Native Sparse Attention算法中推理场景下,实现对KV压缩的计算。
参数说明
input(Tensor):必选输入,待压缩张量,shape支持[block_num,page_block_size,head_num,head_dim],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
weight(Tensor):必选输入,压缩的权重,shape支持[compress_block_size, head_num],weight和input的shape满足broadcast关系,数据类型支持bfloat16、float16,数据类型与input保持一致,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
slot_mapping(Tensor):必选输入,表示每个batch尾部压缩数据存储的位置的索引,shape支持[batch_num],数据类型支持int32,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor。
compress_block_size(int):必选输入,压缩滑窗的大小。
compress_stride(int):必选输入,两次压缩滑窗间隔大小。
page_block_size(int):必选输入,page_attention场景下page的block_size大小。
block_table(Tensor):可选输入,page_attention场景下kv缓存使用的block映射表,不支持非连续的Tensor。
actual_seq_len(list[int]):必选输入,表示每个batch对应的token的长度。
cache(Tensor):必选输入,推理场景下的kv缓存,支持非连续的Tensor,不支持空Tensor。
输出说明
代表对KV压缩计算后的结果。
约束说明
input和weight满足broadcast关系,input的第三维大小与weight的第二维大小相等。
compress_block_size、compress_stride 必须是16的整数倍,且compress_block_size>=compress_stride,compress_block_size <= 64。
actual_seq_len目前仅支持取值1。
page_block_size只能是64或者128。
headDim是16的整数倍,且headDim <= 256。
需保证slotMapping的值无重复,否则会导致计算结果不稳定。
blockTable的值不应超过blockNum,否则会发生越界。
actual_seq_len的值不应该超过最大序列长度。
headNum <= 64,且headNum>50时headNum%2=0。
支持的型号
Atlas A2训练系列产品
调用示例
>>> import torch
>>> import torch_npu
>>> input = torch.randn(1, 128, 1, 192, dtype=torch.float16).npu()
>>> weight = torch.randn(32, 1, dtype=torch.float16).npu()
>>> slot_mapping = torch.randn([1]).int().npu()
>>> compress_block_size = 32
>>> compress_stride = 16
>>> page_block_size = 128
>>> act_seq_lens = [43]
>>> block_table = torch.randn([1, 1]).int().npu()
>>> cache = torch.zeros([1, 1, 192],dtype=torch.float16).npu()
>>> torch_npu.npu_nsa_compress_infer(input, weight,slot_mapping,compress_block_size,compress_stride,page_block_size,actual_seq_len=act_seq_lens,block_table=block_table,cache=cache)
"""
)
_add_torch_npu_docstr(
"npu_nsa_compress_attention",
"""
torch_npu.npu_nsa_compress_attention(query, key, value, scale_value, head_num, compress_block_size, compress_stride, select_block_size, select_block_count, topk_mask=None, atten_mask=None, actual_seq_qlen=None, actual_cmp_seq_kvlen=None, actual_sel_seq_kvlen=None)
功能描述
实现Native Sparse Attention算法中训练场景下的压缩注意力功能。
参数说明
query(Tensor):必选参数,shape支持[T,N,D],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
key(Tensor):必选参数,shape支持[T,N2,D],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
value(Tensor):必选参数,shape支持[T,N2,D2],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
scale_value(double):必选参数,表示缩放系数,一般设置为D^-0.5。
head_num(int):必选参数,表示query的head个数。
compress_block_size(int):必选参数,压缩滑窗的大小。
compress_stride(int):必选参数,两次压缩滑窗间隔大小。
select_block_size(int):必选参数,表示select窗口的大小。
select_block_count(int):必选参数,表示select窗口的数量。
topk_mask(Tensor):可选参数,shape支持[S,S],SS分别是max_sq和max_skv,数据类型支持bool。
atten_mask(Tensor):可选参数,取值为1代表该位不参与计算(不生效),为0代表该位参与计算,数据类型支持bool,数据格式支持ND,输入shape类型支持[S,S]格式,SS分别是maxSq和maxSkv。
actual_seq_qlen(list[int]):必选参数,长度表示query有多少个batch,值表示各batch的token长度的前缀和,例如,actual_seq_qlen[0]=s0,actual_seq_qlen[1]=s0+s1,...,actual_seq_qlen[-1]=T。
actual_cmp_seq_kvlen(list[int]):必选参数,长度表示compress attention的key或value有多少个batch,值表示各batch的token长度的前缀和,例如,actual_cmp_seq_kvlen[0]=cmp_skv[0],actual_cmp_seq_kvlen[1]=cmp_skv[0]+cmp_skv[1],...,actual_cmp_seq_kvlen[-1]=T。
actual_sel_seq_kvlen(list[int]):必选参数,长度表示select attention的key/value有多少个batch,值表示各batch的token长度的前缀和,例如,actual_sel_seq_kvlen[0]=sel_skv[0],actual_sel_seq_kvlen[1]=sel_skv[0]+sel_skv[1],...,actual_sel_seq_kvlen[-1]=T。
输出说明
Tensor:代表压缩注意力attention的结果。
Tensor:代表选择出的topk。
Tensor:代表softmax计算的max中间结果,用于反向计算。
Tensor:代表softmax计算的sum中间结果,用于反向计算。
约束说明
compress_block_size、compress_stride、select_block_size必须是16的整数倍;且compress_block_size >= compress_stride,select_block_size >= compress_block_size,select_block_size % compress_stride == 0;selectBlockCount <= selKvLen。
目前仅支持compress_block_size=32, compress_stride=16, select_block_size=64, select_block_count=16。
cmp_skv[i] <= 14000。
sel_skv[i] = CeilDiv(cmp_skv[i], select_block_size // compress_stride)。
query、key、value的数据类型必须一致。
query、key、value的B:batchsize必须相等。
query、key、value的D:Head-Dim必须满足(qD == kD && kD >= vD)。
query、key、value的input_layout属性必须一致。
query、key、value的N:qN >= kN && kN == vN,qN与kN必须成比例关系,即qN / kN必须是非0整数。
G=qN / kN, G必须满足:G<128 && 128 % G == 0。
SparseMode:当前仅支持1;attenMask可传入[masS1, maxCmpS2]的下三角或none,topkMask可传入[maxS1, maxSelS2]的对角线或none(attenMask和topkMask数据填充也必须符合约束)。
支持的型号
Atlas A2训练系列产品
调用示例
>>> import torch
>>> import torch_npu
>>> query = torch.randn(65536, 64, 192, dtype=torch.bfloat16).npu()
>>> key = torch.randn(4096, 4, 192, dtype=torch.bfloat16).npu()
>>> value = torch.randn(4096, 4, 128, dtype=torch.bfloat16).npu()
>>> scale_value = 1 / (192**0.5)
>>> head_num = 64
>>> compress_block_size = 32
>>> compress_stride = 16
>>> select_block_size = 64
>>> select_block_count = 16
>>> actual_seq_qlen = [65536]
>>> actual_cmp_seq_kvlen = [4096]
>>> actual_sel_seq_kvlen = [1024]
>>> torch_npu.npu_nsa_compress_attention(query, key, value, scale_value, head_num, compress_block_size, compress_stride, select_block_size, select_block_count, actual_seq_qlen=actual_seq_qlen, actual_cmp_seq_kvlen=actual_cmp_seq_kvlen, actual_sel_seq_kvlen=actual_sel_seq_kvlen)
"""
)
_add_torch_npu_docstr(
"npu_nsa_compress_attention_infer",
"""
torch_npu.npu_nsa_compress_attention_infer(query, key, value, scale_value, head_num, key_value_head_num, select_block_size, select_block_count, page_block_size, compress_block_size, compress_stride, layout='TND', atten_mask=None, block_table=None, topk_mask=None, actual_seq_qlen=None, actual_cmp_seq_kvlen=None, actual_sel_seq_kvlen=None)
功能描述
Native Sparse Attention算法中推理场景下,实现对KV压缩的计算。
参数说明
query(Tensor):必选输入,layout为TND时,shape支持3维输入,为[batch, key_value_head_num * group_size, head_size_qk],layout为BSND时,shape支持4维输入,为[batch, query_seq_len, key_value_head_num * group_size, head_size_qk],数据类型支持bfloat16、float16,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor,不支持inf,nan。
key(Tensor):必选输入,shape支持3维输入,为[block_num, page_block_size, head_size_qk * key_value_head_num],数据类型支持bfloat16、float16,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor,不支持inf,nan。
value(Tensor):必选输入,shape支持3维输入,为[block_num, page_block_size, head_size_v * key_value_head_num],数据类型支持bfloat16、float16,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor,不支持inf,nan。
scale_value(double):必选输入,表示缩放系数。
head_num(int):必选输入,表示query的head个数。
key_value_head_num(int):必选输入,表示key或者value的head个数。
select_block_size(int):必选输入,表示选择窗口的大小。
select_block_count(int):必选输入,表示选择窗口的数量。
page_block_size**(int):必选输入,page_attention场景下page的block_size大小。
compress_block_size**(int):必选输入,压缩滑窗的大小。
compress_stride**(int):必选输入,两次压缩滑窗间隔大小。
layout(str):可选输入,表示输入的数据排布格式,支持TND、BSND,默认为TND。
atten_mask(Tensor):可选输入,当前不支持。
block_table**(Tensor):可选输入,shape支持2维输入,数据类型支持‘int32’,page_attention场景下kv缓存使用的block映射表,不支持非连续的Tensor,不支持空tensor,不支持inf,nan。
topk_mask**(Tensor):可选输入,当前不支持。
actual_seq_qlen(list[int]):可选输入,当前不支持。
actual_cmp_seq_kvlen(list[int]):必选输入,表示压缩注意力的key/value的每个S的长度。
actual_sel_seq_kvlen(list[int]):可选输入,当前不支持。
输出说明
代表对KV压缩计算后的结果。
约束说明
- query的数据排布格式中,T代表B(Batch)与S(Seq-Length)合轴后的结果、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
- key和value的数据排布格式当前(paged attention)支持(block_num, block_size, H),H(Head-Size)表示隐藏层的大小,H=N∗D。
- 参数query中的N和head_num值相等,key、value的N和key_value_head_num值相等,并且head_num是key_value_head_num的倍数关系。
- 参数query中的D和key的D(H/key_value_head_num)值相等。
- 参数query中的B、block_table的B、actual_cmp_seq_kvlen的shape值相等,B取值范围1-20。
- 参数key中的block_num和参数value中的block_num值相等。
- 参数key中的block_size、参数value中的block_size和page_block_size值相等。
- query,key,value输入,功能使用限制如下:
- 支持query的N轴必须是key/value的N轴(H/D)的整数倍。
- 支持query的N轴与key/value的N轴(H/D)的比值小于等于128,且128是group的整数倍。
- 支持query与Key的D轴小于等于192,scale_value取值D^-0.5。
- 支持value的D轴小于等于128。
- 支持query与Key的D轴大于等于value的D轴。
- 支持key与value的block_size小于等于128,且是16的整数倍。
- 仅支持query的S轴等于1。
- 仅支持paged attention。
- 仅支持key/value的S轴小于等于8192。
- 仅支持compress_block_size取值16、32、48、64、80、96、112、118。
- 仅支持compress_stride取值16、32、48、64。
- 仅支持select_block_size取值16、32、48、64、80、96、112、118。
- 仅支持compress_block_size大于等于compress_stride , select_block_size大于等于compress_block_size , select_block_size是compress_stride的整数倍。
- 压缩前的kv_seq_len的上限可以表示为:no_cmp_kv_seq_len_ceil = (cmp_kv_seq_len − 1) ∗ compress_block_stride + compress_block_size,需要满足no_cmp_kv_seq_len_ceil / select_block_size <= 4096,且需要满足select_block_count <= no_cmp_kv_seq_len_ceil / select_block_size。
- block_size第2维的取值需满足公式(max(cmp_kv_seq_len) + page_block_size - 1) // page_block_size。
- block_num的取值需满足公式B * (max(cmp_kv_seq_len) + page_block_size - 1) // page_block_size。
- block_table的取值范围需满足[0, block_num]。
- query,key,value的数据类型需保持一致。
- actual_cmp_seq_kvlen的取值范围为[128, 4096]。
支持的型号
Atlas A2训练系列产品
调用示例
>>> import torch
>>> import torch_npu
>>> query = torch.randn([1, 32, 65], dtype=torch.float16).npu()
>>> key = torch.randn([25, 48, 65], dtype=torch.float16).npu()
>>> value = torch.randn([25, 48, 18], dtype=torch.float16).npu()
>>> scale_value = 0.01
>>> head_num = 32
>>> key_value_head_num = 1
>>> select_block_size = 32
>>> select_block_count = 397
>>> page_block_size = 48
>>> compress_block_size = 32
>>> compress_stride = 16
>>> block_table = torch.tensor([[23, 2, 20, 22, 4, 21, 7, 12, 3, 20, 20, 0, 15, 0, 4, 8, 10, 20, 21, 18, 18, 18, 11, 12, 20]]).int().npu()
>>> actual_cmp_seq_kvlen = [1180]
>>> torch_npu.npu_nsa_compress_attention_infer(query, key, value, scale_value, head_num, key_value_head_num, select_block_size, select_block_count, page_block_size, compress_block_size, compress_stride, block_table=block_table, actual_cmp_seq_kvlen=actual_cmp_seq_kvlen)
"""
)
_add_torch_npu_docstr(
"npu_nsa_select_attention",
"""
torch_npu.npu_nsa_select_attention(query, key, value, topk_indices, scale_value, head_num, select_block_size, select_block_count, atten_mask=None, actual_seq_qlen=None, actual_seq_kvlen=None)
功能描述
实现Native Sparse Attention算法中训练场景下选择注意力的计算。
参数说明
query(Tensor):必选参数,shape支持[T1,N1,D1],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
key(Tensor):必选参数,shape支持[T2,N2,D1],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
value(Tensor):必选参数,shape支持[T2,N2,D2],数据类型支持bfloat16、float16,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
topk_indices(Tensor):必选参数,shape为[T1, N2, select_block_count],数据类型支持int32,数据格式支持ND,支持非连续的Tensor,不支持空Tensor。
scale_value(double):必选参数,表示缩放系数,一般设置为D^-0.5。
head_num(int):必选参数,表示单卡的head个数,即query的N1轴长度。
select_block_size(int):必选参数,表示select窗口的大小。
select_block_count(int):必选参数,表示select窗口的数量。
atten_mask(Tensor):可选参数,当前暂不支持。
actual_seq_qlen(list[int]):必选参数,长度表示query有多少个batch,值表示各batch的token长度的前缀和,例如,actual_seq_qlen[0]=s0,actual_seq_qlen[1]=s0+s1,...,actual_seq_qlen[-1]=T1。
actual_seq_kvlen(list[int]):必选参数,,长度表示key或value有多少个batch,值表示各batch的token长度的前缀和,例如,actual_seq_kvlen[0]=s0,actual_seq_kvlen[1]=s0+s1,...,actual_seq_kvlen[-1]=T2。
输出说明
Tensor:代表经过选择后的注意力attention结果。
Tensor:代表softmax计算的max中间结果,用于反向计算。
Tensor:代表softmax计算的sum中间结果,用于反向计算。
约束说明
1. 输入query、key、value的batchsize必须相等,即要求传入的actual_seq_qlen和actual_seq_kvlen具有相同的长度。
2. 输入query、key、value的D(head_dim)必须满足D_q == D_k,D_k >= D_v。
3. 输入query、key、value的数据类型必须一致。
4. 输入query、key、value的input_layout必须一致,且只支持TND。
5. select_block_size目前仅支持64,与此对应的select_block_count为16。
6. topk_indices必须大于等于0且小于等于B对应的S2 / 64。
7. 支持输入query的N和key/value的N不相等,但必须成比例关系,即N_q / N_kv必须是非0整数,称为G(group),且需满足G <= 32
- B(batchsize):取值范围为1\~65536。
- N(head_num):取值范围为1\~128。
- G(group):取值范围为1\~32。
- S(seq_length):取值范围为1\~128K。且对于KV的S >= select_block_size * select_block_count,且为select_block_size的倍数。
- D(head_dim):D_qk=192,D_v=128。
支持的型号
Atlas A2训练系列产品
调用示例
>>> import torch
>>> import torch_npu
>>> import numpy as np
>>> query = torch.randn(256, 16, 192, dtype=torch.float16).npu()
>>> key = torch.randn(3072, 4, 192, dtype=torch.float16).npu()
>>> value = torch.randn(3072, 4, 128, dtype=torch.float16).npu()
>>> topk_indices = torch.randn(256, 4, 16).int().npu()
>>> scale_value = 1.0
>>> head_num = 16
>>> select_block_size = 64
>>> select_block_count = 16
>>> atten_mask = torch.randn(512, 2048).bool().npu()
>>> actual_seq_qlen = [128, 256]
>>> actual_seq_kvlen = [2048, 3072]
>>> torch_npu.npu_nsa_select_attention(query, key, value, topk_indices, scale_value, head_num, select_block_size, select_block_count, atten_mask=atten_mask, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen)
"""
)
_add_torch_npu_docstr(
"npu_nsa_select_attention_infer",
"""
torch_npu.npu_nsa_select_attention_infer(query, key, value, topk_indices, scale_value, head_num, key_value_head_num, select_block_size, select_block_count, page_block_size, layout='BSND', atten_mask=None, block_table=None, actual_seq_qlen=None, actual_seq_kvlen=None)
功能描述
Native Sparse Attention算法中推理场景下,实现选择注意力的计算。
参数说明
query (Tensor):必选输入,shape支持3维或者4维,数据类型支持bfloat16、float16,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor。
key (Tensor):必选输入,shape支持3维或者4维,数据类型支持bfloat16、float16,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor。
value (Tensor):必选输入,shape支持3维或者4维,数据类型支持bfloat16、float16,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor。
topk_indices (Tensor):必选输入,shape为[batch_size, key_value_head_num, select_block_count],数据类型支持int32,数据格式支持ND,不支持非连续的Tensor,不支持空Tensor。
scale_value (double):必选输入,表示缩放系数。
head_num (int):必选输入,表示query的head个数。
key_value_head_num (int):必选输入,表示key或者value的head个数。
select_block_size (int):必选输入,表示选择窗口的大小。
select_block_count (int):必选输入,表示选择窗口的数量。
page_block_size(int):必选输入,page_attention场景下page的block_size大小。
atten_mask (Tensor):可选输入,当前暂不支持。
block_table(Tensor):可选输入,page_attention场景下kv缓存使用的block映射表,数据类型支持int32,不支持非连续的Tensor,不支持空tensor。
layout(str):可选输入,表示输入的数据排布格式,支持BSH、BSND、TND,默认为BSND。
actual_seq_qlen(list[int]):可选输入,当前暂不支持。
actual_seq_kvlen(list[int]):必选输入,表示key或value每个S的长度。
输出说明
代表经过选择后的注意力结果。
约束说明
query的数据排布格式中,B即Batch,S即Seq-Length,N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。key和value的数据排布格式当前(paged attention场景)支持(block_num, block_size, H)或(block_num, block_size, N, D),H(Head-Size)表示隐藏层的大小,H = N * D。
参数query中的N和head_num值相等,key、value的N和key_value_head_num值相等,并且head_num是key_value_head_num的倍数关系。
参数query中的D和key的D(H/key_value_head_num)值相等。
query,key,value输入,功能使用限制如下:
支持B轴小于等于3072;
支持key/value的N轴(H/D)小于等于256;
支持query的N轴与key/value的N轴(H/D)的比值小于等于16;
支持query与key的D轴等于192;
支持value的D轴等于128;
支持query与key的block_size小于等于64或128;
仅支持query的S轴等于1。
仅支持paged attention。
仅支持select_block_size取值为16的整数倍。
selectBlockCount上限满足select_block_count * select_block_size <= MaxKvSeqlen,MaxKvSeqlen = Max(actual_seq_kvlen)。
支持的型号
Atlas A2训练系列产品
调用示例
>>> import torch
>>> import torch_npu
>>> query = torch.randn([1, 1, 768], dtype=torch.float16).npu()
>>> key = torch.randn([246, 64, 384], dtype=torch.float16).npu()
>>> value = torch.randn([246, 64, 256], dtype=torch.float16).npu()
>>> topk_indices = torch.tensor([[[0, -1], [0, -1]]], device="npu", dtype=torch.int32)
>>> block_table = torch.tensor([[1, 0]], device="npu", dtype=torch.int32)
>>> scale_value = 2.0
>>> head_num = 4
>>> key_value_head_num = 2
>>> select_block_size = 64
>>> select_block_count = 2
>>> page_block_size = 64
>>> layout = 'BSH'
>>> actual_seq_qlen = None
>>> actual_seq_kvlen = [82] * query.size(0)
>>> atten_mask = None
>>> torch_npu.npu_nsa_select_attention_infer(query, key, value, topk_indices, scale_value, head_num, key_value_head_num, select_block_size, select_block_count, page_block_size, layout=layout, atten_mask=atten_mask, block_table=block_table, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen)
"""
)
_add_torch_npu_docstr(
"npu_gather_sparse_index",
"""
接口原型:
torch_npu.npu_gather_sparse_index(input, index) -> torch.Tensor
功能描述:
从输入Tensor的指定维度dim,按照index中的下标序号提取元素,保存到out Tensor中。
参数说明:
input(torch.Tensor): 输入张量,数据格式支持ND。
在Atlas A2/Atlas A3上数据类型支持torch.float32, torch.float16, torch.bfloat16, torch.int64, torch.int32, torch.int16,
torch.int8, torch.uint8, torch.bool, torch.float64, torch.complex64, torch.complex128
index(torch.Tensor): 包含目标元素下标序号的张量。数据维度不超过8维。数据类型支持torch.int64, torch.int32。取值范围[0, input.shape[0] - 1], 不支持负数索引。
输出说明:
out(torch.Tensor): 接口计算获得的结果,包含按照index中的下标序号提取的元素。数据类型与input一致,输出维度为index.dim + input.dim - 1。
例如input.shape = [16, 32], index.shape = [2, 3],则输出张量 out.shape = [2, 3, 32]
约束说明:
1. input 的维度与 index 的维度之和减1不能超过8,即index.dim + input.dim - 1<=8。
支持版本:
PyTorch 2.1
PyTorch 2.5及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
inputs = torch.randn(16, 32).npu()
index = torch.randint(0, 16, [2, 3]).npu()
out = torch_npu.npu_gather_sparse_index(inputs, index)
"""
)
_add_torch_npu_docstr(
"npu_moe_update_expert",
"""
torch_npu.npu_moe_update_expert(Tensor expert_ids, Tensor eplb_table, *, Tensor? expert_scales=None, Tensor? pruning_threshold=None, Tensor? active_mask=None, int local_rank_id=-1, int world_size=-1, int balance_mode=0) -> (Tensor, Tensor)
功能描述
完成冗余专家部署场景下每个token的topK个专家逻辑卡号到物理卡号的映射。
支持根据阈值对token发送的topK个专家进行剪枝
参数说明
expert_ids:每个token的topK个专家索引,Device侧的Tensor,要求为一个2D的Tensor,shape为 (BS, K)。数据类型支持INT32,INT64,数据格式要求为ND,支持非连续的Tensor。
eplb_table:逻辑专家到物理专家的映射表,外部调用者需保证输入Tensor的值正确:每行第一列为行号对应逻辑专家部署的实例数count,值需大于等于1,每行[1, count]列为对应实例的卡号,取值范围[0, moe_expert_num),Device侧的Tensor,要求是一个2D的Tensor,shape为(moe_expert_num, F)。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor。其中F表示输入映射表的列数,第一列为各行号对应Moe专家部署的实例个数(值>0),后F-1列为该Moe专家部署的物理卡号,取值范围[2, world_size+1]。
expert_scales:每个token的topK个专家的scale权重,用户需保证scale在token内部按照降序排列,可选择传入有效数据或空指针,该参数传入有效数据时,pruning_threshold也需要传入有效数据。Device侧的Tensor,要求是一个2D的Tensor,shape为 (BS, K)。数据类型支持FP16、BF16、FLOAT,数据格式要求为ND,支持非连续的Tensor。
pruning_threshold:专家scale权重的最小阈值,当某个token对应的某个topK专家scale小于阈值时,该token将对该专家进行剪枝,即token不发送至该专家处理,可选择传入有效数据或空指针,该参数传入有效数据时,expert_scales也需要传入有效数据。Device侧的Tensor,要求是一个1D或2D的Tensor,shape为(K,)或(1, K)。数据类型支持FLOAT,数据格式要求为ND,支持非连续的Tensor。
active_mask:表示token是否参与通信,可选择传入有效数据或空指针。传入有效数据时,expert_scales、pruning_threshold也必须传入有效数据,参数为true表示对应的token参与通信,true必须排到false之前,例:{true, false, true}为非法输入;传入空指针时是表示所有token都会参与通信。Device侧的Tensor,要求是一个1D的Tensor,shape为(BS,)。数据类型支持bool,数据格式要求为ND,支持非连续的Tensor。
local_rank_id:本卡ID,数据类型支持INT64,当balance_mode设置0时,本属性取值范围为[0, world_ize)。
world_size:通信域size,数据类型支持INT64,当balance_mode设置0时,本属性取值范围为[2, 768]
balance_mode:均衡规则,数据类型支持INT64,取值支持0和1,0表示用local_rank_id进行负载均衡,1表示使用token_id进行负载均衡。当本属性取值为0时,local_rank_id和world_size必须传入有效值。
输出说明
balanced_expert_ids:映射后每个token的topK个专家所在物理卡的卡号,Device侧的Tensor,要求是一个2D的Tensor,shape为(BS, K),数据类型、数据格式与expert_ids保持一致。
balanced_active_mask:剪枝后的active_mask,当expert_scales、pruning_threshold传入有效数据时该输出有效。Device侧的Tensor,要求是一个2的Tensor,shape为(BS, K),数据类型支持BOOL,数据格式要求为ND,支持非连续的Tensor。
支持的型号
Atlas A3训练系列产品
调用示例
import os
import math
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestMoeUpdateExpert(TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.bs = 128
self.k = 8
self.log_ep_size = 256
self.pyh_ep_size = 8
self.F = 5
self.is_pruning = True
self.world_size = 8
self.balance_mode = 0
self.expert_ids = []
self.eplb_table = []
self.expert_scales = []
self.pruning_threshold = []
self.active_mask = []
self.balanced_expert_ids = []
self.balanced_active_mask = []
self.gen_exp_result()
@classmethod
def _init_dist_hccl(cls, rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '50000'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
return dist
@classmethod
def _test_npu_moe_update_expert(cls, rank_id, input_list):
expert_ids, eplb_table, world_size, expert_scales, pruning_threshold, active_mask, balance_mode, init_pg, c2p, p2c = input_list
_ = init_pg(rank_id, world_size)
out_expert_idx, out_mask = torch_npu.npu_moe_update_expert(expert_ids=expert_ids.npu(),
eplb_table=eplb_table.npu(),
local_rank_id=rank_id,
world_size=world_size,
expert_scales=expert_scales.npu(),
pruning_threshold=pruning_threshold.npu(),
active_mask=active_mask.npu(),
balance_mode=balance_mode)
c2p.put((rank_id, out_expert_idx.cpu(), out_mask.cpu()))
p2c.get()
def gen_exp_result(self):
for rank_id in range(self.world_size):
eplb_table = np.zeros((self.log_ep_size, self.F - 1))
count_column = np.random.randint(1, self.F, size=(self.log_ep_size, 1))
all_ranks = np.arange(self.pyh_ep_size)
for i in range(self.log_ep_size):
np.random.shuffle(all_ranks)
for j in range(count_column[i][0]):
eplb_table[i][j] = all_ranks[j]
eplb_table = np.hstack((count_column, eplb_table))
expert_ids = np.random.randint(low=0, high=self.log_ep_size, size=(self.bs, self.k))
if self.is_pruning:
expert_scales = -np.sort(-np.random.uniform(low=0, high=0.25, size=(self.bs, self.k)), axis=1)
pruning_threshold = np.random.uniform(low=0, high=0.15, size=(1, self.k))
num_true = np.random.randint(0, self.bs + 1)
active_mask = np.concatenate([np.ones(num_true, dtype=bool), np.zeros(self.bs - num_true, dtype=bool)])
eplb_table_tensor = torch.from_numpy(eplb_table).to(torch.int32)
self.eplb_table.append(eplb_table_tensor)
expert_ids_tensor = torch.from_numpy(expert_ids).to(torch.int32)
self.expert_ids.append(expert_ids_tensor)
if self.is_pruning:
expert_scales_tensor = torch.from_numpy(expert_scales).to(torch.float32)
self.expert_scales.append(expert_scales_tensor)
pruning_threshold_tensor = torch.from_numpy(pruning_threshold).to(torch.float32)
self.pruning_threshold.append(pruning_threshold_tensor)
active_mask_tensor = torch.from_numpy(active_mask).to(torch.bool)
self.active_mask.append(active_mask_tensor)
balanced_expert_ids = np.zeros((self.bs, self.k))
if self.is_pruning:
balanced_active_mask = np.zeros((self.bs, self.k))
for i in range(self.bs):
for j in range(self.k):
log_ep_id = expert_ids_tensor[i][j]
if self.balance_mode == 0:
mod_val = math.ceil(self.world_size / eplb_table_tensor[log_ep_id][0].item())
phy_ep_id = eplb_table_tensor[log_ep_id][(rank_id // mod_val) + 1]
balanced_expert_ids[i][j] = phy_ep_id
if self.balance_mode == 1:
phy_ep_id = eplb_table_tensor[log_ep_id][(i % eplb_table_tensor[log_ep_id][0].item()) + 1]
balanced_expert_ids[i][j] = phy_ep_id
if self.is_pruning:
if not active_mask_tensor[i]:
balanced_active_mask[i][j] = 0
else:
if expert_scales_tensor[i][j] < pruning_threshold_tensor[0][j] * sum(expert_scales_tensor[i]):
balanced_active_mask[i][j] = 0
else:
balanced_active_mask[i][j] = 1
self.balanced_expert_ids.append(torch.from_numpy(balanced_expert_ids).to(torch.int64))
self.balanced_active_mask.append(torch.from_numpy(balanced_active_mask).to(torch.bool))
@SupportedDevices(['Ascend910_93', 'Ascend950'])
def test_npu_moe_update_expert(self):
ctx = mp.get_context('spawn')
c2p = ctx.Queue(self.world_size)
p2c = ctx.Queue(self.world_size)
ps = []
for rank_id in range(self.world_size):
p = ctx.Process(
target=self._test_npu_moe_update_expert,
args=(rank_id, [self.expert_ids[rank_id], self.eplb_table[rank_id], self.world_size,
self.expert_scales[rank_id], self.pruning_threshold[rank_id], self.active_mask[rank_id],
self.balance_mode, self._init_dist_hccl, c2p, p2c]))
p.start()
ps.append(p)
for _ in range(self.world_size):
rank_id, output_0, output_1 = c2p.get()
self.assertEqual(output_0, self.balanced_expert_ids[rank_id],
("rank {} Expect receive tensor {} but got {}.").format(rank_id, self.balanced_expert_ids[rank_id], output_0))
self.assertEqual(output_1, self.balanced_active_mask[rank_id],
("rank {} Expect receive tensor {} but got {}.").format(rank_id, self.balanced_active_mask[rank_id], output_1))
for _ in range(self.world_size):
p2c.put(0)
for p in ps:
p.join()
if __name__ == '__main__':
run_tests()
"""
)
_add_torch_npu_docstr(
"npu_top_k_top_p",
"""
接口原型:
torch_npu.npu_top_k_top_p(logits, p, k) -> torch.Tensor
功能描述:
对原始输入logits进行top-k和top-p采样过滤
计算公式:
1. 对输入logits按最后一轴进行升序排序,得到对应的排序结果sortedValue和sortedIndices。
sortedValue, sortedIndices = sort(logits, dim=-1, descend=false, stable=true)
2. 计算保留的阈值(第k大的值)。
topKValue[b][v] = sortedValue[b][sortedValue.size(1) - k[b]]
3. 生成top-k需要过滤的mask。
topKMask = sortedValue < topKValue
4. 通过topKMask将小于阈值的部分置为-inf。
sortedValue[b][v] =
-inf if topKMask[b][v] == true else sortedValue[b][v]
5. 通过softmax将经过top-k过滤后的数据按最后一轴转换为概率分布。
probsValue = softmax(sortedValue, dim=-1)
6. 按最后一轴计算累计概率(从最小的概率开始累加)。
probsSum = cumsum(probsValue, dim=-1)
7. 生成top-p的mask,累计概率小于等于1-p的位置需要过滤掉,并保证每个batch至少保留一个元素。
topPMask[b][v] = probsSum[b][v] <= 1-p[b]
topPMask[b][-1] = false
8. 通过topPMask将小于阈值的部分置为-inf。
sortedValue[b][v] =
-inf if topPMask[b][v] == true else sortedValue[b][v]
9. 将过滤后的结果按sortedIndices还原到原始顺序。
out[b][v] = sortedValue[b][sortedIndices[b][v]]
其中 0 <= b < logits.size(0), 0 <= v < logits.size(1)。
参数说明:
logits(torch.Tensor): 输入张量,支持2维,数据类型支持torch.bfloat16, torch.float16, torch.float32。
p(torch.Tensor): 可选张量,默认值为None,不支持p和k同时传None。表示top-p的阈值,值域为[0, 1],数据类型支持torch.bfloat16, torch.float16, torch.float32,数据类型需要与logits一致,shape支持1维且需要与logits的首轴相同,支持非连续Tensor,支持空tensor,支持ND
k(torch.Tensor): 可选张量,默认值为None,不支持p和k同时传None。表示top-k的阈值,值域为[1, 1024],且最大值需要小于等于logits.size(1),数据类型支持torch.int32,shape支持1维且需要与logits的首轴相同,支持非连续Tensor,支持空tensor,支持ND
输出说明:
out(torch.Tensor): 表示过滤后的数据。数据类型支持torch.bfloat16, torch.float16, torch.float32,数据类型需要与logits一致,shape支持2维且需要与logits一致,支持非连续Tensor,数据格式支持ND
约束说明:
无
支持版本:
PyTorch 2.1
PyTorch 2.5及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
logits = torch.randn(16, 2048).npu()
p = torch.rand(16).npu()
k = torch.randint(10, 1024, (16,)).npu().to(torch.int32)
out = torch_npu.npu_top_k_top_p(logits, p, k)
"""
)
_add_torch_npu_docstr(
"ffn_worker_scheduler_",
"""
接口原型:
ffn_worker_scheduler_(Tensor(a!) self, *, int sync_group_size=1, int execute_mode=0) -> Tensor(a!)
功能描述:
Attention和Ffn分离部署场景下,Ffn侧数据扫描功能,扫描并原地完成数据整理。
参数说明:
scheduler_context(torch.Tensor): 输入张量, scheduler_context的定义与生成可参照torch_npu._afd包;
sync_group_size(int): 可选,默认值为1;
execute_mode(int): 可选,默认值为0。
约束说明:
无
支持版本:
PyTorch 2.1及更高版本
支持的型号:
Atlas A3训练系列产品
Atlas A3推理系列产品
调用示例:
import torch
import torch_npu
import os
window_size = 209715200
ffn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()
attn_workers = 2
micro_batch_number = 3
batch_size = 6
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
ffn_window = ffn_window_tensor.data_ptr()
context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode = 0, session_num = attn_workers,
micro_batch_num = micro_batch_number, micro_batch_size = batch_size, selected_expert_num = top_k + 1,
expert_num = expert_num, attn_to_ffn_token_size = attn_to_ffn_token_size, ffn_to_attn_token_size = ffn_to_attn_token_size,
ffn_window = ffn_window, ffn_window_size = window_size)
schedule_context = context_holder.get_schedule_context_tensor()
def _set_all_flags():
num_int8 = attn_workers * micro_batch_number * (8 + batch_size * top_k * 4)
per_session_num = micro_batch_number * (8 + batch_size * top_k * 4)
int32_view = ffn_window_tensor[:num_int8].view(torch.int32)
int32_view[:] = 1
_set_all_flags()
torch_npu.ffn_worker_scheduler_(schedule_context, sync_group_size = 2)
"""
)
_add_torch_npu_docstr(
"ffn_worker_scheduler",
"""
接口原型:
ffn_worker_scheduler(Tensor self, *, int sync_group_size=1, int execute_mode=0) -> Tensor
功能描述:
Attention和Ffn分离部署场景下,Ffn侧数据扫描功能,扫描并完成数据整理输出。
参数说明:
scheduler_context(torch.Tensor): 输入张量, scheduler_context的定义与生成可参照torch_npu._afd包;
sync_group_size(int): 可选,默认值为1;
execute_mode(int): 可选,默认值为0。
输出说明:
scheduler_context(torch.Tensor): 输出结果张量
约束说明:
无
支持版本:
PyTorch 2.1及更高版本
支持的型号:
Atlas A3训练系列产品
Atlas A3推理系列产品
调用示例:
import torch
import torch_npu
import os
window_size = 209715200
ffn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()
attn_workers = 2
micro_batch_number = 3
batch_size = 6
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
ffn_window = ffn_window_tensor.data_ptr()
context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode = 0, session_num = attn_workers,
micro_batch_num = micro_batch_number, micro_batch_size = batch_size, selected_expert_num = top_k + 1,
expert_num = expert_num, attn_to_ffn_token_size = attn_to_ffn_token_size, ffn_to_attn_token_size = ffn_to_attn_token_size,
ffn_window = ffn_window, ffn_window_size = window_size)
schedule_context = context_holder.get_schedule_context_tensor()
def _set_all_flags():
num_int8 = attn_workers * micro_batch_number * (8 + batch_size * top_k * 4)
per_session_num = micro_batch_number * (8 + batch_size * top_k * 4)
int32_view = ffn_window_tensor[:num_int8].view(torch.int32)
int32_view[:] = 1
_set_all_flags()
schedule_context_out = torch_npu.ffn_worker_scheduler(schedule_context, sync_group_size = 2)
"""
)
_add_torch_npu_docstr(
"attention_worker_scheduler_",
"""
接口原型:
attention_worker_scheduler_(Tensor(a!) self) -> Tensor(a!)
功能描述:
Attention和Ffn分离部署场景下,Attention侧数据扫描功能,扫描并原地确保数据就绪。
参数说明:
scheduler_context(torch.Tensor): 输入张量, scheduler_context的定义与生成可参照torch_npu._afd包。
约束说明:
无
支持版本:
PyTorch 2.1及更高版本
支持的型号:
Atlas A3训练系列产品
Atlas A3推理系列产品
调用示例:
import torch
import torch_npu
import os
window_size = 209715200
attn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()
attn_workers = 144
micro_batch_number = 3
batch_size = 30
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
attn_window = attn_window_tensor.data_ptr()
context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode = 1, session_num = attn_workers,
micro_batch_num = micro_batch_number, micro_batch_size = batch_size, selected_expert_num = top_k + 1,
expert_num = expert_num, attn_to_ffn_token_size = attn_to_ffn_token_size, ffn_to_attn_token_size = ffn_to_attn_token_size,
attention_window = attn_window, attention_window_size = window_size)
schedule_context = context_holder.get_schedule_context_tensor()
def _set_all_flags():
num_int8 = batch_size * (top_k + 1) * 4 * micro_batch_number
int32_view = attn_window_tensor[:num_int8].view(torch.int32)
int32_view[:] = 1
_set_all_flags()
torch_npu.attention_worker_scheduler_(schedule_context)
"""
)
_add_torch_npu_docstr(
"attention_worker_scheduler",
"""
接口原型:
attention_worker_scheduler(Tensor self) -> Tensor
功能描述:
Attention和Ffn分离部署场景下,Attention侧数据扫描功能,扫描并确保数据就绪。
参数说明:
scheduler_context(torch.Tensor): 输入张量, scheduler_context的定义与生成可参照torch_npu._afd包。
输出说明:
scheduler_context(torch.Tensor): 输出结果张量
约束说明:
无
支持版本:
PyTorch 2.1及更高版本
支持的型号:
Atlas A3训练系列产品
Atlas A3推理系列产品
调用示例:
import torch
import torch_npu
import os
window_size = 209715200
attn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()
attn_workers = 144
micro_batch_number = 3
batch_size = 30
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
attn_window = attn_window_tensor.data_ptr()
context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode = 1, session_num = attn_workers,
micro_batch_num = micro_batch_number, micro_batch_size = batch_size, selected_expert_num = top_k + 1,
expert_num = expert_num, attn_to_ffn_token_size = attn_to_ffn_token_size, ffn_to_attn_token_size = ffn_to_attn_token_size,
attention_window = attn_window, attention_window_size = window_size)
schedule_context = context_holder.get_schedule_context_tensor()
def _set_all_flags():
num_int8 = batch_size * (top_k + 1) * 4 * micro_batch_number
int32_view = attn_window_tensor[:num_int8].view(torch.int32)
int32_view[:] = 1
_set_all_flags()
schedule_context_out = torch_npu.attention_worker_scheduler(schedule_context)
"""
)
_add_torch_npu_docstr(
"npu_top_k_top_p_sample",
"""
接口原型:
torch_npu.npu_top_k_top_p_sample(logits, top_k, top_p, q=None, min_ps=None, eps=1e-8, is_need_logits=False, top_k_guess=32, ks_max=1024, input_is_logits=True, post_sample='qSample', generator=None) -> (Tensor, Tensor)
功能描述:
根据输入词频logits、top_k/top_p/min_ps采样参数、随机采样权重分布q,进行topK-topP-minP-Sample采样计算,输出每个batch的最大词频logits_select_idx,以及topK-topP采样后的词频分布logits_top_kp_select。
算子包含4个可单独使能,但上下游处理关系保持不变的采样算法(从原始输入到最终输出):topK采样、topP采样、minP显著性采样、不采样 / 指数采样 / 多项式随机采样 。目前支持以下12种计算场景。如下表所示:
| 计算场景 | topK采样 | topP采样 | minP采样 | 后继处理 |备注|
| :-------:| :------:|:-------:|:-------:|:-------:|:-------:|
|Argmax采样|×|×|×|None|对输入logits每个batch取最大logits和对应索引,结果作为logits_select_idx[batch,1]。|
|topK采样|√|×|×|None|无|
|topP采样|×|√|×|None|无|
|qSample采样|×|×|×|qSample|对输入logits每个batch使用q[i]进行指数采样,从结果中取最大值和索引,作为logits_select_idx[batch,1]。|
|topK-topP采样|√|√|×|None|无|
|topK-qSample采样|√|×|×|qSample|无|
|topK-multiNomial采样|√|×|×|multiNomial|无|
|topK-minP-multiNomial采样|√|×|√|multiNomial|无|
|topP-qSample采样|×|√|×|qSample|无|
|topK-topP-qSample采样|√|√|×|qSample|VLLM框架标准完整功能。|
|topK-topP-multiNomial采样|√|√|×|multiNomial|min_ps为无效值,但仍执行多项式采样|
|topK-topP-minP-multiNomial采样|√|√|√|multiNomial|Sglang框架标准完整功能。|
计算公式:
输入logits为大小是[batch, voc_size]的词频表,其中每个batch对应一条输入序列,而voc_size则是约定每个batch的统一长度。<br>
logits中的每一行logits[batch][:]根据相应的top_k[batch]、top_p[batch]、q[batch, :]、min_ps[batch],执行不同的计算场景。<br>
下述公式中使用b和v来分别表示batch和voc_size方向上的索引。
topK采样
1. 按分段长度v采用分段topK归并排序,用{s-1}块的topK对当前{s}块的输入进行预筛选,渐进更新单batch的topK,减少冗余数据和计算。
2. top_k[batch]对应当前batch采样的k值,有效范围为1≤top_k[batch]≤min(voc_size[batch], 1024),如果top_k[batch]超出有效范围,则视为跳过当前batch的topK采样阶段,也同样会则跳过当前batch的排序,将输入logits[batch]直接传入下一模块。<br>
* 具体计算流程如下所示:
* 根据输入top_k[b]与`ks_max`的关系,判断是否进行topK采样:top_k[b]≤0,跳过topK采样;
(1)1≤top_k≤min(voc_size,ks_max),执行topK采样
(2)top_k>min(voc_size,ks_max),跳过topK采样
* 对当前batch分割为若干子段,滚动计算top_k_value[b]:
top_k_value[b]={Max(top_k[b])}_{s=1}^{lceilfrac{S}{v}rceil}{top_k_value[b]{s-1}cup{logits[b][v]getop_k_min[b][s-1]}}
Card(top_k_value[b])=top_k[b]
其中:
top_k_min[b][s]=Min(top_k_value[b]{s})
v表示预设的滚动topK时固定的分段长度:
v=8*ks\_max
ks_max有效取值范围[1,1024],默认为1024,并且需要向上对齐到8的整数倍。
* 生成需要过滤的mask:
top_k_mask = sorted_value>top_k_value
* 将小于阈值的部分通过mask置为默认无效值defLogit:
sorted_value[b][v]=begin{cases} -inf & top_k_mask[b][v]=true}
sorted_value[b][v] & top_k_mask[b][v]=false} & end{cases}
* 其中defLogits取决于入参属性Attr.optional.Bool.input_is_logits,该属性控制输入logits和输出logits_top_kp_select的归一化:
defLogit = begin{cases} -inf, & inputIsLogits = True 0, & inputIsLogits = False end{cases}
topP采样
* 根据入参约束属性Attr.optional.Bool.input_is_logits(false),如果该属性为True,则对排序后结果进行归一化:
logit_sortProb = begin{cases} softmax(logits_sort), & inputIsLogits = True
logits_sort, & inputIsLogits = False
* 根据输入top_p[b]的数值,本模块的处理策略如下:top_p[b]≤0,保留1个最大词频token
(1)0<top_p<1,执行topP采样
(2)top_p≥1,跳过topP采样
* 如果执行常规topP采样,且如果前序topK环节已有排序输出结果,则根据topK采样输出计算累积词频,并根据top_p截断采样:
topPMask[b] = begin{cases} 0, & sum_{topKMask[b]}^{} {logits_sortProb}[b][*] > p[b]
1, & sum_{topKMask[b]}^{} logits_sortProb[b][*] leq p[b] end{cases}
* 如果执行常规topP采样,但前序topK环节被跳过,则计算top-p的mask:
topPMask[b] = begin{cases} topKMask[b][0:GuessK], & sum_{GuessK}^{} probValue[b][*] ge p[b] probSum[b][v] le 1 - p[b], & others end{cases}
* 将需要过滤的位置设置为默认无效值defLogit,得到logits_sort,记为sortedValue[b][v]:
sortedValue[b][v] = begin{cases}
defLogit & \quad topPMask[b][v] = false logit_sortProb[b][v] & quad topPMask[b][v] = true end{cases}
* 取过滤后sortedValue[b][v]每行中前topK个元素,查找这些元素在输入中的原始索引,整合为logits_idx:
logitsIdx[b][v] = Index(sortedValue[b][v] in Logits)
* 从输入Logtis中按logitsIdx顺序遍历取出元素,其余位置填入defLogit,作为logitsSortMasked:
logitsSortMasked[b, :Len(logitsIdx[b][:])] = Logits[b, logitsIdx[b][:]]
logitsSortMasked[b, Len(logitsIdx[b][:]):] = defLogit
* (sglang框架支持更新)直接使用截断后的sortedValue作为logitsSortMasked:
logitsSortMasked[b,:] = sortedValue[b]
minP采样
* 如果min_ps[b]∈(0, 1),则执行minP采样:
logitsMax[b] = Max(logitsSortMasked[b])
minPThd = logitsMax[b] * minPs[b]
minPMask[b] = begin{cases} 0, & logitsSortMasked[b] < minPThd 1, & logitsSortMasked[b] geq minPThd end{cases}
logitsSortMasked[b,:] = begin{cases} defLogit, & minPMask[b] = 0 logitsSortMasked[b,:], & minPMask[b] = 1 end{cases}
* 其他情况:
logitsSortMasked[b, :] = begin{cases} logitsSortMasked[b, :], & if minPs[b] leq 0 max(logitsSortMasked[b, :]), & if minPs[b] geq 1 end{cases}
min_ps[b]≥1时,每个batch仅取1个最大token,其余位置填充defLogit。
可选输出
* 如果入参属性Attr.Bool.is_need_logits=True,则使用topK-topP-minP联合采样后的logitsIndexMasked,进行logits_top_kp_select输出。
logitsIndex[b][v] = Index(logitsSortMasked[b][v] in Logits)
logitsIndexMasked[b,:] = logitsIndex[b,:] * topKMask[b] * topPMask[b] * minPMask[b]
其中,topK、topP、minP采样环节如果被跳过,则相应mask为全1。
* 接下来使用logitsIndexMasked对输入logits进行Select,过滤输入logits中的高频token作为logits_top_kp_select输出:
logitsTopKpSelect}[b][v] = begin{cases} logits[b][v], & if logitsIndexMasked[b,v] = True
defLogit, & if logitsIndexMasked[b,v] = False end{cases}
后继处理
* 此阶段输入为前序对前序topK-topP-minP采样的联合结果logitsSortMasked。
* 此处输入须要确保logitsSortMasked∈(0,1),根据输入logits的实际情况,配置入参约束属性Attr.optional.Bool.input_is_logits,即:
inputIsLogits = begin{cases} True, & Logits notin [0,1]
False, & Logits in [0,1] end{cases}
使得
probs[b] = logitsSortMasked[b, :]
接下来有三种模式:None,qSample,multiNomial,通过入参约束属性attr.optional.Str.post_sample加以控制。
* None
* 直接对每个batch通过Argmax取最大元素和索引,并通过gatherOut输出。
logitsSelectIdx[b] = LogitsIdx[b][ArgMax(probs[b][:])]
* qSample
* 先对probs进行指数分布采样:
qCnt = Sum(MinPMask == 1)
probsOpt[b] = frac{probs[b]}{q[b, :qCnt] + eps}
* 再进行Argmax-GatherOut输出结果:
logitsSelectIdx[b] = LogitsIdx[b][ArgMax(probsOpt[b][:])]
* multiNomial
* 使用多项式随机采样,根据logitsSortMasked中的概率值,执行无放回的多项式采样,对每个batch取1个样本,将采样结果作为当期batch的输出:
sampleIdx}[b] = multinomial(logitsSortMasked[b,:], numSamples=1, seed[b], offset[b])
logitsSelectIdx[b] = LogitsIdx[b][sampleIdx[b]]
* 对于采样种子,当attr.optional.Str.post_sample="multinomial"时,q约束为INT64,分别从第一列和第二列获取multiNomial采样的seed和offset:
seed[b] = begin{cases} q[b, 0], & b < qRows
q[-1, 0], & b ge qRows end{cases}
offset[b] = begin{cases} q[b, 1], & b < qRows
q[-1, 1], & b ge qRows end{cases}
* 该采样过程以aclnn.Multinomial为基准,可参看:https://gitcode.com/cann/ops-math-dev/blob/master/random/dsa_random_uniform/docs/aclnnMultinomial.md
* pta调用时,采样种子和偏移默认使用内建值,可参看:https://gitcode.com/Ascend/op-plugin/blob/master/op_plugin/ops/opapi/MultinomialKernelNpuOpApi.cpp
参数说明:
logits(Tensor):必选参数,表示待采样的输入词频,目前支持2维,词频索引固定为最后一维。数据类型支持`float16`、`bfloat16`和`float32`,数据格式支持$ND$,支持非连续Tensor。
top_k(Tensor):必选参数,表示每个batch采样的k值,有效范围为1≤top_k[batch]≤min(voc_size[batch], 1024),无效范围则跳过topK,目前支持1维。数据类型支持`int32`,数据格式支持$ND$,支持非连续Tensor。
top_p(Tensor):必选参数,表示每个batch采样的p值,有效范围为0<$top\_p[batch]<1$,目前支持1维。数据类型和数据格式与`logits`保持一致,支持非连续Tensor。
- 在任何情况下,topP对每个batch的输出都会保留至少1个token。
- top_p[batch] ≤0时,对当前batch仅保留概率最大的1个token。
- top_p[batch]处于合法值范围(0,1)时,对当前batch执行标准topP采样。
- p>=1时跳过相应batch的topP步骤,提取整个batch信息并生成ones掩模作为输出。
q(Tensor):可选参数,topK-topP采样输出的随机采样权重分布矩阵,数据类型支持`float32`,数据格式支持$ND$,支持非连续Tensor,默认值为None, 此时跳过后继采样,从probs计算logits_select_idx。
- 根据post_sample的模式不同,该参数约束如下:
- post_sample = qSample时, 尺寸约束为[batch, voc_size], 数据类型必须为float32,指数分布采样矩阵,维度需与logits的一致。
- post_sample = multinomial时, multinomial随机采样参数矩阵,数据类型必须为int64,用于为aclnnmultiNomial采样提供控制参数。合法的尺寸为[q_row, 2],其中q_row≥1:
- 第1列对应aclnnMultinomial.seed参数:对应当前batch的随机数种子。
- 第2列对应aclnnMultinomial.offset参数:随机数生成器的偏移量,它影响生成的随机数序列的位置。设置偏移量后,生成的随机数序列会从指定位置开始。
- 如果qrow<batch,则默认使用最后一个batch的采样参数作为后续batch的multinomial采样参数。
eps(float):可选参数,在softmax和权重采样中防止除零,默认值为1e-8。
is_need_logits(bool):可选参数,控制`logits_top_kp_select`的输出条件,默认值为False。
top_k_guess(int):可选参数,仅在当前batch的top_k为无效值时使能,适用于跳过topK的top_k_guess-TopP加速采样。有效值范围top_k_guess>0,默认为32,用于TopP加速采样中基于top_k_guess的直接索引过滤。如果传入非正数,视为跳过top_k_guess环节,直接使用基于cumsum的标准topP实现,对当前batch做topP全排序采样,保持基准性能。
ks_max(int):可选参数,约束topK采样中允许的topk[batch]合法值上限,影响跳过topK采样的条件,允许传入任意非零正整数。有效值范围[1,1024]之间的整数,传入超过1024的值会自动设为1024。
input_is_logits(bool):可选参数,该参数控制输入logits在topP及后续步骤之前,是否进行归一化处理,并决定可选输出logits_top_kp_select中的无效logits默认值类型。Logits表示“未经归一化的原始值”,而相对地已经过归一化的则定义为“probs”。该参数的取值影响如下:
- 若该参数取值为True,输入的logits中的数值不能确保在[0,1]区间内。由于logits未进行归一化,在进行top_p采样等后续步骤之前,先对输入进行softmax处理。logits_top_kp_select中的无效logits默认值defLogit=-inf。
- 若该参数取值为False,输入logits中的所有元素都确保在[0,1]区间内。输入logits已经归一化,为避免梯度平滑化,top_p采样等后续步骤直接使用前级处理的结果。logits_top_kp_select中的无效logits默认值defLogit=0。
post_sample(str):可选参数,该参数控制topk-topp采样之后的后继处理策略。第一优先级:判断q是否为None,如果q=None,则无视参数提供的post_sample内容,强制后继处理模式一概设为None。参数合法值允许:
- qSample(默认值):倾向于使用qSample采样。
- multinomial:使用multinomial采样(多项式随机抽样),此时入参中的q矩阵将被解析为随机种子,执行multinomial-gather。
- None:显式强调不使用任何后继处理,此时传入任何q!=None都被无视。
generator(Generator):可选参数,Multinomial使用的随机数生成器,必须指定seed才能传入。
输出说明:
logits_select_idx(Tensor):表示经过topK-topP-sample计算流程后,每个batch中词频最大元素max(probs_opt[batch, :])在输入logits中的位置索引。数据类型支持int64,数据格式支持ND。
logits_top_kp_select(Tensor):表示经过topK-topP-minP采样获得mask,对原输入logits中高频token的过滤结果。仅在is_need_logits=true时使能输出计算和搬运,否则直接输出相应尺寸的空tensor。数据类型支持float32,数据格式支持$ND$。
约束说明:
该接口支持推理场景下使用。
该接口目前不支持图模式。
logits、q、logits_top_kp_select的尺寸和维度必须完全一致。
logits、top_k、top_p、logits_select_idx除最后一维以外的所有维度必须顺序和大小完全一致。目前logits只能是2维,top_k、top_p、logits_select_idx必须是1维非空Tensor。logits、top_k、top_p不允许空Tensor作为输入,如需跳过相应模块,需按相应规则设置输入。
如果需要单独跳过topK模块,请传入[batch, 1]大小的Tensor,并使每个元素均为无效值。
如果1024<top_k[batch]<voc_size[batch],则视为选择当前batch的全部有效元素并跳过topK环节。
如果需要单独跳过topP模块,请传入[batch, 1]大小的Tensor,并使每个元素均≥1。
如果需要单独跳过sample模块,使用其默认值或设置q为None;如需使用Sample模块,则必须传入对应尺寸的Tensor。
支持的型号:
Atlas A3 训练系列产品
Atlas A3 推理系列产品
Atlas A2 训练系列产品
Atlas 800I A2 推理产品
A200I A2 Box 异构组件
调用示例:
>>> import numpy as np
>>> import torch
>>> import torch_npu
>>> logits = torch.from_numpy(np.random.uniform(-2, 2, size=[2, 4])).type(torch.float16).npu()
>>> top_ks = torch.from_numpy(np.random.uniform(1, 2, size=[2, ])).type(torch.int32).npu()
>>> top_ps = torch.from_numpy(np.random.uniform(0.4, 0.5, size=[2, ])).type(torch.float16).npu()
>>> q = None
>>> min_ps = torch.from_numpy(np.random.uniform(0.1, 0.5, size=[2, ])).type(torch.float16).npu()
>>> post_sample = 'multiNomial'
>>> if post_sample == "multiNomial":
>>> generator_npu = torch.Generator(device="npu")
>>> generator_npu.manual_seed(1)
>>> else:
>>> generator_npu = None
>>> npu_out_index, logits_top_kp_select = torch_npu.npu_top_k_top_p_sample(logits, top_ks, top_ps, q=q, min_ps=min_ps, eps=1e-8, is_need_logits=True, top_k_guess=32, ks_max=1024, input_is_logits=True, post_sample=post_sample, generator=generator_npu)
>>> print(npu_out_index)
>>> print(logits_top_kp_select)
"""
)
_add_torch_npu_docstr(
"npu_moe_token_permute",
"""
接口原型:
torch_npu.npu_moe_token_permute(tokens, indices, num_out_tokens=None, padded_mode=False) -> (Tensor, Tensor)
功能描述
MoE的permute计算,根据索引indices将tokens广播并排序。
参数说明:
tokens(torch.Tensor):必选输入,2维Tensor, shape为(num_tokens,hidden_size),数据类型torch.bfloat16,支持非连续Tensor,支持ND
indices(torch.Tensor): 必选输入,2维Tensor,shape为(num_tokens,topK),数据类型torch.int64,支持非连续Tensor,支持ND
num_out_tokens(int, optional):可选输入,默认为None,数据类型int64,表示有效输出token数。设置为0时,表示不会删除任何token。不为0时,会按照num_tokens进行切片丢弃按照indices排序好的token中超过num_tokens的部分,为负数时按照切片索引为负数时处理。
padded_mode(bool, optional): 可选输入,默认为False,如果为True,表示indices已被填充为代表每个专家选中的token索引,此时不对indices进行排序,目前仅支持为False
输出说明:
permuted_tokens(torch.Tensor):2维Tensor,数据类型torch.bfloat16(当前版本permuted_tokens仅支持bfloat16)
sorted_indices(torch.Tensor):1维Tensor,数据类型torch.int32(当前版本sorted_indices仅支持int32)
约束说明:
indices 要求元素个数小于16777215,值大于等于0小于16777215(单点支持int32或int64的最大或最小值,其余值不在范围内排序结果不正确)
topK小于等于512
支持版本:
PyTorch 2.1
PyTorch 2.5及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
dtype = torch.bfloat16
tokens = torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [0, 0, 0]]).npu().to(dtype)
indices = torch.tensor([[0, 4], [4, 3], [4, 2], [1, 1]]).npu()
num_out_tokens = indices.numel()
probs = torch.ones_like(indices) / 2
probs = probs.npu().to(dtype)
permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices, num_out_tokens)
"""
)
_add_torch_npu_docstr(
"npu_moe_token_unpermute",
"""
接口原型:
torch_npu.npu_moe_token_unpermute(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None) -> Tensor
功能描述
根据sorted_indices存储的下标,获取permuted_tokens中存储的输入数据;如果存在probs数据,permuted_tokens会与probs相乘;最后进行累加求和,并输出计算结果
参数说明:
permuted_tokens(torch.Tensor):必选输入,2维Tensor, shape为(num_tokens*topK,hidden_size),数据类型torch.bfloat16,支持非连续Tensor,支持ND
sorted_indices(torch.Tensor): 必选输入,1维Tensor,shape为(num_tokens*topK),数据类型torch.int64,支持非连续Tensor,支持ND
probs(torch.Tensor, optional):可选输入,默认为None,当probs传时,topK等于probs的第二维;当probs不传时,topK=1。shape为(num_tokens,topK),支持的数据类型BFLOAT16。数据格式支持ND,支持非连续输入
padded_mode(bool, optional): 可选输入,默认为False,数据类型int64,目前仅支持为False
restore_shape(torch.size, optional): 可选输入,默认为None,表示permute前输入的shape,只在padded_mode为True时生效。数据类型torch.size
输出说明:
unpermuted_tokens(torch.Tensor):2维Tensor,数据类型torch.bfloat16,padded_mode=False时,shape为(num_tokens,hidden_size)
约束说明:
目前仅支持padded_mode为False
支持版本:
PyTorch 2.1
PyTorch 2.5及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
dtype = torch.bfloat16
permuted_tokens = torch.tensor([[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.],
[3., 3., 3.],
[2., 2., 2.],
[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]]).npu().to(dtype)
sorted_indices = torch.tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=torch.int32).npu()
indices = torch.tensor([[0, 4], [4, 3], [4, 2], [1, 1]]).npu()
probs = torch.ones_like(indices) / 2
unpermuted_tokens = torch_npu.npu_moe_token_unpermute(permuted_tokens, sorted_indices, probs=probs)
"""
)
_add_torch_npu_docstr(
"npu_dynamic_block_quant",
"""
接口原型:
torch_npu.npu_dynamic_block_quant(x, *, min_scale=0.0, round_mode="rint", dst_type=1, row_block_size=1, col_block_size=128) -> (Tensor, Tensor)
功能描述
对输入张量,通过给定的`row_block_size`和`col_block_size`将输入划分成多个数据块,以数据块为基本粒度进行量化。在每个块中,先计算出当前块对应的量化参数`scale`,并根据`scale`对输入进行量化。输出最终的量化结果,以及每个块的量化参数`scale`。
参数说明:
x (Tensor):必选参数,输入张量,数据类型支持float16、bfloat16,支持非连续的Tensor,数据格式支持ND。当前shape支持2维和3维。
min_scale (float):可选参数,参与scale计算的最小scale值。当前支持取值大于等于0。
round_mode (str):可选参数,指定cast到输出的转换方式。当前仅支持取值rint。
dst_type (int):可选参数,指定输出y的数据类型。当前仅支持取值1,表示代码输出y的数据类型为int8。
row_block_size (int):可选参数,指定一个block的行大小。当前仅支持取值1。
col_block_size (int):可选参数,指定一个block的列大小,当前仅支持取值128。
输出说明:
y (Tensor):量化结果。
scale (Tensor):量化时使用的量化参数。
支持版本:
PyTorch 2.1
PyTorch 2.5及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
>>> import torch
>>> import torch_npu
>>> x = torch.rand(3, 4).to("npu").to(torch.float16)
>>> min_scale = 0
>>> dst_type = 1
>>> row_block_size = 1
>>> col_block_size = 128
>>> y, scale = torch_npu.npu_dynamic_block_quant(x, min_scale=min_scale, dst_type=dst_type, row_block_size=row_block_size, col_block_size=col_block_size)
>>> y
tensor([[ 92, 65, 15, 127],
[100, 127, 116, 64],
[ 95, 15, 87, 127]], device='npu:0', dtype=torch.int8)
>>> scale
tensor([[0.0063],
[0.0076],
[0.0073]], device='npu:0')
"""
)
_add_torch_npu_docstr(
"obfuscation_initialize",
"""
功能描述:
该接口用于完成PMCC(Privacy&Model Confidential Computing)模型混淆引擎的资源初始化,即与PMCC混淆引擎CA(普通OS中的Client Application)建立socket连接、对CA、TA(TEE OS中的Trusted Application)进行初始化,并返回socket连接符。
接口原型:
torch_npu.npu.obfuscation_initialize(hidden_size, tp_rank, cmd, data_type, model_obf_seed_id, data_obf_seed_id, thread_num, obf_coefficient) -> Tensor
参数说明
- hidden_size(`int`):必选参数,隐藏层的维度,数据类型为`int32`,支持输入范围为1-10000,仅在`cmd`设置为1或2时需要填写有效值,否则填0。
- tp_rank(`int`):必选参数, 张量并行TP Rank,数据类型为`int32`,支持输入范围为0-1024,仅在`cmd`设置为1或2时需要填写有效值,否则填0。
- cmd(`int`):必选参数,资源初始化的指令编号,数据类型为`int32`,取值范围为{1, 2, 3}。
* 1:进行浮点推理模式资源初始化。
* 2:进行量化推理模式资源初始化。
* 3:进行资源释放。
- data_type(`int`):可选参数, 代表Tensor数据类型的编号,数据类型为`int32`,仅在`cmd`设置为1或2时需要填写有效值,否则填0。
* <term>Atlas 推理系列产品</term>: Tensor数据类型支持`float16` 、`float32`、`int8`。
* <term>Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件</term>: Tensor数据类型支持`float16`、`float32`、`bfloat16`、`int8`。
- model_obf_seed_id(`int`):可选参数, 模型混淆因子id,用于`TA`从`TEE KMC`查询模型混淆因子,数据类型为`int32`,仅在`cmd`设置为1或2时需要填写已注册的有效混淆因子id,否则填0。
- data_obf_seed_id(`int`):必选参数, 数据混淆因子id,用于`TA`从`TEE KMC`查询数据混淆因子,数据类型为`int32`,仅在`cmd`设置为1或2时需要填写已注册的有效混淆因子id,否则填0。
- thread_num(`int`):可选参数, `CA`/`TA`进行混淆处理使用的线程数,数据类型为`int32`,取值范围为{1, 2, 3, 4, 5, 6},仅在`cmd`设置为1或2时需要填写有效值,否则填0。
- obf_coefficient(`float`):可选参数,混淆系数,支持输入范围为0-1,默认值1.0。
输出说明:
`Tensor`
代表socket连接符,1D,shape为(1),数据类型为`int32`。
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas 推理系列产品
调用示例:
import torch
import torch_npu
device = "npu:0"
hidden_size = int(3584)
cmd = 1
data_type = torch.bfloat16
model_obf_seed = 0
data_obf_seed = 0
thread_num = 4
tp_rank = 0
i = 0
hidden_states = torch.randn((1024,3584), dtype=torch.bfloat16, device=device)
obf_cft = 1.0
fd = torch_npu.npu.obfuscation_initialize(hidden_size, tp_rank, cmd, data_type=data_type, thread_num= thread_num, obf_coefficient=obf_cft)
"""
)
_add_torch_npu_docstr(
"obfuscation_finalize",
"""
功能描述:
该接口用于完成PMCC(Privacy&Model Confidential Computing)模型混淆引擎的资源释放,即与PMCC混淆引擎CA(普通OS中的Client Application)断开socket连接。
接口原型:
torch_npu.npu.obfuscation_finalize(fd_to_close) -> Tensor
参数说明:
fd_to_close(`Tensor`):填写调用[obfuscation_initialize](./torch_npu-npu-obfuscation_initialize.md)接口的返回值,数据类型为`int32`。
输出说明:
`Tensor`
代表socket连接符,1D,shape为(1),数据类型为`int32`。
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas 推理系列产品
调用示例:
import torch
import torch_npu
device = "npu:0"
hidden_size = int(3584)
cmd = 1
data_type = torch.bfloat16
model_obf_seed = 0
data_obf_seed = 0
thread_num = 4
tp_rank = 0
i = 0
hidden_states = torch.randn((1024,3584), dtype=torch.bfloat16, device=device)
obf_cft = 1.0
fd = torch_npu.npu.obfuscation_initialize(hidden_size, tp_rank, cmd, data_type=data_type, thread_num= thread_num, obf_coefficient=obf_cft)
torch_npu.npu.obfuscation_finalize(fd)
"""
)
_add_torch_npu_docstr(
"obfuscation_calculate",
"""
功能描述:
该接口用于将张量x和配置参数(如param)发送至PMCC(Privacy&Model Confidential Computing)混淆引擎。引擎的CA(普通OS中的Client Application)模块调用TA(TEE OS中的Trusted Application)模块,进行张量混淆处理,最终返回混淆结果。
接口原型:
torch_npu.npu.obfuscation_calculate(fd, x, param, obf_coefficient) -> Tensor
参数说明:
- fd(`Tensor`):必选参数,socket连接符,数据类型为`int32`,填写调用[obfuscation_initialize](./torch_npu-npu-obfuscation_initialize.md)接口的返回值。
- x(`Tensor`):必选参数,待混淆处理的`Tensor`输入,对`Tensor`维度不作限制,shape为( , *, ... , hiddenSize),即最后一维的size是[obfuscation_initialize](./torch_npu-npu-obfuscation_initialize.md)的入参`hiddenSize`。数据格式支持ND。
* <term>Atlas 推理系列产品</term>: `Tensor`数据类型支持`float16` 、`float32`、`int8`。
* <term>Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件</term>: `Tensor`数据类型支持`float16`、`float32`、`bfloat16`、`int8`。
- param(`Tensor`):必选参数,张量`x`的最后一维的维度,数据类型为`int32`。
- obf_coefficient(`float`):可选参数,混淆系数,支持输入范围为(0.0,1.0],默认值1.0。
输出说明:
`Tensor`
代表`obfuscation_calculate`的计算结果,输出数据类型及shape与`x`相同。
支持的芯片型号:
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas 推理系列产品
调用示例:
import torch
import torch_npu
device = "npu:0"
hidden_size = int(3584)
cmd = 1
data_type = torch.bfloat16
model_obf_seed = 0
data_obf_seed = 0
thread_num = 4
tp_rank = 0
i = 0
hidden_states = torch.randn((1024,3584), dtype=torch.bfloat16, device=device)
obf_cft = 1.0
fd = torch_npu.npu.obfuscation_initialize(hidden_size, tp_rank, cmd, data_type=data_type, thread_num= thread_num, obf_coefficient=obf_cft)
param = torch.tensor([3584], device=device)
x_obf_out = torch_npu.npu.obfuscation_calculate(fd, hidden_states, param, obf_coefficient=obf_cft)
"""
)
_add_torch_npu_docstr(
"npu_gelu_mul",
"""
接口原型:
torch_npu.npu_gelu_mul(input, *, approximate='none') -> Tensor
功能描述:
将输入Tensor按照最后一个维度分为左右两个Tensor:x1和x2,对左边的x1进行Gelu计算,将计算结果与x2相乘。
计算公式:
给定输入张量 input,最后一维的长度为 2d,函数 GeluMul 进行以下计算:
(1)将 input 分割为两部分:x₁ = input[...,:d], x₂ = input[...,d:]
(2)对 x1 应用 GELU 激活函数,"tanh"模式公式如下:GELU(x) = 0.5 * x * [1 + tanh( √(2/π) * (x + 0.044715 * x³) )]
“none”对应的erf模式公式如下:GELU(x)= 0.5 * x * [1 + erf( x / √2 )]
因此,计算:x₁ = GELU(x₁)
(3)最终输出是 x₁ 和 x₂ 的逐元素乘积:out = x₁ * x₂
参数说明:
input (Tensor类型):必选参数,输入张量,数据类型支持BFLOAT16、FLOAT16、FLOAT。支持非连续的Tensor,数据格式支持ND,shape维度2至8维,
且shape满足如下要求:(1)最后一维值为偶数且小于等于1024。(2)其他维度的乘积小于等于200000。
approximate(String类型):可选参数,计算输入, Gelu计算的模式,只支持“none”和“tanh”,分别对应Gelu的erf模式和tanh模式,默认值为“none”。
输出说明:
out (Tensor):输出张量,数据类型支持BFLOAT16、FLOAT16、FLOAT。shape维度2至8维。支持非连续的Tensor,数据格式支持ND,输出的数据类型与输入保持一致,输出shape和输入shape其他维度一致,最后一维的值为输入shape最后一维值的二分之一。
约束说明
典型场景尾轴为16的倍数,当尾轴为非32B对齐时,建议走小算子拼接逻辑。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
shape = [100, 400]
mode = "none"
input = torch.rand(shape, dtype=torch.float16).npu()
output = torch_npu.npu_gelu_mul(input, approximate=mode)
"""
)
_add_torch_npu_docstr(
"npu_sparse_lightning_indexer_grad_kl_loss",
"""
接口原型:
npu_sparse_lightning_indexer_grad_kl_loss(query, key, query_index, key_index, weights, sparse_indices, softmax_max, softmax_sum, scale_value, *, query_rope=None, key_rope=None, actual_seq_qlen=None, actual_seq_klen=None, layout='BSND', sparse_mode=3, pre_tokens=2^63-1, next_tokens=2^63-1) -> (Tensor, Tensor, Tensor, Tensor)
功能描述:
该接口实现了npu_lightning_indexer的反向功能,并融合了Loss的计算。npu_lightning_indexer用于筛选Attention的query与key间最高内在联系的Top-k项,存放在sparse_indices中,以减少长序列场景下的Attention计算量,提升训练性能。
参数说明:
query(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1, D)、(T1, N1, D)。
key(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2, D)、(T2, N2, D)。
query_index(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1index, D)、(T1, N1index, D)。
key_index(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2index, D)、(T2, N2index, D)。
weights(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1)、(T1, N1)。
sparse_indices(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, topK)、(T1, topK)。
softmax_max(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, N2, S1, G)、(N2, T1, G)。
softmax_sum(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, N2, S1, G)、(N2, T1, G)。
scale_value(float):必选参数,表示缩放系数,数据类型支持FLOAT。
query_rope(Tensor):可选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1, Dr)、(T1, N1, Dr)。
key_rope(Tensor):可选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2, Dr)、(T2, N2, Dr)。
actual_seq_qlen(int[]):可选参数,int类型数组,TND场景时需传入此参数。表示query每个S的累加和长度,数据类型支持INT64,数据格式支持ND,默认值为None。
actual_seq_klen(int[]):可选参数,int类型数组,TND场景时需传入此参数。表示key每个S的累加和长度,数据类型支持INT64,数据格式支持ND,默认值为None。
layout(str):可选参数,用于标识输入query的数据排布格式,数据类型支持str。当前支持BSND、TND,默认值为"BSND"。
sparse_mode(int):可选参数,表示sparse的模式,数据类型支持INT32,默认值为3。
pre_tokens(int):可选参数,数据类型支持INT64,默认值2^63-1。
next_tokens(int):可选参数,数据类型支持INT64,默认值2^63-1。
输出说明:
d_query_index(Tensor):表示query_index的梯度,数据类型支持BFLOAT16、FLOAT16。
d_key_index(Tensor):表示key_index的梯度,数据类型支持BFLOAT16、FLOAT16。
d_weights(Tensor):表示weights的梯度,数据类型支持BFLOAT16、FLOAT16。
loss(Tensor):表示网络正向输出和golden值的差异,数据类型支持FLOAT。
支持版本:
PyTorch 2.1
PyTorch 2.5及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
def gen_inputs(seqlens_list_array, seqlens_list_kv_array, isTnd):
B = 1
NQuery = 64
NQueryIndex = 64
N2 = 1
S1 = 128
S2 = 128
topK = 2048
D = 512
DIndex = 128
DR = 64
output_dtype = torch.float16
q = torch.randn(B, S1, NQuery, D, dtype=output_dtype, device=torch.device('npu'))
k = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))
q_index = torch.randn(B, S1, NQueryIndex, DIndex, dtype=output_dtype, device=torch.device('npu'))
k_index = torch.randn(B, S2, N2, DIndex, dtype=output_dtype, device=torch.device('npu'))
if DR != 0:
q_rope = torch.randn(B, S1, NQuery, DR, dtype=output_dtype, device=torch.device('npu'))
k_rope = torch.randn(B, S2, N2, DR, dtype=output_dtype, device=torch.device('npu'))
else:
q_rope = None
k_rope = None
weights = torch.randn(B, S1, NQueryIndex, dtype=output_dtype, device=torch.device('npu'))
a = -0.05 # 最小值
b = 0.05 # 最大值
kk = 3.0 # 控制分布范围(3σ 覆盖绝大多数值)
scale = (b - a) / (2 * kk)
shift = (a + b) / 2
weights = weights * scale + shift
if isTnd:
sparse_indices = torch.zeros(S1, N2, topK).to(torch.int32).npu()
tIdx = 0
for bIdx in range(B):
for s1Idx in range(seqlens_list_array[bIdx]):
s2RealSize = (int)((seqlens_list_kv_array[bIdx] - seqlens_list_array[bIdx]) + s1Idx + 1)
if s2RealSize <= 0:
s2RealSize = seqlens_list_kv_array[bIdx]
if s2RealSize > topK:
s2RealLen = topK
else:
s2RealLen = s2RealSize
#处理S2无效行场景,把对应的sparse indices置为-1
sparse_indices[tIdx, :, 0 : s2RealLen] = (torch.randint(0, s2RealSize, (s2RealLen,)).to(torch.int32)).npu()
sparse_indices[tIdx, :, s2RealLen : topK] = -1
tIdx = tIdx + 1
q_tnd = q.squeeze(dim=0)
k_tnd = k.squeeze(dim=0)
q_index_tnd = q_index.squeeze(dim=0)
k_index_tnd = k_index.squeeze(dim=0)
if q_rope is not None:
q_rope_tnd = q_rope.squeeze(dim=0)
k_rope_tnd = k_rope.squeeze(dim=0)
else :
q_rope_tnd = None
k_rope_tnd = None
weights_tnd = weights.squeeze(dim=0)
softmax_max = torch.randn(N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
softmax_sum = torch.randn(N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
return q_tnd, k_tnd, q_index_tnd, k_index_tnd, q_rope_tnd, k_rope_tnd, weights_tnd, sparse_indices, softmax_max, softmax_sum
else :
sparse_indices = torch.zeros(B, S1, N2, topK).to(torch.int32).npu()
for s1Idx in range(S1):
s2RealSize = (int)(S2 - S1 + s1Idx + 1)
if s2RealSize <= 0:
s2RealSize = S2
if s2RealSize > topK:
s2RealLen = topK
else:
s2RealLen = s2RealSize
sparse_indices[:, s1Idx, 0, 0 : s2RealLen] = (torch.randint(0, s2RealSize, (s2RealLen,)).to(torch.int32)).npu()
sparse_indices[:, s1Idx, 0, s2RealLen : topK] = -1
softmax_max = torch.randn(B, N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
softmax_sum = torch.randn(B, N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
return q, k, q_index, k_index, q_rope, k_rope, weights, sparse_indices, softmax_max, softmax_sum
actual_seq_qlen = [128]
actual_seq_kvlen = [128]
input_layout = 'TND'
isTnd = True
sparse_mode = 3
scale = 1.0
q, k, q_index, k_index, q_rope, k_rope, weights, sparse_indices, softmax_max, softmax_sum = gen_inputs(actual_seq_qlen, actual_seq_kvlen, isTnd)
torch_npu.npu_sparse_lightning_indexer_grad_kl_loss(
q, k, q_index, k_index, weights, sparse_indices, softmax_max, softmax_sum, scale,
query_rope=q_rope, key_rope=k_rope, actual_seq_qlen=actual_seq_qlen, actual_seq_klen=actual_seq_kvlen, layout=input_layout, sparse_mode=sparse_mode, pre_tokens=65536, next_tokens=65536
)
"""
)
_add_torch_npu_docstr(
"npu_sim_exponential_",
"""
接口原型:
torch_npu.npu_sim_exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
功能描述:
根据参数lambd生成指数分布随机数,并原地填充至输入张量中。
计算公式:
f(x) = -1/λ * ln(1-u), u ~ Uniform(0, 1]
参数说明:
self(Tensor):必选参数,源数据张量,公式中的f(x)。要求为连续的Tensor,数据类型支持bfloat16、float16、float32,数据格式支持ND,shape支持0~8维。
lambd(double):可选参数,指数分布的参数,公式中的λ,可配置为任意正实数,默认值为1。
generator(Generator)::可选参数,用于生成seed和offset,供aclnnSimThreadExponential算子使用,默认为None。
返回值说明:
out(Tensor):表示公式中的f(x),即原地更新后的input张量。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
shape = [100, 400]
gen = torch.Generator(device="npu")
gen.manual_seed(0)
input = torch.zeros(shape, dtype=torch.float32).npu()
torch_npu.npu_sim_exponential_(input, lambd=1, generator=gen)
"""
)
_add_torch_npu_docstr(
"npu_fused_floyd_attention",
"""
接口原型:
npu_fused_floyd_attention(Tensor query_ik, Tensor key_ij, Tensor value_ij, Tensor key_jk, Tensor value_jk, *, Tensor? atten_mask=None, float scale_value=1.) -> (Tensor, Tensor, Tensor)
功能描述:
训练场景下,FloydAttn相较于传统FA主要是计算qk/pv注意力时会额外将seq作为batch轴从而转换为batchMatmul。
计算公式:
P=Softmax(Mask(scale*(query_ik * key_ij^T + query_ik * key_jk^T), atten_mask))
attention_out=P * value_ij + P * value_jk
参数说明:
query_ik (Tensor类型):必选参数,输入张量,数据类型支持BFLOAT16、FLOAT16。数据类型与key_ij/value_ij/key_jk/value_jk的数据类型一致,数据格式支持ND,输入shape支持[BHMND]。
key_ij (Tensor类型):必选参数,输入张量,代表从节点 i 到其直接邻居 j 的关系或特征,数据类型支持BFLOAT16、FLOAT16。数据类型与query_ik/value_ij/key_jk/value_jk的数据类型一致,数据格式支持ND,输入shape支持[BHMKD]。
value_ij (Tensor类型):必选参数,输入张量,代表从节点 i 到其直接邻居 j的信息内容,数据类型支持BFLOAT16、FLOAT16。数据类型与query_ik/key_ij/key_jk/value_jk的数据类型一致,数据格式支持ND,输入shape支持[BHKND]。
key_jk (Tensor类型):必选参数,输入张量,代表从直接邻居 j 到支点 k 的关系或特征,数据类型支持FLOAT16、BFLOAT16,数据类型与query_ik/key_ij/value_ij/value_jk的数据类型一致,数据格式支持ND,输入shape支持[BHKND]。
value_jk (Tensor类型):必选参数,输入张量,代表从节点 j到其直接邻居 k的信息内容,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据类型与query_ik/key_ij/value_ij/key_jk的数据类型一致,数据格式支持ND,输入shape支持[BHMKD]。
atten_mask (Tensor类型):可选张量,数据类型支持BOOL、UINT8,数据格式支持ND,输入shape类型需为[BNK],默认值为None。
scale_value (float):代表缩放系数,数据类型支持float。一般设置为D^-0.5。
输出说明:
softmax_max_out (Tensor):输出张量,Softmax计算的Max中间结果,用于反向计算。数据类型支持FLOAT,输出的shape类型为[BHMN8]。数据格式支持ND。
softmax_sum_out (Tensor):输出张量,Softmax计算的Sum中间结果,用于反向计算。数据类型支持FLOAT,输出的shape类型为[BHMN8]。数据格式支持ND。
attention_out (Tensor):输出张量,计算公式的最终输出。数据类型支持FLOAT16、BFLOAT16。数据类型和shape类型与query_ik保持一致,数据格式支持ND,输入shape支持[BHMND]。
约束说明
关于数据shape的约束,其中:
B:取值范围为1~2K。
H:取值范围为1~256。
M:取值范围为1~1M。
N:取值范围为1~1M。
K:取值范围为1~1M。
D:取值范围为32~256。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
import math
def truncated_normal(mean, std, min, max, size):
x = torch.normal(mean, std, size)
x = torch.where((x < min) | (x > max), torch.tensor(0.0), x)
return x
B, N, S1, S2, S3, D = 1, 1, 16, 256, 256, 64
Q = truncated_normal(0.0, 1, -10, 10, (B, N, S1, S2, D)).to(torch.bfloat16).npu()
K1 = truncated_normal(0.0, 1, -10, 10, (B, N, S1, S3, D)).to(torch.bfloat16).npu()
K2 = truncated_normal(0.0, 1, -10, 10, (B, N, S3, S2, D)).to(torch.bfloat16).npu()
V1 = truncated_normal(0.0, 1, -10, 10, (B, N, S1, S3, D)).to(torch.bfloat16).npu()
V2 = truncated_normal(0.0, 1, -10, 10, (B, N, S3, S2, D)).to(torch.bfloat16).npu()
atten_mask = torch.randint(0, 2, [B, 1, S1, 1, S3]).to(torch.bool).npu()
scale = 1.0/math.sqrt(D)
x_max_npu, x_sum_npu, output_npu = torch_npu.npu_fused_floyd_attention(
Q,
K1,
V1,
K2,
V2,
atten_mask = atten_mask,
scale_value = scale
)
"""
)
_add_torch_npu_docstr(
"npu_dense_lightning_indexer_softmax_lse",
"""
接口原型:
npu_dense_lightning_indexer_softmax_lse(query_index, key_index, weights, *, actual_seq_qlen=None, actual_seq_klen=None, layout='BSND', sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807) -> (Tensor, Tensor)
功能描述:
是npu_dense_lightning_indexer_grad_kl_loss接口的前置接口,通过把Lightning Indexer组件的Softmax求最大值和求和运算提前来降低接口的显存占用。
参数说明:
query_index(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1index, D)、(T1, N1index, D)。
key_index(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2index, D)、(T2, N2index, D)。
weights(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16、FLOAT。支持输入shape(B, S1, N1index)、(T1, N1index)。
actual_seq_qlen(int[]):可选参数,int类型数组,TND场景时需传入此参数。表示query每个S的累加和长度,数据类型支持INT64,数据格式支持ND,默认值为None。
actual_seq_klen(int[]):可选参数,int类型数组,TND场景时需传入此参数。表示key每个S的累加和长度,数据类型支持INT64,数据格式支持ND,默认值为None。
layout(str):可选参数,用于标识输入query的数据排布格式,数据类型支持str。当前支持BSND、TND,默认值为"BSND"。
sparse_mode(int):可选参数,表示sparse的模式,数据类型支持INT32,默认值为3。
pre_tokens(int):可选参数,数据类型支持INT64,默认值2^63-1。
next_tokens(int):可选参数,数据类型支持INT64,默认值2^63-1。
输出说明:
softmax_max_index(Tensor):表示softmax计算使用的max值,数据类型支持FLOAT。
softmax_sum_index(Tensor):表示softmax计算使用的sum值,数据类型支持FLOAT。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
B = 1
N1 = 64
N2 = 1
S1 = 128
S2 = 256
D = 128
output_dtype = torch.float16
q_index = torch.randn(B, S1, N1, D, dtype=output_dtype, device=torch.device('npu'))
k_index = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))
weights = torch.randn(B, S1, N1, dtype=output_dtype, device=torch.device('npu'))
actual_seq_qlen = None
actual_seq_klen = None
input_layout = 'BSND'
sparse_mode = 3
softmax_max_index, softmax_sum_index = torch_npu.npu_dense_lightning_indexer_softmax_lse(q_index, k_index, weights, actual_seq_qlen=actual_seq_qlen, actual_seq_klen=actual_seq_klen, layout=input_layout, sparse_mode=sparse_mode)
"""
)
_add_torch_npu_docstr(
"npu_dense_lightning_indexer_grad_kl_loss",
"""
接口原型:
npu_dense_lightning_indexer_grad_kl_loss(query, key, query_index, key_index, weights, softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, scale_value, *, query_rope=None, key_rope=None, actual_seq_qlen=None, actual_seq_klen=None, layout='BSND', sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807) -> (Tensor, Tensor, Tensor, Tensor)
功能描述:
该接口实现了Lightning Indexer组件warmup阶段训练的反向梯度计算,并融合了Loss的计算。
参数说明:
query(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1, D)、(T1, N1, D)。
key(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2, D)、(T2, N2, D)。
query_index(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1index, D)、(T1, N1index, D)。
key_index(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2index, D)、(T2, N2index, D)。
weights(Tensor):必选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16、FLOAT。支持输入shape(B, S1, N1index)、(T1, N1index)。
softmax_max(Tensor):必选参数,数据格式支持ND,数据类型支持FLOAT。支持输入shape(B, N2, S1, G)、(N2, T1, G)。
softmax_sum(Tensor):必选参数,数据格式支持ND,数据类型支持FLOAT。支持输入shape(B, N2, S1, G)、(N2, T1, G)。
softmax_max_index(Tensor):必选参数,数据格式支持ND,数据类型支持FLOAT。支持输入shape(B, N2index, S1)、(N2index, T1)。
softmax_sum_index(Tensor):必选参数,数据格式支持ND,数据类型支持FLOAT。支持输入shape(B, N2index, S1)、(N2index, T1)。
scale_value(float):必选参数,表示缩放系数,数据类型支持FLOAT。
query_rope(Tensor):可选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S1, N1, Dr)、(T1, N1, Dr)。
key_rope(Tensor):可选参数,数据格式支持ND,数据类型支持BFLOAT16、FLOAT16。支持输入shape(B, S2, N2, Dr)、(T2, N2, Dr)。
actual_seq_qlen(int[]):可选参数,int类型数组,TND场景时需传入此参数。表示query每个S的累加和长度,数据类型支持INT64,数据格式支持ND,默认值为None。
actual_seq_klen(int[]):可选参数,int类型数组,TND场景时需传入此参数。表示key每个S的累加和长度,数据类型支持INT64,数据格式支持ND,默认值为None。
layout(str):可选参数,用于标识输入query的数据排布格式,数据类型支持str。当前支持BSND、TND,默认值为"BSND"。
sparse_mode(int):可选参数,表示sparse的模式,数据类型支持INT32,默认值为3。
pre_tokens(int):可选参数,数据类型支持INT64,默认值2^63-1。
next_tokens(int):可选参数,数据类型支持INT64,默认值2^63-1。
输出说明:
d_query_index(Tensor):表示query_index的梯度,数据类型支持BFLOAT16、FLOAT16。
d_key_index(Tensor):表示key_index的梯度,数据类型支持BFLOAT16、FLOAT16。
d_weights(Tensor):表示weights的梯度,数据类型支持BFLOAT16、FLOAT16。
loss(Tensor):表示网络正向输出和golden值的差异,数据类型支持FLOAT。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A2训练系列产品
Atlas A3训练系列产品
调用示例:
import torch
import torch_npu
B = 1
N1 = 64
N2 = N1
N1_index = 64
N2_index = 1
S1 = 128
S2 = 256
D = 128
Dr = 64
output_dtype = torch.float16
q = torch.randn(B, S1, N1, D, dtype=output_dtype, device=torch.device('npu'))
k = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))
q_index = torch.randn(B, S1, N1_index, D, dtype=output_dtype, device=torch.device('npu'))
k_index = torch.randn(B, S2, N2_index, D, dtype=output_dtype, device=torch.device('npu'))
q_rope = torch.randn(B, S1, N1, Dr, dtype=output_dtype, device=torch.device('npu'))
k_rope = torch.randn(B, S2, N2, Dr, dtype=output_dtype, device=torch.device('npu'))
weights = torch.randn(B, S1, N1_index, dtype=output_dtype, device=torch.device('npu'))
softmax_max = (torch.randn(B, N2, S1, 1, dtype=torch.float32, device=torch.device('npu')).abs() + 0.4) * D
softmax_sum = torch.ones(B, N2, S1, 1, dtype=torch.float32, device=torch.device('npu'))
actual_seq_qlen = None
actual_seq_klen = None
input_layout = 'BSND'
sparse_mode = 3
scale = 1.0
softmax_max_index, softmax_sum_index = torch_npu.npu_dense_lightning_indexer_softmax_lse(q_index, k_index, weights, actual_seq_qlen=actual_seq_qlen, actual_seq_klen=actual_seq_klen, layout=input_layout, sparse_mode=sparse_mode)
torch_npu.npu_dense_lightning_indexer_grad_kl_loss(q, k, q_index, k_index, weights, softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, scale, query_rope=q_rope, key_rope=k_rope, actual_seq_qlen=actual_seq_qlen, actual_seq_klen=actual_seq_klen, layout=input_layout, sparse_mode=sparse_mode)
"""
)
_add_torch_npu_docstr(
"_lstm_npu",
"""
接口原型:
_lstm_npu(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, bool? batch_first=False, Tensor? batch_sizes=None) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述:
LSTM(Long Short-Term Memory,长短时记忆)网络是一种特殊的循环神经网络(RNN)模型。进行LSTM网络计算,接收输入序列和初始状态,返回输出序列和最终状态。
计算公式:
$$
\begin{aligned}
(1)\qquad f_t &=\sigma(W_f[h_{t-1}, x_t] + b_f) \\
(2)\qquad i_t &=\sigma(W_i[h_{t-1}, x_t] + b_i) \\
(3)\qquad o_t &=\sigma(W_o[h_{t-1}, x_t] + b_o) \\
(4)\qquad \tilde{c}_t &=tanh(W_c[h_{t-1}, x_t] + b_c) \\
(5)\qquad c_t &=f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c}_t \\
(6)\qquad c_{o}^{t} &=tanh(c_t) \\
(7)\qquad h_t &=o_t ⊙ c_{o}^{t} \\
\end{aligned}
$$
- $x_t ∈ R^{d}$:LSTM单元的输入向量。
- $f_t ∈ (0, 1)^{h}$:遗忘门激活向量。
- $i_t ∈ (0, 1)^{h}$:输入门、更新门激活向量。
- $o_t ∈ (0, 1)^{h}$:输出门激活向量。
- $h_i ∈ (-1, 1)^{h}$:隐藏状态向量,也称为LSTM单元的输出向量。
- $\tilde{c}_t ∈ (-1, 1)^{h}$:cell输入激活向量。
- $c_t ∈ R^{h}$:cell状态向量。
- $W ∈ R^{h×d},(U ∈ R^{h×h})∩(b ∈ R^{h})$:训练中需要学习的权重矩阵和偏置向量参数。
参数说明:
input (Tensor类型):必选参数,LSTM单元的输入向量,数据格式支持ND,数据类型支持FLOAT16、FLOAT32,非连续Tensor。
(1)若batch_sizes传入空指针:当batch_first=False时shape应为(time_step, batch_size, input_size), 否则为(batch_size, time_step, input_size)。其中,batch_first表示batch是否在第一维;time_step表示时间维度;batch_size表示每个时刻需要处理的batch数量;input_size表示输入的特征数量。
(2)若传入有效batch_sizes:shape应为(time_step * batch_size, input_size),其内存排列与(time_step, batch_size, input_size)相同。
hx (TensorList):必选入参,表示LSTM运算中的初始hidden和cell状态列表,Device侧的TensorList,数据类型支持FLOAT16、FLOAT32,列表长度为2,列表中每个shape支持三维(D * num_layers, batch_size, hidden_size),若输入为空,则表示输入的初始hidden和cell状态为0。
params (TensorList):必选入参,表示LSTM运算中的权重和偏置张量列表,Device侧的TensorList,数据格式支持ND,数据类型支持FLOAT16、FLOAT32。
列表长度为 2 * D * B * num_layers, 其中,num_layers对应参数numLayers,表示LSTM层数,bidirection为True时 D = 2, 否则 D = 1, has_biases为True时 B = 2, 否则 B = 1;
其中bidirection为True, 且has_biases为True时,参数排布如下:[weight_ih_0, weight_hh_0, bias_ih_0, bias_hh_0, weight_ih_reverse_0, weight_hh_reverse_0, bias_ih_reverse_0, bias_hh_reverse_0],
其中 weight_ih_0 表示第0层输入的权重参数,其shape为(4 * hidden_size, cur_input_size),其中cur_intput_size 表示LSTM每层计算时的输入的特征数量(首层为input_size, 后续层为hidden_size, 如果bidirection为True,则为2 * hidden_size);
weight_hh_0 表示第0层隐藏层的权重参数,其shape为(4 * hidden_size, hidden_size),bias_ih_0 表示第0层输入权重参数的偏置,其shape为(4 * hidden_size),bias_hh_0 表示第0层隐藏层权重参数的偏置,其shape为(4 * hidden_size)。
has_biases (bool):必选入参,表示是否有biases。
num_layers (int):必选入参,表示LSTM层数。
dropout (float):表示随机掩码的概率。当前不支持该功能。
train (bool):必选入参,表示是否是训练模式。其中train = True时,在计算前向LSTM时会保存中间结果用于反向传播,train = False的时候,前向计算过程不保存中间结果。
bidirectional(bool):必选入参,表示是否是双向。
batch_first(bool):可选入参,表示输入数据格式是否是Batch在第一轴(B, T, H)。
batch_sizes(Tensor类型):可选参数,表示每个时间步实际参与计算的有效Batch数。传入nullptr时,代表输入input为定长模式数据,否则为不定长模式。shape为(time_step,),其中元素应按降序排列,元素值为正整数且最大不超过总Batch数量,且第一位元素值应与总Batch数量相等。
输出说明:
output (Tensor):输出张量,表示LSTM运算中最后一层每个时间步的输出结果,数据类型支持FLOAT16、FLOAT32,非连续Tensor。当batch_first=False时shape支持三维(time_step, batch_size, D * hidden_size),否则支持三维(batch_size,time_step, D * hidden_size)。
hy (Tensor):输出张量,表示进行LSTM运算中每层最后一个时间步的隐藏层(公式(7)的输出)。shape支持三维(D * num_layers, batch_size, hidden_size)。数据格式支持ND,数据类型支持FLOAT16、FLOAT32。
cy (Tensor):输出张量,表示进行LSTM运算中每层最后一个时间步的Cell状态(公式(5)的输出)。shape支持三维(D * num_layers, batch_size, hidden_size)。数据格式支持ND,数据类型支持FLOAT16、FLOAT32。
i_out(Tensor):输出张量,表示LSTM运算中每层输入门的激活值(sigmoid输出,公式(2)的输出),数据类型支持FLOAT16、FLOAT32。原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);
最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);当train=False时,输出为空张量。
j_out(Tensor):输出张量,表示LSTM运算中每层的候选cell状态(tanh输出,公式(4)的输出),数据类型支持FLOAT16、FLOAT32。原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);
最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);当train=False时,输出为空张量。
f_out(Tensor):输出张量,表示LSTM运算中每层遗忘门的激活值(sigmoid输出),数据类型支持FLOAT16、FLOAT32。原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);
最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);当train=False时,输出为空张量。
o_out(Tensor):输出张量,表示LSTM运算中每层输出门的激活值(sigmoid输出,公式(3)的输出),数据类型支持FLOAT16、FLOAT32。原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);
最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);当train=False时,输出为空张量。
h_out(Tensor):输出张量,表示LSTM运算中每层的隐藏层(公式(7)的输出),数据类型支持FLOAT16、FLOAT32。train=True时:原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);
train=False时:原始列表长度为 D * num_layers,列表中每个元素为二维张量(batch_size, hidden_size);最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, batch_size, hidden_size);
c_out(Tensor):输出张量,表示LSTM运算中每层的最终Cell状态(公式(5)的输出),数据类型支持FLOAT16、FLOAT32。train=True时:原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);
train=False时:原始列表长度为 D * num_layers,列表中每个元素为二维张量(batch_size, hidden_size);最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, batch_size, hidden_size);
tanh_c_out(Tensor):输出张量,表示LSTM运算中每层最终cell状态经过tanh激活函数后的输出(公式(6)的输出),数据类型支持FLOAT16、FLOAT32。原始列表长度为 D * num_layers,列表中每个元素为三维张量(time_step, batch_size, hidden_size);最终输出为将列表张量在维度0上堆叠后的张量,shape为 (D * num_layers, time_step, batch_size, hidden_size);当train=False时,输出为空张量。
约束说明
确定性计算:_lstm_npu默认支持确定性实现。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列产品
Atlas 训练系列产品
调用示例:
import torch
import torch_npu
torch.npu.set_device(0)
device = torch.device("npu:0")
dtype = torch.float32
input_tensor = torch.randn((2, 3, 8), dtype=dtype, device=device)
h0 = torch.randn((2, 2, 4), dtype=dtype, device=device)
c0 = torch.randn((2, 2, 4), dtype=dtype, device=device)
hx = [h0, c0]
hidden_size = 4
gate_size = 4 * hidden_size
weight_ih = torch.randn((gate_size, 8), dtype=dtype, device=device)
weight_hh = torch.randn((gate_size, 4), dtype=dtype, device=device)
params = [weight_ih, weight_hh, weight_hh, weight_hh]
c0 = torch.randn((1, 2, 16), dtype=dtype, device=device)
batch_sizes = torch.randn((4*hidden_size), dtype=dtype, device=device)
input_tensor = input_tensor.to(device, dtype=dtype)
hx = [
t.to(device=device, dtype=dtype)
for t in hx
]
params = [
p.to(device=device, dtype=dtype)
for p in params
]
out, out_h, out_c, _, _, _, _, hn_list, cn_list, _ = torch_npu._lstm_npu(
input_tensor, # 1. input
hx, # 2. hx
params, # 3. params
False, # 4. has_biases
2, # 5. num_layers
0.0, # 6. dropout
False, # 7. train
False, # 8. bidirectional
batch_first=True, # 9. batch_first(位置参数,非关键字)
batch_sizes=None # 10. batch_sizes
)
"""
)
_add_torch_npu_docstr(
"_lstm_npu_backward",
"""
接口原型:
_lstm_npu_backward(Tensor grad_y, Tensor grad_hy, Tensor grad_cy, Tensor input, Tensor[] hx, Tensor[] params, Tensor i, Tensor j, Tensor f, Tensor o, Tensor h, Tensor c, Tensor tanhc, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, bool? batch_first=False, Tensor? batch_sizes=None) -> (Tensor, Tensor[], Tensor[])
功能描述:
LSTM的反向传播,计算正向输入input、权重params、初始状态hx的梯度。
计算公式:
<details>
<summary> 单层LSTM反向传播计算公式</summary>
| 组件 | 公式 |
|:---|:---|
| 输入拼接 | $\mathbf{z}_t = \begin{bmatrix} \mathbf{h}_{t-1} \\ \mathbf{x}_t \end{bmatrix}$ |
| 遗忘门 | $\mathbf{f}_t = \sigma(\mathbf{W}_f \mathbf{z}_t + \mathbf{b}_f)$ |
| 输入门 | $\mathbf{i}_t = \sigma(\mathbf{W}_i \mathbf{z}_t + \mathbf{b}_i)$ |
| 候选状态 | $\mathbf{g}_t = \tanh(\mathbf{W}_g \mathbf{z}_t + \mathbf{b}_c)$ |
| 输出门 | $\mathbf{o}_t = \sigma(\mathbf{W}_o \mathbf{z}_t + \mathbf{b}_o)$ |
| 细胞状态 | $\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \mathbf{g}_t$ |
| 隐藏状态 | $\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)$ |
其中:
- $\sigma$ 是 sigmoid 函数
- $\odot$ 表示逐元素乘法 (Hadamard product)
- $W_*$ 是可学习的权重矩阵
- $b_*$ 是可学习的偏置项
</details>
<details>
<summary> 反向传播变量定义</summary>
- 总损失:$L = \sum_{t=1}^{T} L_t$
- 隐藏状态梯度:$\delta\mathbf{h}_t = \frac{\partial L}{\partial \mathbf{h}_t}$
- 细胞状态梯度:$\delta\mathbf{c}_t = \frac{\partial L}{\partial \mathbf{c}_t}$
</details>
<details>
<summary> 反向传播算法(时间步 t -> t-1)</summary>
- **初始化**
$$
\delta\mathbf{h}_{T} = \mathbf{0}, \quad \delta\mathbf{c}_{T} = \mathbf{0}, \quad \mathbf{f}_{T} = \mathbf{0}
$$
- **循环 $t = T - 1$ 到 $0$**
1.**当前隐藏状态梯度**
$$
\delta\mathbf{h}_t = \frac{\partial L_t}{\partial \mathbf{h}_t} + \delta\mathbf{h}_{\text{next}}
$$
2.**当前细胞状态梯度**
$$
\delta\mathbf{c}_t = \delta\mathbf{h}_t \odot \mathbf{o}_t \odot (1 - \tanh^2(\mathbf{c}_t)) + \delta\mathbf{c}_{\text{next}} \odot \mathbf{f}_{\text{next}}
$$
3.**门控梯度计算**
$$
\delta\mathbf{o}_t = \delta\mathbf{h}_t \odot \tanh(\mathbf{c}_t) \odot \mathbf{o}_t \odot (1 - \mathbf{o}_t)
$$
$$
\delta\mathbf{g}_t = \delta\mathbf{c}_t \odot \mathbf{i}_t \odot (1 - \mathbf{g}_t^2)
$$
$$
\delta\mathbf{i}_t = \delta\mathbf{c}_t \odot \mathbf{g}_t \odot \mathbf{i}_t \odot (1 - \mathbf{i}_t)
$$
$$
\delta\mathbf{f}_t = \delta\mathbf{c}_t \odot \mathbf{c}_{t-1} \odot \mathbf{f}_t \odot (1 - \mathbf{f}_t)
$$
4.**参数梯度累加**
$$
\frac{\partial L}{\partial \mathbf{W}_f} \mathrel{+}= \delta\mathbf{f}_t \mathbf{z}_t^\top
$$
$$
\frac{\partial L}{\partial \mathbf{b}_f} \mathrel{+}= \delta\mathbf{f}_t
$$
$$
\frac{\partial L}{\partial \mathbf{W}_i} \mathrel{+}= \delta\mathbf{i}_t \mathbf{z}_t^\top
$$
$$
\frac{\partial L}{\partial \mathbf{b}_i} \mathrel{+}= \delta\mathbf{i}_t
$$
$$
\frac{\partial L}{\partial \mathbf{W}_g} \mathrel{+}= \delta\mathbf{g}_t \mathbf{z}_t^\top
$$
$$
\frac{\partial L}{\partial \mathbf{b}_g} \mathrel{+}= \delta\mathbf{g}_t
$$
$$
\frac{\partial L}{\partial \mathbf{W}_o} \mathrel{+}= \delta\mathbf{o}_t \mathbf{z}_t^\top
$$
$$
\frac{\partial L}{\partial \mathbf{b}_o} \mathrel{+}= \delta\mathbf{o}_t
$$
5.**传播到前一时刻**
$$
\delta\mathbf{z}_t = \mathbf{W}_f^\top \delta\mathbf{f}_t + \mathbf{W}_i^\top \delta\mathbf{i}_t + \mathbf{W}_g^\top \delta\mathbf{g}_t + \mathbf{W}_o^\top \delta\mathbf{o}_t
$$
$$
\delta\mathbf{h}_{\text{prev}} = \delta\mathbf{z}_t[1:\dim(\mathbf{h}_{t-1})]
$$
$$
\delta\mathbf{c}_{\text{prev}} = \delta\mathbf{c}_t \odot \mathbf{f}_t
$$
6.**更新传播变量**
$$
\delta\mathbf{h}_{\text{next}} \leftarrow \delta\mathbf{h}_{\text{prev}}
$$
$$
\delta\mathbf{c}_{\text{next}} \leftarrow \delta\mathbf{c}_{\text{prev}}
$$
$$
\mathbf{f}_{\text{next}} \leftarrow \mathbf{f}_t
$$
</details>
<details>
<summary> 梯度计算原理</summary>
- **细胞状态梯度推导**
$$
\delta\mathbf{c}_t = \frac{\partial L}{\partial \mathbf{h}_t} \frac{\partial \mathbf{h}_t}{\partial \mathbf{c}_t} + \frac{\partial L}{\partial \mathbf{c}_{t+1}} \frac{\partial \mathbf{c}_{t+1}}{\partial \mathbf{c}_t}
$$
其中:
$$
\frac{\partial \mathbf{h}_t}{\partial \mathbf{c}_t} = \mathbf{o}_t \odot (1 - \tanh^2(\mathbf{c}_t))
$$
$$
\frac{\partial \mathbf{c}_{t+1}}{\partial \mathbf{c}_t} = \mathbf{f}_{t+1}
$$
- **遗忘门梯度推导**
$$
\delta\mathbf{f}_t = \frac{\partial L}{\partial \mathbf{a}_f^t} = \delta\mathbf{c}_t \odot \mathbf{c}_{t-1} \odot \mathbf{f}_t \odot (1 - \mathbf{f}_t)
$$
- **参数梯度推导**
$$
\frac{\partial L}{\partial \mathbf{W}_f} = \sum_{t=1}^{T} \delta\mathbf{f}_t \mathbf{z}_t^\top
$$
- **LSTM 梯度流动特性**
**长程依赖处理**
$$
\frac{\partial \mathbf{c}_T}{\partial \mathbf{c}_1} = \prod_{k=2}^{T} \mathbf{f}_k \quad \text{(对角矩阵)}
$$
</details>
<details>
<summary> 多层LSTMBackward反向传播</summary>
在多层LSTM网络中,层与层之间的梯度传播仅关注隐藏状态的传递(忽略单层内部细节,如门控机制或单元状态)。设:
- $\mathbf{h}^{(l)}$:第 $l$ 层的隐藏状态($l = 1, 2, \dots, L$,其中 $L$ 为总层数)
- $L$:损失函数
- $\frac{\partial L}{\partial \mathbf{h}^{(l)}}$:损失函数对第 $l$ 层隐藏状态的梯度
**核心传播公式**
梯度从顶层($l = L$)向底层($l = 1$)传播,层间关系由链式法则给出:
$$
\frac{\partial L}{\partial \mathbf{h}^{(l-1)}} = \frac{\partial L}{\partial \mathbf{h}^{(l)}} \cdot \frac{\partial \mathbf{h}^{(l)}}{\partial \mathbf{h}^{(l-1)}}
$$
其中:
- $\frac{\partial L}{\partial \mathbf{h}^{(l)}}$:当前层 $l$ 的梯度(已由上一层反向传播得到)
- $\frac{\partial \mathbf{h}^{(l)}}{\partial \mathbf{h}^{(l-1)}}$:第 $l$ 层隐藏状态对第 $l-1$ 层隐藏状态的雅可比矩阵
- $\cdot$:矩阵乘法(梯度传播本质为向量-矩阵乘法)
即每层的输出的梯度dx为上一层输入的梯度dy。
</details>
参数说明:
grad_y (Tensor类型):必选参数,LSTM正向最后一层输出hidden的梯度。对应公式中的∂L/∂h^(l)。双向时数据沿最后一维按前后向排布。数据类型与input一致。数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。
若传入有效batchSizesOptional,shape为[time_step * batch_size, hidden_size * D];若传入空指针batchSizesOptional,shape为[time_step, batch_size, hidden_size * D] 或 [batch_size, time_step, hidden_size * D]。
grad_hy (Tensor类型):必选入参,LSTM正向每层输出hidden在T时刻从下一个时间步传来的梯度。对应δh_next。多层双向时数据沿第0维按先双向后逐层排布。数据类型与input一致。数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor,shape为[numLayers * D, batch_size, hidden_size]。
grad_cy (Tensor类型):必选入参,LSTM每层输出cell在T时刻从下一个时间步传来的梯度。对应δc_next。多层双向时数据沿第0维按先双向后逐层排布。数据类型与input一致。数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。shape为[numLayers * D, batch_size, hidden_size]。
input (Tensor类型):必选参数,LSTM单元的输入向量,数据格式支持ND,数据类型支持FLOAT16、FLOAT32,非连续Tensor。若传入有效batchSizesOptional,shape为[time_step * batch_size, input_size];
若传入空指针batchSizesOptional,shape为[time_step, batch_size, input_size] 或 [batch_size, time_step, input_size],batch_size表示序列组数;time_step表示时间维度;input_size表示输入的特征数量。
hx (TensorList):必选入参,LSTM每层的初始hidden和cell状态。对应0时刻的h(t-1)与c(t-1)。列表长度为2,包含h_0和c_0。多层双向时每个tensor数据沿第0维按先双向后逐层排布。数据类型与input一致。数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。列表内每个tensor shape为[D * num_layers, batch_size, hidden_size]。
params (TensorList):必选入参,LSTM每层的权重和偏置张量列表,对应公式中的w与b。bidirection为True时 `D = 2`,否则 `D = 1`,hasBiases为True时 `B = 2`,否则 `B = 1`。列表长度为 D * B * num_layers。当bidirection和hasBias均为True时排布为:[weight_ih_0, weight_hh_0, bias_ih_0, bias_hh_0, weight_ih_reverse_0, weight_hh_reverse_0, bias_ih_reverse_0, bias_hh_reverse_0]。hasBias为False时无bias项;bidirection为False时无reverse项。多层时逐层排布。数据类型与input一致。
数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。weight_ih: [4*hidden_size, cur_input_size];weight_hh: [4*hidden_size, hidden_size];bias_ih: [4*hidden_size];bias_hh: [4*hidden_size]。
i (Tensor类型):必选入参,LSTM正向中每层输出的输入门的激活值。对应公式中的i。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
j (Tensor类型):必选入参,LSTM正向中每层输出的候选cell状态的激活值。对应公式中的g。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
f (Tensor类型):必选入参,LSTM正向中每层遗忘门的激活值。对应公式中的f。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
o (Tensor类型):必选入参,LSTM正向中每层输出门的激活值。对应公式中的o。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
h (Tensor类型):必选入参,LSTM正向中每层的隐藏hidden状态。对应公式中的h。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
c (Tensor类型):必选入参,LSTM正向中每层的最终cell状态。对应公式中的c。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
tanhc (Tensor类型):必选入参,LSTM正向中每层最终cell状态经过tanh激活函数后的输出。对应公式中的tanh(c)。
原始输入形态:为Tensor列表,列表长度为 D * num_layers;多层双向场景下,列表内Tensor按「先双向、后多层」的顺序排布;
列表中单个Tensor特征:非连续Tensor,数据类型与input一致,支持FLOAT16、FLOAT32;shape为[time_step, batch_size, hidden_size];
代码处理后形态:将列表内Tensor在维度0上堆叠为单个大Tensor,最终shape为[D * num_layers, time_step, batch_size, hidden_size];若列表为空则输出空Tensor。
has_biases(bool):必选入参,表示是否有偏置b。
num_layers(int):必选入参,表示LSTM层数。值大于0。
dropout (float):表示随机掩码的概率。当前不支持该功能。
train (bool):必选入参,表示是否是训练模式。其中train = True时,在计算前向LSTM时会保存中间结果用于反向传播,train = False的时候,前向计算过程不保存中间结果。
bidirectional(bool):必选入参,表示是否是双向。bidirectional为true时D为2,bidirectional为false时D为1。
batch_first(bool):可选入参,表示输入数据格式是否是Batch在第一轴(B, T, H)。
batch_sizes(Tensor类型):可选参数,变长LSTM输入序列各个时刻的有效序列batch数。变长序列时支持。shape为[time_step]。
输出说明:
dx_out (Tensor):输出张量,输入input上的梯度,对应公式中的δx。shape与input一致。数据类型与input一致。shape为[time_step, batch_size, input_size] 或 [batch_size, time_step, input_size],数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。
out_hx_prev (TensorList):输出张量列表,由张量dh_prev_out和dc_prev_out拼接结果,列表长度为2。dh_prev_out是LSTM每层初始hidden的梯度,对应t=0时的δh_prev。数据类型与input一致。shape为[D * num_layers, batch_size, hidden_size],数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。dc_prev_out是多层双向时数据沿第0维按先双向后逐层排布。数据类型与input一致。shape为[D * num_layers, batch_size, hidden_size],数据格式支持ND,数据类型支持FLOAT16、FLOAT32。非连续Tensor。
dparams_out(TensorList):输出张量列表,权重和偏置的梯度张量列表。对应公式中的δw和δb。列表长度为 D * B * num_layers。排布与输入params一致。数据类型与input一致。dweight_ih: [4*hidden_size, cur_input_size],dweight_hh: [4*hidden_size, hidden_size],dbias: [4*hidden_size]
约束说明
确定性计算:_lstm_npu_backward默认支持确定性实现。
支持版本:
PyTorch 2.6及更高版本
支持的型号:
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列产品
Atlas 训练系列产品
调用示例:
import torch
import torch_npu
def _lstm_backward_npu(
grad_y: torch.Tensor,
grad_hy: torch.Tensor,
grad_cy: torch.Tensor,
input: torch.Tensor,
hx: List[torch.Tensor],
params: List[torch.Tensor],
i: torch.Tensor,
j: torch.Tensor,
f: torch.Tensor,
o: torch.Tensor,
h: torch.Tensor,
c: torch.Tensor,
tanhc: torch.Tensor,
has_biases: bool = True,
num_layers: int = 1,
dropout: float = 0.0,
train: bool = True,
bidirectional: bool = False,
batch_first: Optional[bool] = None,
batch_sizes: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]:
# 1. 强制所有张量转到NPU设备,且确保类型为float32(解决double dtype警告)
npu_device = torch.device('npu:0')
dtype = torch.float32 # NPU不支持double,统一用float32
# 类型+设备转换
def to_npu_tensor(t: torch.Tensor) -> torch.Tensor:
if t is None:
return t
return t.to(device=npu_device, dtype=dtype)
grad_y = to_npu_tensor(grad_y)
grad_hy = to_npu_tensor(grad_hy)
grad_cy = to_npu_tensor(grad_cy)
input = to_npu_tensor(input)
i = to_npu_tensor(i)
j = to_npu_tensor(j)
f = to_npu_tensor(f)
o = to_npu_tensor(o)
h = to_npu_tensor(h)
c = to_npu_tensor(c)
tanhc = to_npu_tensor(tanhc)
hx = [to_npu_tensor(t) for t in hx]
params = [to_npu_tensor(p) for p in params]
if batch_sizes is not None:
batch_sizes = to_npu_tensor(batch_sizes)
# 2. 处理batch_first默认值(接口声明中默认False)
if batch_first is None:
batch_first = False
# 3. 核心调用:严格匹配接口声明的入参类型
# 关键:i/j/f/o/h/c/tanhc直接传原始Tensor,不做chunk拆分!
output = torch_npu._lstm_backward_npu(
grad_y, # 0: Tensor grad_y
grad_hy, # 1: Tensor grad_hy
grad_cy, # 2: Tensor grad_cy
input, # 3: Tensor input
hx, # 4: Tensor[] hx (Python列表)
params, # 5: Tensor[] params (Python列表)
i, # 6: Tensor i (核心修复:传单个Tensor,非tuple)
j, # 7: Tensor j
f, # 8: Tensor f
o, # 9: Tensor o
h, # 10: Tensor h
c, # 11: Tensor c
tanhc, # 12: Tensor tanhc
has_biases, # 13: bool has_biases
num_layers, # 14: int num_layers
dropout, # 15: float dropout
train, # 16: bool train
bidirectional, # 17: bool bidirectional
batch_first=batch_first, # 18: bool? batch_first(关键字参数,接口声明要求)
batch_sizes=batch_sizes # 19: Tensor? batch_sizes(关键字参数)
)
# 4. 解析输出
input_grad = output[0] # 输入序列梯度
hx_prev_grad = output[1] # 初始hx梯度
param_grads_list = output[2] # 参数梯度列表
return input_grad, hx_prev_grad, param_grads_list
if __name__ == "__main__":
# 1. 初始化测试参数(严格匹配接口要求)
seq_len = 8
batch_size = 4
input_size = 6
hidden_size = 4
num_layers = 1
bidirectional = False
num_dir = 2 if bidirectional else 1
# 2. 构造NPU张量(float32,避免double dtype警告)
device = torch.device('npu:0')
dtype = torch.float32
# 输入序列 [seq_len, batch, input_size]
input = torch.randn(seq_len, batch_size, input_size, device=device, dtype=dtype)
# 输出梯度 [seq_len, batch, hidden_size*num_dir]
grad_y = torch.randn(seq_len, batch_size, hidden_size * num_dir, device=device, dtype=dtype)
# 最后一层h/c梯度 [num_layers*num_dir, batch, hidden_size]
grad_hy = torch.randn(num_layers * num_dir, batch_size, hidden_size, device=device, dtype=dtype)
grad_cy = torch.randn(num_layers * num_dir, batch_size, hidden_size, device=device, dtype=dtype)
# 初始h0/c0 [num_layers*num_dir, batch, hidden_size]
h0 = torch.randn(num_layers * num_dir, batch_size, hidden_size, device=device, dtype=dtype)
c0 = torch.randn(num_layers * num_dir, batch_size, hidden_size, device=device, dtype=dtype)
hx = [h0, c0]
# LSTM参数(w_ih, w_hh, b_ih, b_hh)
w_ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=dtype)
w_hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=dtype)
b_ih = torch.randn(4 * hidden_size, device=device, dtype=dtype)
b_hh = torch.randn(4 * hidden_size, device=device, dtype=dtype)
params = [w_ih, w_hh, b_ih, b_hh]
# 前向中间结果(核心:不拆分,直接传原始Tensor)
# 形状:[seq_len, batch, num_layers*num_dir*hidden_size]
i = torch.randn(seq_len, batch_size, num_layers * num_dir * hidden_size, device=device, dtype=dtype)
j = torch.randn_like(i)
f = torch.randn_like(i)
o = torch.randn_like(i)
h = torch.randn_like(i)
c = torch.randn_like(i)
tanhc = torch.tanh(c).to(dtype=dtype) # 确保dtype一致
input_grad, h_prev_grad, c_prev_grad, param_grads = _lstm_backward_npu(
grad_y=grad_y,
grad_hy=grad_hy,
grad_cy=grad_cy,
input=input,
hx=hx,
params=params,
i=i, # 直接传单个Tensor,无chunk
j=j,
f=f,
o=o,
h=h,
c=c,
tanhc=tanhc,
has_biases=True,
num_layers=num_layers,
dropout=0.0,
train=True,
bidirectional=bidirectional,
batch_first=False,
batch_sizes=None
)
"""
)
_add_torch_npu_docstr(
"npu_rotate_quant",
"""
接口原型:
torch_npu.npu_rotate_quant(Tensor x, Tensor rotation, *, Tensor? alpha=None, int? dst_dtype=None, int? axis=-1, str? round_mode="rint", int? scale_alg=0, float? dst_type_max=0.0, bool? transpose_y=False) -> (Tensor, Tensor)
torch_npu.npu_rotate_quant(Tensor x, Tensor rotation, *, Tensor? alpha=None, int? dst_dtype=None, int? axis=-1, str? round_mode="rint", int? scale_alg=0, float? dst_type_max=0.0, bool? transpose_y=False) -> (Tensor, Tensor)
功能描述
`npu_rotate_quant`是一种融合旋转(Rotate)和量化(Quant)的计算方法。该方法适用于需要对输入数据进行旋转变换后进行量化的场景,融合算子在底层能够对部分过程并行,达到性能优化的效果。
参数说明:
x(Tensor):必选输入,输入tensor。shape支持2维[m,n],数据类型支持`bfloat16`和`float16`,数据格式支持ND,支持非连续的Tensor。
rotation(Tensor):必选输入,旋转矩阵tensor。shape支持2维[k,k],数据类型支持`bfloat16`和`float16`,数据格式支持ND,支持非连续的Tensor。
alpha(`Tensor`):可选输入,旋转角度缩放因子,数据类型为1维`Tensor`,默认值为None。
dst_dtype: int类型, 指定量化输出的类型, 可选输入, 传None时当做torch.int8处理。支持的量化输出类型包括: torch.int8(1), torch.quint4x2(16), torch_npu.float4_e2m1fn_x2(296), torch.float8_e5m2(23), torch.float8_e4m3fn(24)。
axis: int类型, 指定量化输出的轴, 可选输入, 默认值为-1。
round_mode: str类型, 指定取整模式, 可选输入, 默认值为"rint"。
scale_alg: int类型, 指定scale算法, 可选输入, 默认值为0。
dst_type_max: float类型, 指定量化输出的最大值, 可选输入, 默认值为0.0。
transpose_y: bool类型, 指定输出是否转置, 可选输入, 默认值为False, 当前版本仅支持False。
输出说明:
y(`Tensor`):输出的量化结果,数据类型根据dst_dtype决定。数据格式支持ND,int场景中支持非连续的Tensor。
scale(`Tensor`):输出的量化因子,数据类型根据dst_dtype决定。当dst_dtype为MX类型(float4_e2m1fn_x2/float8_e5m2/float8_e4m3fn)时,scale为MX格式量化因子(uint8表示float8_e8m0);其他类型时,scale为per-token量化因子(float32)。
支持的型号:
Atlas A5 训练系列产品/Atlas A5 推理系列产品
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
约束说明
关于数据shape的约束,其中:
n:取值范围为128~16000,8字节对齐, n可以整除k。
transpose_y当前版本仅支持False。
当dst_dtype为quint4x2时,输入最后一维必须可被8整除。
当dst_dtype为float4_e2m1fn_x2时,输入最后一维必须可被2整除。
调用示例:
import torch
import torch_npu
import numpy as np
def test_rotate_quant_int8(M=512, N=1024, K=1024):
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output, output_scale = torch_npu.npu_rotate_quant(
x, rotation, dst_dtype=torch.int8
)
return output, output_scale
def test_rotate_quant_mxfp4(M=512, N=1024, K=1024):
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output, output_scale = torch_npu.npu_rotate_quant(
x, rotation, dst_dtype=torch_npu.float4_e2m1fn_x2, axis=-1, round_mode="rint"
x, rotation, dst_dtype=torch.int8
)
return output, output_scale
def main():
output, output_scale = test_rotate_quant_int8()
if __name__ == "__main__":
main()
"""
)
_add_torch_npu_docstr(
"npu_quant_max",
"""
功能描述:
算子功能: 对输入张量x进行量化操作,同时计算并输出绝对值最大值amax。
计算公式: y = cast(x * scale, dst_dtype), amax = max(|x|)
接口原型:
torch_npu.npu_quant_max(Tensor x, Tensor scale, *, str round_mode="rint", int dst_dtype=291) -> (Tensor, Tensor)
参数说明:
x: Tensor类型, 输入张量, 待量化的数据。数据格式支持ND, 支持非连续的Tensor。输入最大支持8维。
数据类型支持`float32`、`bfloat16`和`float16`。
scale: Tensor类型, 量化缩放因子。仅支持1维Tensor, shape为(1,)。数据格式支持ND。
数据类型支持float32。
round_mode: String类型, 可选参数, 舍入模式, 默认值为"rint"。支持的取值:
- "rint": 四舍六入五成双舍入模式, 适用于float8_e5m2/float8_e4m3fn输出类型。
- "round": 向最近整数舍入模式, 适用于hifloat8输出类型。
- "hybrid": 混合舍入模式, 适用于hifloat8输出类型。
dst_dtype: int类型, 可选参数, 指定输出y的数据类型对应的枚举值, 默认值为torch.float8_e5m2。支持的取值:
- torch_npu.hifloat8
- torch.float8_e5m2
- torch.float8_e4m3fn
输出说明:
返回两个Tensor:
- y: Tensor类型, 量化后的输出, 数据类型由dst_dtype指定, shape与输入x相同。
- amax: Tensor类型, 输入x的绝对值最大值, 数据类型与输入x相同, 一维Tensor, shape为[1]。
约束说明:
该接口支持推理、训练场景下使用。
该接口支持图模式。
x、scale这两个输入中不能含有None。
round_mode与dst_dtype的搭配需遵循约束: float8_e5m2/float8_e4m3fn仅支持"rint", hifloat8仅支持"round"或"hybrid"。
支持的PyTorch版本
PyTorch 2.1及以上
支持的NPU产品
Atlas 350加速卡
调用示例
单算子调用
import torch
import torch_npu
# 输入数据
x_tensor = torch.randn((16, 128), dtype=torch.bfloat16).npu()
scale = torch.tensor([2.0], dtype=torch.float32).npu()
# 调用算子
y, amax = torch_npu.npu_quant_max(x_tensor, scale, round_mode="rint", dst_dtype=torch.float8_e5m2)
图模式调用
import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.debug.graph_dump.type = 'pbtxt'
npu_backend = tng.get_npu_backend(compiler_config=config)
x_tensor = torch.randn((16, 128), dtype=torch.bfloat16).npu()
scale = torch.tensor([2.0], dtype=torch.float32).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, scale):
return torch_npu.npu_quant_max(x, scale, round_mode="rint", dst_dtype=torch.float8_e5m2)
cpu_model = Model()
model = cpu_model.npu()
model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
y, amax = model(x_tensor, scale)
"""
)
_add_torch_npu_docstr(
"npu_apply_rotary_pos_emb",
"""
接口原型:
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin, *, layout='BSND', rotary_mode='half') -> (Tensor, Tensor)
产品支持情况:
Atlas 350 加速卡
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列产品
功能描述:
为提升推理网络性能,将query和key两路算子融合为单路,在旋转位置编码计算中直接对结果执行原地更新。
参数说明:
query (Tensor): 必选参数,待执行旋转位置编码的第一个张量。数据类型支持bfloat16、float16、float32,数据格式支持ND。layout为TND时,shape为3维,其他layout场景下shape为4维。
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:不支持空Tensor,shape最后一维(D)必须等于128或者64。
Atlas 350 加速卡:支持空Tensor,shape最后一维(D)小于等于1024。
key (Tensor): 必选参数,待执行旋转位置编码的第二个张量。数据类型支持bfloat16、float16、float32,数据格式支持ND。layout为TND时,shape为3维,其他layout场景下shape为4维。
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:不支持空Tensor,shape最后一维(D)必须等于128或者64。
Atlas 350 加速卡:支持空Tensor,shape最后一维(D)小于等于1024。
cos (Tensor): 必选参数,旋转位置编码余弦值张量。数据类型支持bfloat16、float16、float32,数据格式支持ND。layout为TND时,shape为3维,其他layout场景下shape为4维。
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:不支持空Tensor,shape中B维度与query、key的B维度一致,shape第3维(N)必须等于1,shape最后一维(D)必须等于128或者64。
Atlas 350 加速卡:支持空Tensor,shape中B维度与query、key的B维度一致,或者等于1,shape中N维度必须等于1,shape最后一维(D)小于等于1024。
sin (Tensor): 必选参数,旋转位置编码正弦值张量。数据类型支持bfloat16、float16、float32,数据格式支持ND。layout为TND时,shape为3维,其他layout场景下shape为4维。
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:不支持空Tensor,shape中B维度与query、key的B维度一致,shape第3维(N)必须等于1,shape最后一维(D)必须等于128或者64。
Atlas 350 加速卡:支持空Tensor,shape中B维度与query、key的B维度一致,或者等于1,shape最后一维(D)小于等于1024。
layout (str): 可选参数,张量布局格式,支持"BSND"、"SBND"、"BNSD"、"TND"。默认"BSND"。
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持BSND的4维Tensor、TND的3维Tensor。
Atlas 350 加速卡:支持BSND、SBND、BNSD的4维Tensor,TND的3维Tensor。
rotary_mode (str): 可选参数,旋转编码模式,支持"half"、"quarter"、"interleave",默认值为"half"。
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持"half"模式。
Atlas 350 加速卡:支持"half"、"interleave"、"quarter"模式。
返回值说明:
query_out (Tensor): 原地更新后的query张量。
key_out (Tensor): 原地更新后的key张量。
约束说明:
Atlas 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:
1. layout为"BSND",query、key、cos、sin输入shape的前2维(B、S)必须相等;layout为"TND"时,第1维(T)必须相等。
2. query、key输入shape的最后一维(D)必须相等,cos、sin输入shape的最后一维(D)必须相等。
3. 输入张量query、key、cos、sin的数据类型必须相同。
4. layout为"BSND"时,输入query的shape用(q_b, q_s, q_n, q_d)表示,key的shape用(q_b, q_s, k_n, q_d)表示,cos和sin的shape用(q_b, q_s, 1, cos_d)表示。其中,b表示batch_size,s表示seq_length,n表示head_num,d表示head_dim。layout为"TND"时,输入query的shape用(q_t, q_n, q_d)表示,key的shape用(q_t, k_n, q_d)表示,cos和sin的shape用(q_t, 1, cos_d)表示。其中,t表示b和s合轴,n表示head_num,d表示head_dim。
Atlas 350 加速卡:
1. 对于任意layout,query与key除N维度外其他维度必须相同;query、key输入shape的最后一维(D)必须相等,cos、sin输入shape的最后一维(D)必须相等,且小于等于query、key输入shape的最后一维(D)。
2. 输入张量query、key、cos、sin的数据类型必须相同。
3. rotary_mode为"half"和"interleave"时,输入shape最后一维必须被2整除;rotary_mode为"quarter"时,输入shape最后一维必须被4整除。
Atlas 推理系列产品不支持`bfloat16`。
支持版本:
PyTorch 2.7.1+
调用示例:
import torch
import torch_npu
def test_npu_apply_rotary_pos_emb():
batch = 1
seq_len = 64
num_heads = 8
head_dim = 64
query = torch.randn(batch, seq_len, num_heads, head_dim, dtype=torch.float16).npu()
key = torch.randn(batch, seq_len, num_heads, head_dim, dtype=torch.float16).npu()
cos = torch.randn(batch, seq_len, 1, head_dim, dtype=torch.float16).npu()
sin = torch.randn(batch, seq_len, 1, head_dim, dtype=torch.float16).npu()
q_out, k_out = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin,
layout="BSND",
rotary_mode="half"
)
print("API: npu_apply_rotary_pos_emb test passed!")
if __name__ == "__main__":
test_npu_apply_rotary_pos_emb()
"""
)
_add_torch_npu_docstr(
"npu_fused_linear_online_max_sum",
"""
torch_npu.npu_fused_linear_online_max_sum(Tensor input, Tensor weight, Tensor target, int vocab_start_index, int vocab_end_index, bool return_logits=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
功能描述
词汇表并行场景下融合矩阵乘与交叉熵前处理算子。支持vocabulary_size维度切卡融合MatMul与CELoss,需与npu_fused_cross_entropy_loss_with_max_sum配合使用。
参数说明
input (Tensor) - MatMul计算的左矩阵,2维,数据类型支持bfloat16、float16。
weight (Tensor) - MatMul计算的右矩阵,2维,数据类型与input一致。
target (Tensor) - 目标索引,1维,数据类型支持int32、int64。
vocab_start_index (int) - 本卡分配的词汇表起始索引。
vocab_end_index (int) - 本卡分配的词汇表结束索引。
return_logits (bool,默认值为False) - 是否返回MatMul结果。True走高性能分支,False走省显存分支。
输出说明
返回6个Tensor: logits_max(float32), sum_exp_logits(float32), predicted_logits(float32), target_mask(uint8), masked_target(与target类型一致), vocab_parallel_logits(与input类型一致,return_logits为False时返回None)。
约束说明
input与weight数据类型必须一致。vocabStartIndex不可小于0。vocabEndIndex不可小于vocabStartIndex。默认确定性实现。
示例
>>> input_tensor = torch.randn(128, 64, dtype=torch.float16).npu()
>>> weight_tensor = torch.randn(256, 64, dtype=torch.float16).npu()
>>> target_tensor = torch.randint(0, 256, (128,), dtype=torch.int32).npu()
>>> logits_max, sum_exp_logits, predicted_logits, target_mask, masked_target, vocab_parallel_logits = \\
... torch_npu.npu_fused_linear_online_max_sum(input_tensor, weight_tensor, target_tensor, 0, 64, return_logits=True)
"""
)
_add_torch_npu_docstr(
"npu_fused_cross_entropy_loss_with_max_sum",
"""
torch_npu.npu_fused_cross_entropy_loss_with_max_sum(Tensor logits_max, Tensor sum_exp_logits, Tensor predicted_logits, *, float? label_smoothing=0.0, Tensor? input=None, Tensor? weight=None, Tensor? vocab_parallel_logits=None) -> (Tensor, Tensor)
功能描述
词汇表并行场景下交叉熵计算模块的一部分,计算Loss与Softmax结果。需配合npu_fused_linear_online_max_sum使用,多卡场景下需先对logits_max和sum_exp_logits执行全局通信。
参数说明
logits_max (Tensor) - 全局通信后的MatMul结果各行最大值,1维,数据类型float32。
sum_exp_logits (Tensor) - 全局通信后的exp累加结果,1维,数据类型float32,shape与logits_max一致。
predicted_logits (Tensor) - 全局通信后的预测logits,1维,数据类型float32,shape与logits_max一致。
label_smoothing (float,默认值为0.0) - 标签平滑系数,当前仅支持0。
input (Tensor,可选,默认值为None) - 当前仅支持None。
weight (Tensor,可选,默认值为None) - 当前仅支持None。
vocab_parallel_logits (Tensor,可选,默认值为None) - MatMul计算结果,传入时计算Softmax输出,不传入时返回None。
输出说明
返回2个Tensor: loss(float32)和softmax(float32,vocab_parallel_logits为None时返回None)。
约束说明
logits_max、sum_exp_logits、predicted_logits的shape需一致。label_smoothing当前仅支持0。默认确定性实现。
示例
>>> logits_max = torch.randn(128, dtype=torch.float32).npu()
>>> sum_exp_logits = torch.randn(128, dtype=torch.float32).npu()
>>> predicted_logits = torch.randn(128, dtype=torch.float32).npu()
>>> vocab_parallel_logits = torch.randn(128, 256, dtype=torch.float16).npu()
>>> loss, softmax = torch_npu.npu_fused_cross_entropy_loss_with_max_sum(logits_max, sum_exp_logits, predicted_logits, vocab_parallel_logits=vocab_parallel_logits)
"""
)
_add_torch_npu_docstr(
"npu_fused_linear_cross_entropy_loss_with_max_sum_backward",
"""
torch_npu.npu_fused_linear_cross_entropy_loss_with_max_sum_backward(Tensor grad, Tensor input, Tensor weight, Tensor target_mask, Tensor masked_target, float label_smoothing=0.0, Tensor? logits_max=None, Tensor? sum_exp_logits=None, Tensor? softmax=None) -> (Tensor, Tensor)
功能描述
词汇表并行场景下交叉熵损失计算的梯度算子,计算叶子节点input和weight的梯度。支持高性能模式(传入Softmax)和省显存模式(传入logits_max和sum_exp_logits)。
参数说明
grad (Tensor) - 当前节点的梯度,1维,数据类型float32。
input (Tensor) - 矩阵乘的输入矩阵,2维,数据类型支持float16、bfloat16。
weight (Tensor) - 矩阵乘的权重矩阵,2维,数据类型与input一致,第0维长度不支持小于128。
target_mask (Tensor) - 目标词ID是否在范围内的位掩码,1维,数据类型uint8。
masked_target (Tensor) - 目标词ID映射到当前设备的局部索引,1维,数据类型支持int32、int64。
label_smoothing (float,默认值为0.0) - 标签平滑系数,当前仅支持0。
logits_max (Tensor,可选) - 全局logits最大值,softmax为None时必须提供。
sum_exp_logits (Tensor,可选) - 处理后的logits,softmax为None时必须提供。
softmax (Tensor,可选) - Softmax计算结果,传入时走高性能模式。
输出说明
返回2个Tensor: input_grad(与input类型一致)和weight_grad(与input类型一致)。
约束说明
input与weight数据类型必须一致。softmax为None时logits_max和sum_exp_logits必须同时提供。默认确定性实现。
示例
>>> grad = torch.randn(128, dtype=torch.float32).npu()
>>> input_tensor = torch.randn(128, 64, dtype=torch.float16).npu()
>>> weight_tensor = torch.randn(256, 64, dtype=torch.float16).npu()
>>> target_mask = torch.zeros((128 + 7) // 8, dtype=torch.uint8).npu()
>>> masked_target = torch.randint(0, 256, (128,), dtype=torch.int32).npu()
>>> softmax = torch.randn(128, 256, dtype=torch.float32).npu()
>>> input_grad, weight_grad = torch_npu.npu_fused_linear_cross_entropy_loss_with_max_sum_backward(
... grad, input_tensor, weight_tensor, target_mask, masked_target, softmax=softmax)
"""
)