#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
std::tuple<at::Tensor, at::Tensor> _ctc_loss(const at::Tensor &log_probs, const at::Tensor &targets,
at::IntArrayRef input_lengths_list, at::IntArrayRef target_lengths_list,
int64_t blank, bool zero_infinity)
{
TORCH_CHECK(log_probs.dim() == 2 || log_probs.dim() == 3,
"log_probs has to be a 2D or 3D Tensor, but got Tensor of dimension ", log_probs.dim(),
OPS_ERROR(ErrCode::PARAM));
at::Tensor log_probs_cast = log_probs;
if (log_probs.scalar_type() == at::kHalf) {
log_probs_cast = at_npu::native::custom_ops::_npu_dtype_cast(log_probs_cast, at::kFloat);
}
int64_t max_length = 0;
for (auto &i : target_lengths_list) {
if (i > max_length) {
max_length = i;
}
}
auto shape = log_probs.sizes();
blank = blank + max_length * shape[2];
auto output_sizes = op_infer::ctc_loss_npu_output_size(log_probs, max_length);
at::Tensor neg_log_likelihood = npu_preparation::apply_tensor_with_format(
std::get<0>(output_sizes), log_probs_cast.options(), npu_preparation::get_tensor_npu_format(log_probs_cast));
at::Tensor log_alpha = npu_preparation::apply_tensor_with_format(
std::get<1>(output_sizes), log_probs_cast.options(), npu_preparation::get_tensor_npu_format(log_probs_cast));
if (log_probs.dim() == 2) {
c10::SmallVector<int64_t, N> log_probs_shape = op_infer::array_to_small_vector(log_probs.sizes());
c10::SmallVector<int64_t, N> log_probs_shape_3d = {log_probs_shape[0], 1, log_probs_shape[1]};
log_probs_cast = log_probs_cast.reshape(log_probs_shape_3d);
}
at_npu::native::OpCommand cmd;
cmd.Name("CTCLossV2")
.Input(log_probs_cast)
.Input(targets)
.Input(input_lengths_list)
.Input(target_lengths_list)
.Output(neg_log_likelihood)
.Output(log_alpha)
.Attr("blank", blank)
.Attr("zero_infinity", zero_infinity)
.Run();
if (log_probs.scalar_type() == at::kHalf) {
neg_log_likelihood = at_npu::native::custom_ops::_npu_dtype_cast(neg_log_likelihood, at::kHalf);
log_alpha = at_npu::native::custom_ops::_npu_dtype_cast(log_alpha, at::kHalf);
}
return std::tie(neg_log_likelihood, log_alpha);
}
}