#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
inline c10::SmallVector<int64_t, SIZE> transpose_shape(
at::IntArrayRef ori_shape, at::IntArrayRef perm, bool transpose_first)
{
c10::SmallVector<int64_t, SIZE> trans_shape;
if (transpose_first) {
trans_shape = op_infer::array_to_small_vector(ori_shape);
} else {
size_t shape_size = ori_shape.size();
for (size_t i = 0; i < perm.size(); i++) {
TORCH_CHECK(shape_size > perm[i],
"npu_confusion_transpose forward/backward input invalid, "
"shape has size ",
shape_size,
" but perm[i] is, ",
perm[i],
OPS_ERROR(ErrCode::PARAM));
trans_shape.emplace_back(ori_shape[perm[i]]);
}
}
return trans_shape;
}
at::Tensor npu_confusion_transpose(
const at::Tensor &self, at::IntArrayRef perm, at::IntArrayRef shape, bool transpose_first)
{
if (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950) {
return acl_op::npu_confusion_transpose(self, perm, shape, transpose_first);
}
DO_COMPATIBILITY(aclnnConfusionTranspose, acl_op::npu_confusion_transpose(self, perm, shape, transpose_first));
c10::SmallVector<int64_t, SIZE> svec_output_shape = transpose_shape(shape, perm, transpose_first);
at::Tensor y = npu_preparation::apply_tensor_without_format(svec_output_shape, self.options());
EXEC_NPU_CMD(aclnnConfusionTranspose, self, perm, shape, transpose_first, y);
return y;
}
at::Tensor npu_confusion_transpose_backward_symint(
const at::Tensor &grad, at::IntArrayRef perm, c10::SymIntArrayRef shape_symint, bool transpose_first)
{
if (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950) {
return acl_op::npu_confusion_transpose_backward_symint(grad, perm, shape_symint, transpose_first);
}
DO_COMPATIBILITY(aclnnConfusionTranspose,
acl_op::npu_confusion_transpose_backward_symint(grad, perm, shape_symint, transpose_first));
at::IntArrayRef shape = c10::asIntArrayRefUnchecked(shape_symint);
c10::SmallVector<int64_t, SIZE> svec_backward_shape = transpose_shape(shape, perm, transpose_first);
at::IntArrayRef backward_shape(svec_backward_shape);
int64_t perm_len = perm.size();
c10::SmallVector<int64_t, SIZE> svec_backward_perm(perm_len, 0);
for (int64_t i = 0; i < perm_len; i++) {
svec_backward_perm[perm[i]] = i;
}
at::IntArrayRef backward_perm(svec_backward_perm);
at::Tensor y = npu_preparation::apply_tensor_without_format(shape, grad.options());
EXEC_NPU_CMD(aclnnConfusionTranspose, grad, backward_perm, backward_shape, transpose_first, y);
return y;
}
}