82f94caa创建于 2024年12月7日历史提交
import torch
from mindspeed.op_builder import FusedEmaAdamWOpBuilder

__all__ = ["npu_apply_fused_ema_adamw"]

fused_ema_adamw_op_builder = FusedEmaAdamWOpBuilder()


def npu_apply_fused_ema_adamw(grad: torch.Tensor,
                              var: torch.Tensor,
                              m: torch.Tensor,
                              v: torch.Tensor,
                              s: torch.Tensor,
                              step: torch.Tensor,
                              lr: float = 1e-3,
                              ema_decay: float = 0.9999,
                              beta1: float = 0.9,
                              beta2: float = 0.999,
                              eps: float = 1e-8,
                              mode: int = 1,
                              bias_correction: bool = True,
                              weight_decay: float = 0.0):
    fused_ema_adamw_ops = fused_ema_adamw_op_builder.load()
    return fused_ema_adamw_ops.npu_apply_fused_ema_adamw(grad,
                                                         var,
                                                         m,
                                                         v,
                                                         s,
                                                         step,
                                                         lr,
                                                         ema_decay,
                                                         beta1,
                                                         beta2,
                                                         eps,
                                                         mode,
                                                         bias_correction,
                                                         weight_decay)