#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/OpInterface.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
const int DIMENSION_2D = 2;
std::tuple<at::Tensor, at::Tensor> _pack_padded_sequence(const at::Tensor &input, const at::Tensor &lengths,
bool batch_first)
{
TORCH_CHECK(input.dim() >= DIMENSION_2D, "Input must have two dims.", input.dim(), OPS_ERROR(ErrCode::PARAM));
auto batchsize = batch_first ? input.size(0) : input.size(1);
auto timesize = batch_first ? input.size(1) : input.size(0);
TORCH_CHECK(input.numel() > 0, "Cannot pack empty tensors." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(input.numel() < std::numeric_limits<int64_t>::max(),
"Input tensor contain more than the max number of int64." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(lengths.size(0) == batchsize, "Expected 'len(lengths)' to be equal to batch_size, but got ",
lengths.size(0), " (batch_size=", batchsize, ")" + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(lengths.device().type() == at::kCPU,
"'lengths' argument should be a CPU tensor, but got ",
lengths.device().str(), " tensor" + OPS_ERROR(ErrCode::PARAM));
auto lengths_vec = lengths.contiguous().data_ptr<int64_t>();
TORCH_CHECK(lengths_vec != nullptr && lengths_vec[batchsize - 1] > 0,
"Length of all samples has to be greater than 0, but found an element "
"in 'lengths' that is <= 0" + OPS_ERROR(ErrCode::PARAM));
auto output = batch_first ? input.transpose(0, 1) : input;
auto len = lengths_vec[0];
if (len < timesize) {
vector<int> tmp_vector = {};
for (int i = 0; i < len; i++) {
tmp_vector.emplace_back(i);
}
auto index = at::from_blob(tmp_vector.data(), {len}, at::kInt);
index = npu_preparation::copy_tensor_host_to_device(index);
output = op_plugin::index_select(output, 0, index);
timesize = len;
}
at::SmallVector<int64_t, N> shape;
shape.emplace_back(batchsize * timesize);
for (int i = 2; i < input.dim(); i++) {
shape.emplace_back(input.size(i));
}
output = output.contiguous();
output = output.reshape(shape);
at::Tensor batchsizes = at::empty({timesize}, lengths.options());
auto batchsize_vec = batchsizes.data_ptr<int64_t>();
TORCH_CHECK(batchsize_vec != nullptr, "batchsizes is null" + OPS_ERROR(ErrCode::PARAM));
int64_t last = batchsize - 1;
for (int ti = 0; ti < timesize; ti++) {
for (int bi = last; bi >= 0; bi--) {
if (lengths_vec[bi] > ti) {
batchsize_vec[ti] = (bi + 1);
last = bi;
break;
}
}
}
return std::tie(output, batchsizes);
}
}