npu_apply_fused_adamw_v2 对外接口
接口原型
npu_apply_fused_adamw_v2(var, grad, m, v, max_grad_norm, step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize)
npu_apply_fused_adamw_v2接口用于更新adamw优化器中的var(模型参数), m(一阶矩动量), v(二阶矩动量),max_grad_norm(训练过程中最大的二阶矩动量)这四个参数。
import math
import torch
import torch_npu
import numpy as np
# 接口内部计算逻辑示例如下
def npu_apply_fused_adamw_v2(var, grad, m, v, max_grad_norm, step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize):
var_dtype, m_dtype, v_dtype, grad_dtype, step_dtype, max_grad_norm_dtype = \
var.dtype, m.dtype, v.dtype, grad.dtype, step.dtype, max_grad_norm.dtype
is_var_dtype_bf16_fp16 = "bfloat16" in str(var_dtype) or "float16" in str(var_dtype)
is_grad_dtype_bf16_fp16 = "bfloat16" in str(grad_dtype) or "float16" in str(grad_dtype)
if is_var_dtype_bf16_fp16:
adamw_params = [
var.to(torch.float32), grad.to(torch.float32), m.to(torch.float32), v.to(torch.float32),
max_grad_norm.to(torch.float32), step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize
]
elif is_grad_dtype_bf16_fp16:
adamw_params = [
var, grad.to(torch.float32), m, v, max_grad_norm.to(torch.float32), step, lr, beta1, beta2,
weight_decay, eps, amsgrad, maximize
]
else:
adamw_params = [
var, grad, m, v, max_grad_norm, step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize
]
if "int64" in str(step_dtype):
step_fp32 = step.to(torch.float32)
adamw_params[5] = step_fp32
def single_tensor_adamw(*args):
(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, step_t,
lr, beta1, beta2, weight_decay, eps, amsgrad, maximize) = args
dtype1 = param.dtype
dtype2 = grad.dtype
lr = np.float32(lr)
beta1 = np.float32(beta1)
beta2 = np.float32(beta2)
weight_decay = np.float32(weight_decay)
eps = np.float32(eps)
if dtype1 != dtype2:
grad = grad.to(dtype1)
max_exp_avg_sq = max_exp_avg_sq.to(dtype1)
if maximize:
grad = -grad
step = step_t
step = step.item()
param = param * (1 - lr * weight_decay)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
step_size = lr / bias_correction1
bias_correction2_sqrt = math.sqrt(bias_correction2)
if amsgrad:
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
denom = (max_exp_avg_sq.sqrt() / bias_correction2_sqrt) + eps
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt) + eps
param.addcdiv_(exp_avg, denom, value=-step_size)
if dtype1 != dtype2:
max_exp_avg_sq = max_exp_avg_sq.to(dtype2)
return param, exp_avg, exp_avg_sq, max_exp_avg_sq
res_var, res_m, res_v, res_max_grad_norm = single_tensor_adamw(*adamw_params)
if is_var_dtype_bf16_fp16:
res_var, res_m, res_v, res_max_grad_norm = (
res_var.to(var_dtype), res_m.to(var_dtype),
res_v.to(var_dtype), res_max_grad_norm.to(max_grad_norm_dtype)
)
elif is_grad_dtype_bf16_fp16:
res_max_grad_norm = res_max_grad_norm.to(max_grad_norm_dtype)
var.copy_(res_var)
m.copy_(res_m)
v.copy_(res_v)
max_grad_norm.copy_(res_max_grad_norm)
输入
var:必选输入,数据类型为tensor(float32)或tensor(float16)或tensor(bfloat16),表示模型参数。接受任意shape,但需保持var, grad, m, v, max_grad_normshape相同。grad:必选输入,数据类型为tensor(float32)或tensor(float16)或tensor(bfloat16),表示模型参数的梯度。接受任意shape,但需保持var, grad, m, v, max_grad_normshape相同。m:必选输入,数据类型必须与var完全一致,表示一阶矩动量。接受任意shape,但需保持var, grad, m, v, max_grad_normshape相同。v:必选输入,数据类型必须与var完全一致,表示二阶矩动量。接受任意shape,但需保持var, grad, m, v, max_grad_normshape相同。max_grad_norm:该参数在amsgrad为True时为必选输入,在amsgrad为False时为可选输入,数据类型为tensor(float32)或tensor(float16)或tensor(bfloat16),表示训练过程中最大的二阶矩动量。接受任意shape,但需保持var, grad, m, v, max_grad_normshape相同。step:必选输入,数据类型为tensor(int64),shape:(1,),表示当前为第几步。lr:可选属性,数据类型为float32,默认值:1e-3。表示学习率。beta1:可选属性,数据类型为float32,默认值:0.9。表示一阶矩动量的衰减率。beta2:可选属性,数据类型为float32,默认值:0.999。表示二阶矩动量的衰减率。weight_decay:可选属性,数据类型为float32,默认值:0.0。表示模型参数的衰减率。eps:可选属性,数据类型为float32,默认值:1e-8。表示一个极小的数。amsgrad:可选属性,数据类型为bool,默认值:False。表示是否使用训练过程中最大的二阶矩动量。maximize:可选属性,数据类型为bool,默认值:False。表示是否最大化参数。
支持的输入数据类型组合:
| 参数名称 | 组合1 | 组合2 | 组合3 | 组合4 | 组合5 | 组合6 | 组合7 | 组合8 | 组合9 | 组合10 | 组合11 | 组合12 | 组合13 | 组合14 | 组合15 | 组合16 | 组合17 | 组合18 | 组合19 | 组合20 | 组合21 | 组合22 | 组合23 | 组合24 | 组合25 | 组合26 | 组合27 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| var | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) |
| grad | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) |
| m | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) |
| v | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float32) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(float16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) | tensor(bfloat16) |
| max_grad_norm | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) | tensor(float32) | tensor(float16) | tensor(bfloat16) |
| step | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) | tensor(int64) |
| lr | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 |
| beta1 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 |
| beta2 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 |
| weight_decay | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 |
| eps | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 | float32 |
| amsgrad | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool |
| maximize | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool | bool |
输出
该接口无输出,该接口调用后会inplace更新入参的 var, m, v, max_grad_norm
调用示例
- 输入 var, grad, m, v, max_grad_norm, step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize
- 调用 npu_apply_fused_adamw_v2 实现 var, m, v 的 inplace 更新
import torch
import torch_npu
from mindspeed.ops.npu_apply_fused_adamw_v2 import npu_apply_fused_adamw_v2
var = torch.full((10, 10), 0.5).to(torch.float32).npu()
grad = torch.full((10, 10), 0.5).to(torch.float32).npu()
m = torch.full((10, 10), 0.9).to(torch.float32).npu()
v = torch.full((10, 10), 0.9).to(torch.float32).npu()
max_grad_norm = torch.full((10, 10), 0.9).to(torch.float32).npu()
step = torch.full((1, ), 1).to(torch.int64).npu()
lr, beta1, beta2, weight_decay, eps, amsgrad, maximize = 1e-3, 0.9999, 0.9999, 0.0, 1e-8, False, False
npu_apply_fused_adamw_v2(var, grad, m, v, max_grad_norm, step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize)