npu_apply_fused_ema_adamw 对外接口

接口原型

npu_apply_fused_ema_adamw(grad, var, m, v, s, step, lr, ema_decay, beta1, beta2, eps, mode, bias_correction, weight_decay)-> var, m, v, s

npu_apply_fused_ema_adamw接口用于更新fused_ema_adamw优化器中的var(模型参数), m(一阶矩动量), v(二阶矩动量), s(ema模型参数)这四个参数。

# 接口内部计算逻辑示例如下
def npu_apply_fused_ema_adamw(grad, var, m, v, s, step, lr, ema_decay, 
                              beta1, beta2, eps, mode, bias_correction,
                              weight_decay):
    beta1_correction = 1 - torch.pow(beta1, step) * bias_correction
    beta2_correction = 1 - torch.pow(beta2, step) * bias_correction
    grad_ = grad + weight_decay * var * (1 - mode)
    m_ = beta1 * m + (1 - beta1) * grad_
    v_ = beta2 * v + (1 - beta2) * grad_ * grad_
    next_m = m_ / beta1_correction
    next_v = v_ / beta2_correction
    denom = torch.pow(next_v, 0.5) + eps
    update = next_m / denom + weight_decay * var * mode
    var_ = var - lr * update
    s_ = ema_decay * s + (1 - ema_decay) * var_
    return var_, m_, v_, s_       

输入

  • grad:必选输入,数据类型为tensor(float32),表示模型参数的梯度。接受任意shape但需保持接口调用时grad, var, m, v, s五个入参shape一致。
  • var:必选输入,数据类型为tensor(float32),表示模型参数。接受任意shape但需保持接口调用时grad, var, m, v, s五个入参shape一致。
  • m:必选输入,数据类型为tensor(float32),表示一阶矩动量。接受任意shape但需保持接口调用时grad, var, m, v, s五个入参shape一致。
  • v:必选输入,数据类型为tensor(float32),表示二阶矩动量。接受任意shape但需保持接口调用时grad, var, m, v, s五个入参shape一致。
  • s:必选输入,数据类型为tensor(float32),表示ema模型参数。接受任意shape但需保持接口调用时grad, var, m, v, s五个入参shape一致。
  • step:必选输入,数据类型为tensor(int64),shape:(1,),表示当前为第几步。
  • lr:可选属性,数据类型为float32,默认值:1e-3。表示学习率。
  • ema_decay:可选属性,数据类型为float32,默认值:0.9999。表示ema衰减超参数。
  • beta1:可选属性,数据类型为float32,默认值:0.9。表示一阶矩动量的衰减率。
  • beta2:可选属性,数据类型为float32,默认值:0.999。表示二阶矩动量的衰减率。
  • eps:可选属性,数据类型为float32,默认值:1e-8。表示一个极小的数。
  • mode:可选属性,数据类型为int,默认值:1。取1表示以adamw模式计算,取0表示以adam模式计算。
  • bias_correction:可选属性,数据类型为bool,默认值:True。表示是否开启偏置修正。
  • weight_decay:可选属性,数据类型为float32,默认值:0.0。表示模型参数的衰减率。

支持的输入数据类型组合:

参数名称 数据类型
grad tensor(float32)
var tensor(float32)
m tensor(float32)
v tensor(float32)
s tensor(float32)
step tensor(int64)
lr float32
ema_decay float32
beta1 float32
beta2 float32
eps float32
mode int
bias_correction bool
weight_decay float32

输出

  • var:必选输出,数据类型为tensor(float32),shape和入参var一致,表示更新后的模型参数。
  • m:必选输出,数据类型为tensor(float32),shape和入参m一致,表示更新后的一阶矩动量。
  • v:必选输出,数据类型为tensor(float32),shape和入参v一致,表示更新后的二阶矩动量。
  • s:必选输出,数据类型为tensor(float32),shape和入参s一致,表示更新后的ema模型参数。

调用示例

  • 输入 grad, var, m, v, s, step, lr, ema_decay, beta1, beta2, eps, mode, bias_correction, weight_decay
  • 输出 var, m, v, s
import torch
import torch_npu
from mindspeed.ops.npu_apply_fused_ema_adamw import npu_apply_fused_ema_adamw


grad = torch.full((10, 10), 0.5).to(torch.float32).npu()
var = torch.full((10, 10), 0.5).to(torch.float32).npu()
m = torch.full((10, 10), 0.9).to(torch.float32).npu()
v = torch.full((10, 10), 0.9).to(torch.float32).npu()
s = torch.full((10, 10), 0.5).to(torch.float32).npu()
step = torch.full((1, ), 1).to(torch.int64).npu()
lr, ema_decay, beta1, beta2, eps, mode, bias_correction, weight_decay= 1e-8, 0.9999, 0.9999, 0.9999, 1e-8, 1, True, 0.001
var, m, v, s = npu_apply_fused_ema_adamw(grad, var, m, v, s, step, lr, ema_decay, beta1, beta2, eps, mode,
                                         bias_correction, weight_decay)