#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
std::vector<at::Tensor> squeeze_chunk_result(const at::TensorList& chunk_result)
{
std::vector<at::Tensor> squeezed_result;
for (const auto& chunk : chunk_result) {
if (chunk.defined()) {
at::Tensor squeezed_chunk = (chunk.dim() > 0 && chunk.size(0) == 1) ? chunk.squeeze(0) : chunk;
squeezed_result.push_back(squeezed_chunk);
} else {
squeezed_result.push_back(at::Tensor());
}
}
return squeezed_result;
}
std::tuple<at::Tensor, std::vector<at::Tensor>, std::vector<at::Tensor>> _lstm_npu_backward(
const at::Tensor &grad_y,
const at::Tensor &grad_hy,
const at::Tensor &grad_cy,
const at::Tensor &input,
const at::TensorList hx,
const at::TensorList params,
const at::Tensor &i,
const at::Tensor &j,
const at::Tensor &f,
const at::Tensor &o,
const at::Tensor &h,
const at::Tensor &c,
const at::Tensor &tanhc,
bool has_biases,
int64_t num_layers,
double dropout,
bool train,
bool bidirectional,
c10::optional<bool> batch_first,
const c10::optional<at::Tensor> &batch_sizes)
{
const bool batch_first_1 = batch_first.value_or(false);
auto out0_shape = op_infer::lstm_backward_npu_output_size(input, batch_first_1, batch_sizes);
auto out1_shape = op_infer::lstm_backward_npu_hc_prev_output_size(input, params, num_layers, bidirectional, batch_first_1, batch_sizes);
int64_t per_layer_params = has_biases ? 4 : 2;
int64_t D = bidirectional ? 2 : 1;
int64_t output_format = ACL_FORMAT_ND;
at::Tensor out0 = npu_preparation::apply_tensor(input, out0_shape);
at::Tensor out_h_prev = npu_preparation::apply_tensor(input, out1_shape);
at::Tensor out_c_prev = npu_preparation::apply_tensor(input, out1_shape);
std::vector<at::Tensor> param_list;
for (int64_t idx = 0; idx < params.size(); ++idx) {
auto i_tensor = npu_preparation::apply_tensor_with_format(input, params[idx].sizes(), output_format);
param_list.emplace_back(std::move(i_tensor));
}
int64_t list_length = D * num_layers;
const int64_t split_dim = 0;
auto i_chunk_origin = at::chunk(i, list_length, split_dim);
auto j_chunk_origin = at::chunk(j, list_length, split_dim);
auto f_chunk_origin = at::chunk(f, list_length, split_dim);
auto o_chunk_origin = at::chunk(o, list_length, split_dim);
auto h_chunk_origin = at::chunk(h, list_length, split_dim);
auto c_chunk_origin = at::chunk(c, list_length, split_dim);
auto tanhc_chunk_origin = at::chunk(tanhc, list_length, split_dim);
std::vector<at::Tensor> i_chunk = squeeze_chunk_result(i_chunk_origin);
std::vector<at::Tensor> j_chunk = squeeze_chunk_result(j_chunk_origin);
std::vector<at::Tensor> f_chunk = squeeze_chunk_result(f_chunk_origin);
std::vector<at::Tensor> o_chunk = squeeze_chunk_result(o_chunk_origin);
std::vector<at::Tensor> h_chunk = squeeze_chunk_result(h_chunk_origin);
std::vector<at::Tensor> c_chunk = squeeze_chunk_result(c_chunk_origin);
std::vector<at::Tensor> tanhc_chunk = squeeze_chunk_result(tanhc_chunk_origin);
at::TensorList i_list = at::TensorList(i_chunk);
at::TensorList j_list = at::TensorList(j_chunk);
at::TensorList f_list = at::TensorList(f_chunk);
at::TensorList o_list = at::TensorList(o_chunk);
at::TensorList h_list = at::TensorList(h_chunk);
at::TensorList c_list = at::TensorList(c_chunk);
at::TensorList tanhc_list = at::TensorList(tanhc_chunk);
at::TensorList param_list_ = at::TensorList(param_list);
std::nullptr_t temp_nullptr = nullptr;
EXEC_NPU_CMD(
aclnnLstmBackward,
input,
hx,
params,
grad_y,
grad_hy,
grad_cy,
i_list,
j_list,
f_list,
o_list,
h_list,
c_list,
tanhc_list,
batch_sizes,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first_1,
temp_nullptr,
out0,
out_h_prev,
out_c_prev,
param_list_);
std::vector<at::Tensor> hx_prev_vec;
hx_prev_vec.push_back(out_h_prev);
hx_prev_vec.push_back(out_c_prev);
at::TensorList out_hx_prev = at::TensorList(hx_prev_vec);
return std::make_tuple(
out0,
out_hx_prev.vec(),
param_list_.vec());
}
}