昇腾迁移融合算子API替换样例

部分torch原生的API在下发和执行时会包括多个小算子,下发和执行耗时较长,可以通过替换成NPU API来使能融合算子,提升训练性能。

torch_npu API的功能和参数描述见API列表

优化器替换

替换优化器一般都能有较大的性能受益,可以优先考虑将torch原生的优化器替换为昇腾提供的亲和优化器。下文以AdamW优化器为例,其他优化器的替换方式一致。

torch_npu.optim.NpuFusedAdamW

torch原生代码示例如下:

import torch
optimizer = torch.optim.AdamW(
  model.parameters(),
  learning_rate,
  momentum=momentum,
  weight_decay=weight_decay
)

torch_npu代码示例如下:

import torch_npu
from torch_npu.contrib import transfer_to_npu

optimizer = torch_npu.optim.NpuFusedAdamW(
  model.parameters(),
  learning_rate,
  momentum=momentum,
  weight_decay=weight_decay
)

亲和API替换

optimizer.clip_grad_norm_fused_

在替换为npu亲和梯度裁剪api之前,请确保代码中已使用npu亲和优化器。

torch原生代码示例如下:

import torch
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

optimizer = torch_npu.optim.NpuFusedAdamW(model.parameters(), lr = lr)
optimizer.clip_grad_norm_fused_(max_norm=10, norm_type=2)

torch_npu.npu_confusion_transpose

示例一

torch原生代码示例如下:

import torch

data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = torch.permute(data, (0, 2, 1, 3)).reshape(height, batch, channel*width)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = torch_npu.npu_confusion_transpose(data, (0, 2, 1, 3), (height, batch, channel*width), transpose_first=True)

示例二

torch原生代码示例如下:

import torch

data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = data.view(batch, height*channel*width).transpose(1, 0)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = torch_npu.npu_confusion_transpose(data, (1, 0), (batch, height*channel*width), transpose_first=False)

torch_npu.npu_scaled_masked_softmax

注意atten_mask和atten_scores张量最后一维的取值范围为32-8192,且必须为32的整数倍。

torch原生代码示例如下:

import torch
x = torch.randn([64, 8, 128, 256]).cuda()
mask = torch.randn([1, 1, 128, 256]).cuda() >= 1
scale = 0.8

output = torch.softmax((x * scale).masked_fill(mask, -1*torch.inf), dim=-1)
# shape is (64, 8, 128, 256)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

x = torch.randn([64, 8, 128, 256]).cuda()
mask = torch.randn([1, 1, 128, 256]).cuda() >= 1
scale = 0.8

output = torch_npu.npu_scaled_masked_softmax(x, mask, scale)
# shape is (64, 8, 128, 256)

torch_npu.fast_gelu

示例一

替换torch.nn.functional.fast_gelu方法,实现上有些差异,激活函数输出结果会不同。

torch原生代码示例如下:

import torch
input_data = torch.rand(64, 32).cuda()
result = torch.nn.functional.gelu(input_data)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

input_data = torch.rand(64, 32).cuda()
result = torch_npu.fast_gelu(input_data)

示例二

继承torch.nn.GELU,基于torch_npu.fast_gelu重写forward方法。

torch原生代码示例如下:

import torch
input_data = torch.rand(64, 32).cuda()
gelu_module = torch.nn.GELU().cuda()
result3 = gelu_module(input_data)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

# 继承torch.nn.GELU,基于torch_npu.fast_gelu重写forward方法
class FastGelu(torch.nn.GELU):
    def forward(self, input_data):
        return torch_npu.fast_gelu(input_data)

input_data = torch.rand(64, 32).cuda()
fast_gelu_module = FastGelu().cuda()
result = fast_gelu_module(input_data)

torch_npu.npu_rms_norm

输入数据dtype仅支持float16、bfloat16、float。

torch原生代码示例如下:

import torch

class TorchRMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim)).cuda()

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

input_data = torch.randn(128, 256).cuda()
torch_rms_norm = TorchRMSNorm((128, 256))
result = torch_rms_norm(input_data)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

class NpuRMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim)).cuda()

    def forward(self, x):
        return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]

input_data = torch.randn(128, 256).cuda()
npu_rms_norm = NpuRMSNorm((128, 256))
result = npu_rms_norm(input_data)

torch_npu.npu_swiglu

输入数据dtype仅支持float16、bfloat16、float。

torch原生代码示例如下:

import torch
class TorchSwiGlu(torch.nn.Module):
    def __init__(self, dim = -1):
        super().__init__()
        self.dim = dim

    def _swiglu(self, x):
        x = torch.chunk(x, 2, -1)
        return torch.nn.functional.silu(x[0]) * x[1]

    def forward(self, x):
        output = self._swiglu(x)
        return output

input_data = torch.randn(128, 256).cuda()
torch_swiglu = TorchSwiGlu()
result = torch_swiglu(input_data)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

class NpuSwiGlu(torch.nn.Module):
    def __init__(self, dim = -1):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        dim = -1
        return torch_npu.npu_swiglu(x, dim=dim)

input_data = torch.randn(128, 256).cuda()
npu_swiglu = NpuSwiGlu()
result = npu_swiglu(input_data)

torch_npu.npu_rotary_mul

torch原生代码示例如下:

import torch

x = torch.rand([2, 8192, 5, 128]).cuda()
r1 = torch.rand([1, 8192, 1, 128]).cuda()
r2 = torch.rand([1, 8192, 1, 128]).cuda()

def torch_func(x, r1, r2):
   x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
   # x1, x2 = torch.chunk(x, 2, -1)
   x_new = torch.cat((-x2, x1), dim=-1)
   output = r1 * x + r2 * x_new
   return output

result = torch_func(x, r1, r2)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

x = torch.rand([2, 8192, 5, 128]).cuda()
r1 = torch.rand([1, 8192, 1, 128]).cuda()
r2 = torch.rand([1, 8192, 1, 128]).cuda()

result = torch_npu.npu_rotary_mul(x, r1, r2)

torch_npu.npu_fusion_attention

torch原生代码示例如下:

import torch

class TorchFlashAttention():
    def supported_op_exec(self, query, key, value, atten_mask=None):
        scale = 0.099
        qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)

        if atten_mask is not None:
            qk.masked_fill_(atten_mask.npu(), torch.tensor(-float('inf')).npu())
        softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(torch.float16)
        output = torch.matmul(softmax_res, value)
        output = output.transpose(1, 2)
        output = output.reshape(output.shape[0], output.shape[1], -1)
        return output

    def custom_op_exec(self, query, key, value, atten_mask=None):
        scale = 0.099
        return torch_npu.npu_fusion_attention(
            query, key, value, head_num=32, input_layout="BSH", scale=scale, atten_mask=atten_mask)

    def trans_BNSD2BSH(self, tensor: torch.Tensor):
        tensor = torch.transpose(tensor, 1, 2)
        tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
        return tensor

    def test_torch_flash_attention(self, device="npu"):
        query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        atten_mask = torch.randn(1, 1, 128, 128, dtype=torch.float16).npu() >= 0

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()

        result = self.supported_op_exec(query.npu(), key.npu(), value.npu(), atten_mask=atten_mask)
        # result shape (1, 128, 4096)

torch_npu代码示例如下:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu


class NPUFlashAttention():

    def npu_exec(self, query, key, value, atten_mask=None):
        scale = 0.099
        return torch_npu.npu_fusion_attention(
            query, key, value, head_num=32, input_layout="BSH", scale=scale, atten_mask=atten_mask)

    def trans_BNSD2BSH(self, tensor: torch.Tensor):
        tensor = torch.transpose(tensor, 1, 2)
        tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
        return tensor

    def test_npu_flash_attention(self, device="npu"):
        query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        atten_mask = torch.randn(1, 1, 128, 128, dtype=torch.float16).npu() >= 0

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()

        result, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.npu_exec(q_npu, k_npu, v_npu, atten_mask)
        # result shape (1, 128, 4096)