import torch
from mindspeed.op_builder.fused_adamw_v2_builder import FusedAdamWV2OpBuilder
__all__ = ["npu_apply_fused_adamw_v2"]
op_builder = FusedAdamWV2OpBuilder()
tmp_tensor = torch.Tensor([1])
def npu_apply_fused_adamw_v2(param: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
max_exp_avg_sq: torch.Tensor,
state_step: int,
lr: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.999,
weight_decay: float = 0.0,
eps: float = 1e-8,
amsgrad: bool = False,
maximize: bool = False,
):
fused_adamw_ops = op_builder.load()
if max_exp_avg_sq is None:
max_exp_avg_sq = tmp_tensor
return fused_adamw_ops.npu_apply_fused_adamw_v2(param,
grad,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
state_step,
lr,
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize)