#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
c10::SmallVector<int64_t, SIZE> _embedding_bag_npu_output_size(const at::Tensor &weight, const at::Tensor &indices,
const at::Tensor &offsets, bool include_last_offset)
{
TORCH_CHECK(weight.dim() == 2, "weight has to be a 2D Tensor, but got Tensor of dimension ", weight.dim(),
OPS_ERROR(ErrCode::PARAM));
c10::SmallVector<int64_t, SIZE> outputSize = {};
if (indices.dim() == 1) {
int64_t offset_size = offsets.size(0);
if (include_last_offset) {
offset_size = offsets.size(0) - 1;
}
outputSize = {offset_size, weight.size(1)};
} else {
outputSize = {indices.size(0), weight.size(1)};
}
return outputSize;
}
string get_mode_str(const int64_t mode)
{
string modeStr = "mean";
if (mode == 0) {
modeStr = "sum";
} else if (mode == 1) {
modeStr = "mean";
} else {
modeStr = "max";
}
return modeStr;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _embedding_bag_out_npu_nocheck(
const at::Tensor &weight, const at::Tensor &indices, const at::Tensor &offsets, bool scale_grad_by_freq,
int64_t mode, bool sparse, const at::Tensor &per_sample_weights, bool include_last_offset, int64_t padding_idx,
at::Tensor &output, at::Tensor &offset2bag, at::Tensor &bag_size, at::Tensor &max_indices,
const at::Tensor &indices_ori, const at::Tensor &offsets_ori)
{
string mode_str = get_mode_str(mode);
at_npu::native::OpCommand cmd;
if (indices.numel() == 0 || offsets.numel() == 0) {
TORCH_CHECK(mode == 0, "The mode must be sum" + OPS_ERROR(ErrCode::PARAM));
output = npu_preparation::apply_tensor(weight);
acl_op::fill_(output, 0);
offset2bag = npu_preparation::apply_tensor(indices, 0);
bag_size = npu_preparation::apply_tensor(indices, 0);
max_indices = npu_preparation::apply_tensor(indices, 0);
return std::tie(output, offset2bag, bag_size, max_indices);
} else {
cmd.Name("EmbeddingBag").Input(weight).Input(indices).Input(offsets);
if (per_sample_weights.defined()) {
cmd.Input(per_sample_weights);
}
cmd.Output(output)
.Output(offset2bag)
.Output(bag_size)
.Output(max_indices)
.Attr("mode", mode_str)
.Attr("scale_grad_by_freq", scale_grad_by_freq)
.Attr("sparse", sparse)
.Attr("include_last_offset", include_last_offset)
.Attr("padding_idx", padding_idx)
.Run();
}
if (mode_str == "sum" && padding_idx == -1) {
offset2bag = npu_preparation::apply_tensor(indices, 0);
}
at::Tensor offset2bag_cast = const_cast<at::Tensor &>(offset2bag);
at::Tensor bag_size_cast = const_cast<at::Tensor &>(bag_size);
at::Tensor max_indices_cast = const_cast<at::Tensor &>(max_indices);
if (indices_ori.dtype() == at::kLong || offsets_ori.dtype() == at::kLong) {
offset2bag_cast = offset2bag_cast.to(at::kLong);
bag_size_cast = bag_size_cast.to(at::kLong);
max_indices_cast = max_indices_cast.to(at::kLong);
}
return std::tie(output, offset2bag_cast, bag_size_cast, max_indices_cast);
}
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _embedding_bag(
const at::Tensor &weight, const at::Tensor &indices, const at::Tensor &offsets, bool scale_grad_by_freq,
int64_t mode, bool sparse, const c10::optional<at::Tensor> &per_sample_weights, bool include_last_offset,
int64_t padding_idx)
{
TORCH_CHECK((indices.dim() > 0), "indices.dim() must be greater than 0" + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK((weight.dim() > 0), "weight.dim() must be greater than 0" + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK((offsets.dim() > 0), "offsets.dim() must be greater than 0" + OPS_ERROR(ErrCode::PARAM));
at::Tensor indices_cast = const_cast<at::Tensor &>(indices);
at::Tensor offsets_cast = const_cast<at::Tensor &>(offsets);
if (indices.dtype() == at::kLong) {
indices_cast = at_npu::native::custom_ops::_npu_dtype_cast(indices_cast, at::kInt);
}
if (offsets.dtype() == at::kLong) {
offsets_cast = at_npu::native::custom_ops::_npu_dtype_cast(offsets_cast, at::kInt);
}
const at::Tensor &per_sample_weights_core = c10::value_or_else(per_sample_weights, [] { return at::Tensor(); });
c10::SmallVector<int64_t, SIZE> result_size =
_embedding_bag_npu_output_size(weight, indices_cast, offsets_cast, include_last_offset);
at::Tensor output_tensor = npu_preparation::apply_tensor(weight, result_size);
int64_t indices_num = indices_cast.size(0);
string mode_str = get_mode_str(mode);
at::Tensor offset2bag = npu_preparation::apply_tensor(indices_cast);
at::Tensor bag_size = npu_preparation::apply_tensor(offsets_cast);
if (include_last_offset) {
bag_size = npu_preparation::apply_tensor(offsets_cast, offsets_cast.size(0) - 1);
}
at::Tensor max_indices;
if (mode_str == "max") {
c10::SmallVector<int64_t, SIZE> max_indices_size =
_embedding_bag_npu_output_size(weight, indices_cast, offsets_cast, include_last_offset);
max_indices = npu_preparation::apply_tensor(offsets_cast, max_indices_size);
} else {
max_indices = npu_preparation::apply_tensor(offsets_cast);
if (include_last_offset) {
max_indices = npu_preparation::apply_tensor(offsets_cast, offsets_cast.size(0) - 1);
}
}
return _embedding_bag_out_npu_nocheck(weight, indices_cast, offsets_cast, scale_grad_by_freq, mode, sparse,
per_sample_weights_core, include_last_offset, padding_idx, output_tensor,
offset2bag, bag_size, max_indices, indices, offsets);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _embedding_bag_forward_only(
const at::Tensor &weight, const at::Tensor &indices, const at::Tensor &offsets, bool scale_grad_by_freq,
int64_t mode, bool sparse, const c10::optional<at::Tensor> &per_sample_weights, bool include_last_offset,
int64_t padding_idx)
{
TORCH_CHECK((indices.dim() > 0), "indices.dim() must be greater than 0" + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK((weight.dim() > 0), "weight.dim() must be greater than 0" + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK((offsets.dim() > 0), "offsets.dim() must be greater than 0" + OPS_ERROR(ErrCode::PARAM));
at::Tensor indices_cast = const_cast<at::Tensor &>(indices);
at::Tensor offsets_cast = const_cast<at::Tensor &>(offsets);
if (indices.dtype() == at::kLong) {
indices_cast = at_npu::native::custom_ops::_npu_dtype_cast(indices_cast, at::kInt);
}
if (offsets.dtype() == at::kLong) {
offsets_cast = at_npu::native::custom_ops::_npu_dtype_cast(offsets_cast, at::kInt);
}
const at::Tensor &per_sample_weights_core = c10::value_or_else(per_sample_weights, [] { return at::Tensor(); });
c10::SmallVector<int64_t, SIZE> result_size =
_embedding_bag_npu_output_size(weight, indices_cast, offsets_cast, include_last_offset);
at::Tensor output_tensor = npu_preparation::apply_tensor(weight, result_size);
int64_t indices_num = indices_cast.size(0);
string mode_str = get_mode_str(mode);
at::Tensor offset2bag = npu_preparation::apply_tensor(indices_cast);
at::Tensor bag_size = npu_preparation::apply_tensor(offsets_cast);
if (include_last_offset) {
bag_size = npu_preparation::apply_tensor(offsets_cast, offsets_cast.size(0) - 1);
}
at::Tensor max_indices;
if (mode_str == "max") {
c10::SmallVector<int64_t, SIZE> max_indices_size =
_embedding_bag_npu_output_size(weight, indices_cast, offsets_cast, include_last_offset);
max_indices = npu_preparation::apply_tensor(offsets_cast, max_indices_size);
} else {
max_indices = npu_preparation::apply_tensor(offsets_cast);
if (include_last_offset) {
max_indices = npu_preparation::apply_tensor(offsets_cast, offsets_cast.size(0) - 1);
}
}
return _embedding_bag_out_npu_nocheck(weight, indices_cast, offsets_cast, scale_grad_by_freq, mode, sparse,
per_sample_weights_core, include_last_offset, padding_idx, output_tensor,
offset2bag, bag_size, max_indices, indices, offsets);
}
}