#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using calcu_op_util = at_npu::native::CalcuOpUtil;
at::Tensor& npu_reshape_out(
const at::Tensor& self,
at::IntArrayRef shape,
bool can_refresh,
at::Tensor& out)
{
if (can_refresh) {
at_npu::native::StorageDescHelper::SetDesc(
out,
op_infer::array_to_small_vector(out.sizes()),
op_infer::array_to_small_vector(out.strides()));
} else {
at_npu::native::copy_d2d_by_memcpy(
out,
self,
at_npu::native::NPUNativeFunctions::get_storage_size(out));
}
return out;
}
at::Tensor npu_reshape(const at::Tensor& self, at::IntArrayRef shape, bool can_refresh)
{
at::Tensor result = npu_preparation::apply_tensor(self, shape);
op_api::npu_reshape_out(self, shape, can_refresh, result);
return result;
}
}