#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using tensor_list = std::tuple<at::Tensor &, at::Tensor &>;
namespace {
tensor_list batch_norm_gather_stats_update_npu_impl(at::Tensor &mean_all, at::Tensor &invstd_all,
const at::Tensor &self, const at::Tensor &sum,
const at::Tensor &square_sum, const at::Tensor &running_mean,
const at::Tensor &running_var, double momentum, double eps,
const at::Tensor &counts)
{
at::Tensor counts_cp =
counts.scalar_type() == at::kInt ? counts : at_npu::native::custom_ops::_npu_dtype_cast(counts, at::kInt);
auto running_mean_dtype = running_mean.scalar_type();
at::Tensor running_mean_ = at_npu::native::custom_ops::_npu_dtype_cast(
at_npu::native::custom_ops::npu_format_cast(
(running_mean.defined() ? running_mean : at::zeros({self.size(1)}, sum.options())), ACL_FORMAT_ND),
sum.scalar_type());
at::Tensor running_var_ = at_npu::native::custom_ops::_npu_dtype_cast(
at_npu::native::custom_ops::npu_format_cast(
(running_var.defined() ? running_var : at::ones({self.size(1)}, sum.options())), ACL_FORMAT_ND),
sum.scalar_type());
at_npu::native::OpCommand cmd;
cmd.Name("SyncBatchNormGatherStats")
.Input(sum)
.Input(square_sum)
.Input(counts_cp)
.Input(running_mean_)
.Input(running_var_)
.Output(mean_all)
.Output(invstd_all)
.Output(running_mean_)
.Output(running_var_)
.Attr("momentum", static_cast<float>(momentum))
.Attr("eps", static_cast<float>(eps))
.Run();
if (running_mean.defined()) {
if (running_mean_.scalar_type() != running_mean_dtype) {
running_mean_ = at_npu::native::custom_ops::_npu_dtype_cast(running_mean_, running_mean_dtype);
running_var_ = at_npu::native::custom_ops::_npu_dtype_cast(running_var_, running_mean_dtype);
}
running_mean.copy_(running_mean_);
running_var.copy_(running_var_);
}
return std::tie(mean_all, invstd_all);
}
}
std::tuple<at::Tensor, at::Tensor> batch_norm_gather_stats_update(const at::Tensor &self, const at::Tensor &sum,
const at::Tensor &square_sum,
const c10::optional<at::Tensor> &running_mean_opt,
const c10::optional<at::Tensor> &running_var_opt,
double momentum, double eps, const at::Tensor &counts)
{
TORCH_CHECK(self.dim() > 1, "The dim input tensor [self] must more than 1." + OPS_ERROR(ErrCode::PARAM));
c10::SmallVector<int64_t, N> output_size = {self.size(1)};
const at::Tensor &running_mean = c10::value_or_else(running_mean_opt, [] { return at::Tensor(); });
const at::Tensor &running_var = c10::value_or_else(running_var_opt, [] { return at::Tensor(); });
at::Tensor mean_all = npu_preparation::apply_tensor(sum, output_size);
at::Tensor invstd_all = npu_preparation::apply_tensor(sum, output_size);
batch_norm_gather_stats_update_npu_impl(mean_all, invstd_all, self, sum, square_sum, running_mean, running_var,
momentum, eps, counts);
return std::make_tuple(mean_all, invstd_all);
}
}