* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "fftoperation/transpose.h"
#include "base_inner_api.h"
#include "utils/ops_base.h"
using namespace AsdSip;
constexpr int DIM_TWO = 2;
Transpose::Transpose(int axis0, int axis1, Mki::SVector<int64_t> dims) : axis0_{axis0}, axis1_{axis1}, dims_{dims} {}
bool Transpose::init() { return true; }
void Transpose::Run(Tensor &input, Tensor &output, void *stream, workspace::Workspace &workspace)
{
input.desc.dtype = Mki::TensorDType::TENSOR_DTYPE_FLOAT;
input.desc.dims.push_back(DIM_TWO);
int axis0 = axis0_;
if (axis0_ < 0) {
axis0 = static_cast<int>(input.desc.dims.size()) - 1 + axis0_;
}
int axis1 = axis1_;
if (axis1_ < 0) {
axis1 = static_cast<int>(input.desc.dims.size()) - 1 + axis1_;
}
uint8_t *deviceBuffer = (uint8_t *)workspace.allocate(ASYNC_WORKSPACE_SIZE);
Mki::SVector<int64_t> shapes = input.desc.dims;
Mki::SVector<int64_t> strides(shapes.size(), 1);
for (int64_t i = static_cast<int64_t>(shapes.size()) - 2; i >= 0; i--) {
strides[i] = shapes[i + 1] * strides[i + 1];
}
int64_t shapeTmp = shapes[axis0];
int64_t stridesTmp = strides[axis0];
shapes[axis0] = shapes[axis1];
strides[axis0] = strides[axis1];
shapes[axis1] = shapeTmp;
strides[axis1] = stridesTmp;
AsStrided(input, output, shapes, strides, stream, 0, deviceBuffer);
workspace.recycleLast();
output.desc.dtype = Mki::TensorDType::TENSOR_DTYPE_COMPLEX64;
output.desc.dims.erase(output.desc.dims.end() - 1);
}
size_t Transpose::EstimateWorkspaceSize()
{
return getAlignedSize(ASYNC_WORKSPACE_SIZE);
}
void Transpose::Run(void *input, void *output, void *stream, workspace::Workspace &workspace)
{
Mki::SVector<int64_t> tmpDims = dims_;
tmpDims.push_back(DIM_TWO);
int axis0 = axis0_;
if (axis0_ < 0) {
axis0 = static_cast<int>(tmpDims.size()) - 1 + axis0_;
}
int axis1 = axis1_;
if (axis1_ < 0) {
axis1 = static_cast<int>(tmpDims.size()) - 1 + axis1_;
}
uint8_t *deviceBuffer = (uint8_t *)workspace.allocate(ASYNC_WORKSPACE_SIZE);
Tensor tensorIn;
Tensor tensorOut;
tensorIn.desc = {Mki::TensorDType::TENSOR_DTYPE_FLOAT, Mki::TensorFormat::TENSOR_FORMAT_ND, tmpDims, {}, 0};
tensorIn.data = input;
tensorOut.data = output;
Mki::SVector<int64_t> strides(tmpDims.size(), 1);
for (int64_t i = static_cast<int64_t>(tmpDims.size()) - 2; i >= 0; i--) {
strides[i] = tmpDims[i + 1] * strides[i + 1];
}
int64_t shapeTmp = tmpDims[axis0];
int64_t stridesTmp = strides[axis0];
tmpDims[axis0] = tmpDims[axis1];
strides[axis0] = strides[axis1];
tmpDims[axis1] = shapeTmp;
strides[axis1] = stridesTmp;
AsStrided(tensorIn, tensorOut, tmpDims, strides, stream, 0, deviceBuffer);
workspace.recycleLast();
}