* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file transpose_tiling_base.cpp
* \brief transpose
*/
#include "transpose_tiling_arch35.h"
#include "transpose_tiling_base.h"
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <iostream>
#include <math.h>
#include <climits>
using namespace std;
using namespace gert;
namespace optiling {
static void DuplicateArray(const int64_t* src, vector<int64_t>& dst, int64_t len)
{
for (int64_t i = 0; i < len; i++) {
dst[i] = src[i];
}
}
static void DuplicateArray(const vector<int64_t>& src, vector<int64_t>& dst, int64_t len)
{
for (int64_t i = 0; i < len; i++) {
dst[i] = src[i];
}
}
static void DuplicateArray(const vector<int64_t>& src, int64_t* dst, int64_t len)
{
for (int64_t i = 0; i < len; i++) {
dst[i] = src[i];
}
}
static bool IsAllOne(const ShapeInfo& shapeInfo)
{
return std::all_of(shapeInfo.inShape.begin(), shapeInfo.inShape.begin() + shapeInfo.dim, [](const int64_t& item) {
return item == 1;
});
}
static int DecreaseCompare(const void* a, const void* b)
{
return (*(int64_t*)b - *(int64_t*)a);
}
static void CalcOutShape(ShapeInfo& shapeInfo)
{
const vector<int64_t>& inShape = shapeInfo.reducedInShape;
const vector<int64_t>& perm = shapeInfo.reducedPerm;
vector<int64_t>& outShape = shapeInfo.reducedOutShape;
for (int64_t i = 0; i < shapeInfo.dim; i++) {
outShape[i] = inShape[perm[i]];
}
}
void RemoveAxisV2(ShapeInfo& shapeInfo)
{
int64_t dim = shapeInfo.dim;
if (dim == 1) {
DuplicateArray(shapeInfo.inShape, shapeInfo.reducedInShape, dim);
DuplicateArray(shapeInfo.perm, shapeInfo.reducedPerm, dim);
DuplicateArray(shapeInfo.outShape, shapeInfo.reducedOutShape, dim);
return;
}
if (IsAllOne(shapeInfo)) {
shapeInfo.reducedInShape[0] = 1;
shapeInfo.reducedPerm[0] = 0;
shapeInfo.reducedOutShape[0] = 1;
shapeInfo.dim = 1;
return;
}
vector<int64_t>& shape = shapeInfo.reducedInShape;
int64_t delPerm[TRANSPOSE_MAX_AXIS_NUM];
int64_t newPerm[TRANSPOSE_MAX_AXIS_NUM];
int64_t shapeSize = 0;
int64_t delPermSize = 0;
int64_t newPermSize = 0;
for (int64_t i = 0; i < dim; i++) {
if (shapeInfo.inShape[i] != 1) {
shape[shapeSize++] = shapeInfo.inShape[i];
} else {
for (int64_t j = 0; j < dim; j++) {
if (shapeInfo.perm[j] == i) {
delPerm[delPermSize++] = shapeInfo.perm[j];
}
}
}
}
qsort(reinterpret_cast<void*>(&delPerm[0]), delPermSize, sizeof(int64_t), DecreaseCompare);
for (int64_t i = 0; i < dim; i++) {
bool delFlag = false;
for (int64_t j = 0; j < delPermSize; j++) {
if (shapeInfo.perm[i] == delPerm[j]) {
delFlag = true;
}
}
if (!delFlag) {
newPerm[newPermSize++] = shapeInfo.perm[i];
}
}
for (int64_t i = 0; i < delPermSize; i++) {
for (int64_t j = 0; j < newPermSize; j++) {
if (newPerm[j] > delPerm[i]) {
newPerm[j] = newPerm[j] - 1;
}
}
}
DuplicateArray(newPerm, shapeInfo.reducedPerm, newPermSize);
shapeInfo.dim = newPermSize;
CalcOutShape(shapeInfo);
}
void MergeAxisV2(ShapeInfo& shapeInfo)
{
int64_t dim = shapeInfo.dim;
if (dim == 1) {
return;
}
int64_t perm[TRANSPOSE_MAX_AXIS_NUM];
int64_t shape[TRANSPOSE_MAX_AXIS_NUM];
int64_t newPerm[TRANSPOSE_MAX_AXIS_NUM];
int64_t newShape[TRANSPOSE_MAX_AXIS_NUM];
int64_t newDimPosition[TRANSPOSE_MAX_AXIS_NUM];
int64_t mergedShape[TRANSPOSE_MAX_AXIS_NUM] = {0};
DuplicateArray(shapeInfo.reducedPerm, perm, dim);
DuplicateArray(shapeInfo.reducedInShape, shape, dim);
for (int i = 0; i < TRANSPOSE_MAX_AXIS_NUM; i++) {
newDimPosition[i] = -1;
}
int64_t curHead = shapeInfo.reducedPerm[0];
newDimPosition[curHead] = 0;
mergedShape[0] = shape[curHead];
int dimIndex = 0;
for (int permIndex = 1; permIndex < dim; ++permIndex) {
if (curHead + 1 == perm[permIndex]) {
curHead = perm[permIndex];
mergedShape[dimIndex] *= shape[curHead];
} else {
curHead = perm[permIndex];
dimIndex++;
newDimPosition[curHead] = dimIndex;
mergedShape[dimIndex] = shape[curHead];
}
}
shapeInfo.dim = dimIndex + 1;
dimIndex = 0;
for (int i = 0; i < dim; i++) {
if (newDimPosition[i] >= 0) {
newDimPosition[dimIndex++] = newDimPosition[i];
}
}
dimIndex = 0;
for (int64_t i = 0; i < dim; ++i) {
if (newDimPosition[i] >= 0) {
int64_t newPermIndex = newDimPosition[i];
for (int64_t j = 0; j < dim; j++) {
if (newDimPosition[j] == i) {
newPerm[dimIndex] = j;
break;
}
}
newShape[dimIndex] = mergedShape[newPermIndex];
dimIndex++;
}
}
DuplicateArray(newShape, shapeInfo.reducedInShape, dimIndex);
DuplicateArray(newPerm, shapeInfo.reducedPerm, dimIndex);
shapeInfo.lastAxisLen = shapeInfo.reducedInShape[shapeInfo.dim - 1];
shapeInfo.lastAxisBurstLen = static_cast<int64_t>(ceil(shapeInfo.lastAxisLen * 1.0 / shapeInfo.elePerBlock));
CalcOutShape(shapeInfo);
}
}