82f94caa创建于 2024年12月7日历史提交
#include <torch/extension.h>
#include <torch/script.h>
#include <torch_npu/csrc/include/ops.h>
#include <torch/torch.h>
#include "torch_npu/csrc/core/npu/NPUFormat.h"
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "inc/aclnn_common.h"

at::Tensor format_trans(const at::Tensor &at_tensor)
{
    if (at_tensor.defined()) {
        TORCH_CHECK(torch_npu::utils::is_npu(at_tensor), "only npu tensor is supported");
        return at_npu::native::npu_format_cast(at_tensor, ACL_FORMAT_ND);
    }
    return at_tensor;
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>npu_apply_fused_ema_adamw(
    at::Tensor grad,
    at::Tensor var,
    at::Tensor m,
    at::Tensor v,
    at::Tensor s,
    at::Tensor step,
    c10::optional<double> lr,
    c10::optional<double> ema_decay,
    c10::optional<double> beta1,
    c10::optional<double> beta2,
    c10::optional<double> eps,
    c10::optional<int64_t> mode,
    c10::optional<bool> bias_correction,
    c10::optional<double> weight_decay)
{
    at::Tensor grad_ = format_trans(grad);
    at::Tensor var_ = format_trans(var);
    at::Tensor m_ = format_trans(m);
    at::Tensor v_ = format_trans(v);
    at::Tensor s_ = format_trans(s);
    at::Tensor step_ = format_trans(step);
    double lr_ = double(lr.value());
    double ema_decay_ = double(ema_decay.value());
    double beta1_ = double(beta1.value());
    double beta2_ = double(beta2.value());
    double eps_ = double(eps.value());
    int64_t mode_ = int64_t(mode.value());
    bool bias_correction_ = bool(bias_correction.value());
    double weight_decay_ = double(weight_decay.value());
    ACLNN_CMD(aclnnApplyFusedEmaAdam,
              grad_,
              var_,
              m_,
              v_,
              s_,
              step_,
              lr_,
              ema_decay_,
              beta1_,
              beta2_,
              eps_,
              mode_,
              bias_correction_,
              weight_decay_);
    return std::tie(var_, m_, v_, s_);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("npu_apply_fused_ema_adamw",
          &npu_apply_fused_ema_adamw,
          "npu_apply_fused_ema_adamw",
          pybind11::arg("grad"),
          pybind11::arg("var"),
          pybind11::arg("m"),
          pybind11::arg("v"),
          pybind11::arg("s"),
          pybind11::arg("step"),
          pybind11::arg("lr") = 1e-3f,
          pybind11::arg("ema_decay") = 0.9999,
          pybind11::arg("beta1") = 0.9,
          pybind11::arg("beta2") = 0.999,
          pybind11::arg("eps") = 1e-8f,
          pybind11::arg("mode") = 1,
          pybind11::arg("bias_correction") = true,
          pybind11::arg("weight_decay") = 0.0);
}