#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpAdapter.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
c10::SmallVector<int64_t, SIZE> get_output_size(const at::Tensor &weight, const at::Tensor &indices,
const at::Tensor &offsets, bool include_last_offset)
{
TORCH_CHECK(weight.dim() == 2, "weight has to be a 2D Tensor, but got Tensor of dimension ", weight.dim());
c10::SmallVector<int64_t, SIZE> outputSize = {};
int64_t offset_size = offsets.size(0);
if (include_last_offset) {
offset_size = offsets.size(0) - 1;
}
outputSize = {offset_size, weight.size(1)};
return outputSize;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _embedding_bag(
const at::Tensor &weight, const at::Tensor &indices, const at::Tensor &offsets, bool scale_grad_by_freq,
int64_t mode, bool sparse, const c10::optional<at::Tensor> &per_sample_weights, bool include_last_offset,
int64_t padding_idx)
{
DO_COMPATIBILITY(aclnnEmbeddingBag, acl_op::_embedding_bag(weight, indices, offsets, scale_grad_by_freq,
mode, sparse, per_sample_weights, include_last_offset, padding_idx));
TORCH_CHECK((indices.dim() > 0), "indices.dim() must be greater than 0");
TORCH_CHECK((weight.dim() > 0), "weight.dim() must be greater than 0");
TORCH_CHECK((offsets.dim() > 0), "offsets.dim() must be greater than 0");
c10::SmallVector<int64_t, SIZE> result_size = get_output_size(weight, indices, offsets, include_last_offset);
at::Tensor output_tensor = npu_preparation::apply_tensor_without_format(weight, result_size);
at::Tensor offset2bag = npu_preparation::apply_tensor_without_format(indices, indices.size(0));
at::Tensor bag_size;
if (include_last_offset) {
bag_size = npu_preparation::apply_tensor_without_format(offsets, offsets.size(0) - 1);
} else {
bag_size = npu_preparation::apply_tensor_without_format(offsets);
}
at::Tensor max_indices;
if (mode == 0 || mode == 1) {
max_indices = npu_preparation::apply_tensor_without_format(offsets);
if (include_last_offset) {
max_indices = npu_preparation::apply_tensor_without_format(offsets, offsets.size(0) - 1);
}
} else {
c10::SmallVector<int64_t, SIZE> max_indices_size =
get_output_size(weight, indices, offsets, include_last_offset);
max_indices = npu_preparation::apply_tensor_without_format(offsets, max_indices_size);
}
at::Tensor offset2bag_cast = offset2bag;
at::Tensor bag_size_cast = bag_size;
at::Tensor max_indices_cast = max_indices;
if (indices.dtype() == at::kLong || offsets.dtype() == at::kLong) {
offset2bag_cast = offset2bag_cast.to(at::kLong);
bag_size_cast = bag_size_cast.to(at::kLong);
max_indices_cast = max_indices_cast.to(at::kLong);
} else if (indices.dtype() == at::kInt || offsets.dtype() == at::kInt) {
offset2bag_cast = offset2bag_cast.to(at::kInt);
bag_size_cast = bag_size_cast.to(at::kInt);
max_indices_cast = max_indices_cast.to(at::kInt);
}
if (mode == 0 && padding_idx < 0) {
offset2bag_cast = npu_preparation::apply_tensor_without_format(offset2bag_cast, 0);
}
EXEC_NPU_CMD(aclnnEmbeddingBag, weight, indices, offsets, scale_grad_by_freq,
mode, sparse, per_sample_weights, include_last_offset, padding_idx,
output_tensor, offset2bag_cast, bag_size_cast, max_indices_cast);
return std::tie(output_tensor, offset2bag_cast, bag_size_cast, max_indices_cast);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _embedding_bag_forward_only(
const at::Tensor &weight, const at::Tensor &indices, const at::Tensor &offsets, bool scale_grad_by_freq,
int64_t mode, bool sparse, const c10::optional<at::Tensor> &per_sample_weights, bool include_last_offset,
int64_t padding_idx)
{
DO_COMPATIBILITY(aclnnEmbeddingBag, acl_op::_embedding_bag_forward_only(weight, indices, offsets, scale_grad_by_freq,
mode, sparse, per_sample_weights, include_last_offset, padding_idx));
TORCH_CHECK((indices.dim() > 0), "indices.dim() must be greater than 0");
TORCH_CHECK((weight.dim() > 0), "weight.dim() must be greater than 0");
TORCH_CHECK((offsets.dim() > 0), "offsets.dim() must be greater than 0");
c10::SmallVector<int64_t, SIZE> result_size = get_output_size(weight, indices, offsets, include_last_offset);
at::Tensor output_tensor = npu_preparation::apply_tensor_without_format(weight, result_size);
at::Tensor offset2bag = npu_preparation::apply_tensor_without_format(indices, indices.size(0));
at::Tensor bag_size;
if (include_last_offset) {
bag_size = npu_preparation::apply_tensor_without_format(offsets, offsets.size(0) - 1);
} else {
bag_size = npu_preparation::apply_tensor_without_format(offsets);
}
at::Tensor max_indices;
if (mode == 0 || mode == 1) {
max_indices = npu_preparation::apply_tensor_without_format(offsets);
if (include_last_offset) {
max_indices = npu_preparation::apply_tensor_without_format(offsets, offsets.size(0) - 1);
}
} else {
c10::SmallVector<int64_t, SIZE> max_indices_size =
get_output_size(weight, indices, offsets, include_last_offset);
max_indices = npu_preparation::apply_tensor_without_format(offsets, max_indices_size);
}
at::Tensor offset2bag_cast = offset2bag;
at::Tensor bag_size_cast = bag_size;
at::Tensor max_indices_cast = max_indices;
if (indices.dtype() == at::kLong || offsets.dtype() == at::kLong) {
offset2bag_cast = offset2bag_cast.to(at::kLong);
bag_size_cast = bag_size_cast.to(at::kLong);
max_indices_cast = max_indices_cast.to(at::kLong);
} else if (indices.dtype() == at::kInt || offsets.dtype() == at::kInt) {
offset2bag_cast = offset2bag_cast.to(at::kInt);
bag_size_cast = bag_size_cast.to(at::kInt);
max_indices_cast = max_indices_cast.to(at::kInt);
}
if (mode == 0 && padding_idx < 0) {
offset2bag_cast = npu_preparation::apply_tensor_without_format(offset2bag_cast, 0);
}
EXEC_NPU_CMD(aclnnEmbeddingBag, weight, indices, offsets, scale_grad_by_freq,
mode, sparse, per_sample_weights, include_last_offset, padding_idx,
output_tensor, offset2bag_cast, bag_size_cast, max_indices_cast);
return std::tie(output_tensor, offset2bag_cast, bag_size_cast, max_indices_cast);
}
}