Examples for Fused Operator API Replacement During Migration to Ascend
Some native torch APIs involve multiple small operators during delivery and execution, resulting in long durations. You can replace these APIs with NPU APIs to enable fused operators and improve training performance.
For details about the functions and parameters of torch_npu APIs, see the torch_npu APIs.
Optimizer Replacement
Replacing an optimizer generally provides significant performance benefits. Prioritize replacing native torch optimizers with Ascend affinity optimizers. The following example uses the AdamW optimizer. The replacement method also applies to other optimizers.
torch_npu.optim.NpuFusedAdamW
Native torch code example:
import torch
optimizer = torch.optim.AdamW(
model.parameters(),
learning_rate,
momentum=momentum,
weight_decay=weight_decay
)
torch_npu code example:
import torch_npu
from torch_npu.contrib import transfer_to_npu
optimizer = torch_npu.optim.NpuFusedAdamW(
model.parameters(),
learning_rate,
momentum=momentum,
weight_decay=weight_decay
)
Affinity API Replacement
optimizer.clip_grad_norm_fused_
Before replacing the API with the NPU affinity gradient clipping API, ensure that an NPU affinity optimizer is already used in the code.
Native torch code example:
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
optimizer = torch_npu.optim.NpuFusedAdamW(model.parameters(), lr = lr)
optimizer.clip_grad_norm_fused_(max_norm=10, norm_type=2)
torch_npu.npu_confusion_transpose
Example 1
Native torch code example:
import torch
data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = torch.permute(data, (0, 2, 1, 3)).reshape(height, batch, channel*width)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = torch_npu.npu_confusion_transpose(data, (0, 2, 1, 3), (height, batch, channel*width), transpose_first=True)
Example 2
Native torch code example:
import torch
data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = data.view(batch, height*channel*width).transpose(1, 0)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
data = torch.rand(64, 3, 64, 128).cuda()
batch, channel, height, width = data.shape
result = torch_npu.npu_confusion_transpose(data, (1, 0), (batch, height*channel*width), transpose_first=False)
torch_npu.npu_scaled_masked_softmax
Note that the value of the last dimension for the atten_mask and atten_scores tensors must be within the range of [32, 8192] and must be a multiple of 32.
Native torch code example:
import torch
x = torch.randn([64, 8, 128, 256]).cuda()
mask = torch.randn([1, 1, 128, 256]).cuda() >= 1
scale = 0.8
output = torch.softmax((x * scale).masked_fill(mask, -1*torch.inf), dim=-1)
# shape is (64, 8, 128, 256)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
x = torch.randn([64, 8, 128, 256]).cuda()
mask = torch.randn([1, 1, 128, 256]).cuda() >= 1
scale = 0.8
output = torch_npu.npu_scaled_masked_softmax(x, mask, scale)
# shape is (64, 8, 128, 256)
torch_npu.fast_gelu
Example 1
Replace the torch.nn.functional.gelu method. There are implementation differences, and the output of the activation function is different.
Native torch code example:
import torch
input_data = torch.rand(64, 32).cuda()
result = torch.nn.functional.gelu(input_data)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
input_data = torch.rand(64, 32).cuda()
result = torch_npu.fast_gelu(input_data)
Example 2
Inherit from torch.nn.GELU and rewrite the forward method based on torch_npu.fast_gelu.
Native torch code example:
import torch
input_data = torch.rand(64, 32).cuda()
gelu_module = torch.nn.GELU().cuda()
result3 = gelu_module(input_data)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
# Inherit from torch.nn.GELU and rewrite the forward method based on torch_npu.fast_gelu
class FastGelu(torch.nn.GELU):
def forward(self, input_data):
return torch_npu.fast_gelu(input_data)
input_data = torch.rand(64, 32).cuda()
fast_gelu_module = FastGelu().cuda()
result = fast_gelu_module(input_data)
torch_npu.npu_rms_norm
The input dtype supports only float16, bfloat16, or float.
Native torch code example:
import torch
class TorchRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)).cuda()
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
input_data = torch.randn(128, 256).cuda()
torch_rms_norm = TorchRMSNorm((128, 256))
result = torch_rms_norm(input_data)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
class NpuRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)).cuda()
def forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
input_data = torch.randn(128, 256).cuda()
npu_rms_norm = NpuRMSNorm((128, 256))
result = npu_rms_norm(input_data)
torch_npu.npu_swiglu
The input dtype supports only float16, bfloat16, or float.
Native torch code example:
import torch
class TorchSwiGlu(torch.nn.Module):
def __init__(self, dim = -1):
super().__init__()
self.dim = dim
def _swiglu(self, x):
x = torch.chunk(x, 2, -1)
return torch.nn.functional.silu(x[0]) * x[1]
def forward(self, x):
output = self._swiglu(x)
return output
input_data = torch.randn(128, 256).cuda()
torch_swiglu = TorchSwiGlu()
result = torch_swiglu(input_data)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
class NpuSwiGlu(torch.nn.Module):
def __init__(self, dim = -1):
super().__init__()
self.dim = dim
def forward(self, x):
dim = -1
return torch_npu.npu_swiglu(x, dim=dim)
input_data = torch.randn(128, 256).cuda()
npu_swiglu = NpuSwiGlu()
result = npu_swiglu(input_data)
torch_npu.npu_rotary_mul
Native torch code example:
import torch
x = torch.rand([2, 8192, 5, 128]).cuda()
r1 = torch.rand([1, 8192, 1, 128]).cuda()
r2 = torch.rand([1, 8192, 1, 128]).cuda()
def torch_func(x, r1, r2):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
# x1, x2 = torch.chunk(x, 2, -1)
x_new = torch.cat((-x2, x1), dim=-1)
output = r1 * x + r2 * x_new
return output
result = torch_func(x, r1, r2)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
x = torch.rand([2, 8192, 5, 128]).cuda()
r1 = torch.rand([1, 8192, 1, 128]).cuda()
r2 = torch.rand([1, 8192, 1, 128]).cuda()
result = torch_npu.npu_rotary_mul(x, r1, r2)
torch_npu.npu_fusion_attention
Native torch code example:
import torch
class TorchFlashAttention():
def supported_op_exec(self, query, key, value, atten_mask=None):
scale = 0.099
qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
if atten_mask is not None:
qk.masked_fill_(atten_mask.npu(), torch.tensor(-float('inf')).npu())
softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(torch.float16)
output = torch.matmul(softmax_res, value)
output = output.transpose(1, 2)
output = output.reshape(output.shape[0], output.shape[1], -1)
return output
def custom_op_exec(self, query, key, value, atten_mask=None):
scale = 0.099
return torch_npu.npu_fusion_attention(
query, key, value, head_num=32, input_layout="BSH", scale=scale, atten_mask=atten_mask)
def trans_BNSD2BSH(self, tensor: torch.Tensor):
tensor = torch.transpose(tensor, 1, 2)
tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
return tensor
def test_torch_flash_attention(self, device="npu"):
query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
atten_mask = torch.randn(1, 1, 128, 128, dtype=torch.float16).npu() >= 0
q_npu = self.trans_BNSD2BSH(query).npu()
k_npu = self.trans_BNSD2BSH(key).npu()
v_npu = self.trans_BNSD2BSH(value).npu()
result = self.supported_op_exec(query.npu(), key.npu(), value.npu(), atten_mask=atten_mask)
# result shape (1, 128, 4096)
torch_npu code example:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
class NPUFlashAttention():
def npu_exec(self, query, key, value, atten_mask=None):
scale = 0.099
return torch_npu.npu_fusion_attention(
query, key, value, head_num=32, input_layout="BSH", scale=scale, atten_mask=atten_mask)
def trans_BNSD2BSH(self, tensor: torch.Tensor):
tensor = torch.transpose(tensor, 1, 2)
tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
return tensor
def test_npu_flash_attention(self, device="npu"):
query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
atten_mask = torch.randn(1, 1, 128, 128, dtype=torch.float16).npu() >= 0
q_npu = self.trans_BNSD2BSH(query).npu()
k_npu = self.trans_BNSD2BSH(key).npu()
v_npu = self.trans_BNSD2BSH(value).npu()
result, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.npu_exec(q_npu, k_npu, v_npu, atten_mask)
# result shape (1, 128, 4096)