/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file permute.cpp
 * \brief Permute operation implementation
 */

#include "interface/utils/operator_tracer.h"
#include "interface/operation/operation_common.h"
#include "interface/function/function.h"
#include "interface/program/program.h"
#include "interface/configs/config_manager.h"
#include "tensor_transformation.h"
#include "permute.h"
#include <algorithm>
#include <sstream>

namespace npu::tile_fwk {

std::vector<int64_t> PermuteTileVector(const std::vector<int64_t>& values, const std::vector<int>& perm)
{
    std::vector<int64_t> result;
    result.reserve(perm.size());
    for (int axis : perm) {
        result.push_back(values[axis]);
    }
    return result;
}

[[maybe_unused]] std::vector<SymbolicScalar> PermuteTileVector(
    const std::vector<SymbolicScalar>& values, const std::vector<int>& perm)
{
    std::vector<SymbolicScalar> result;
    result.reserve(perm.size());
    for (int axis : perm) {
        result.push_back(values[axis]);
    }
    return result;
}

void PermuteOperationOperandCheck(
    const std::vector<LogicalTensorPtr>& iOperand, const std::vector<LogicalTensorPtr>& oOperand)
{
    ASSERT(VectorErrorCode::ERR_PARAM_INVALID, iOperand.size() == 1) << "Permute input operand count should be 1";
    ASSERT(VectorErrorCode::ERR_PARAM_INVALID, oOperand.size() == 1) << "Permute output operand count should be 1";
}

std::vector<int64_t> PermuteResultShape(const std::vector<int64_t>& inputShape, const std::vector<int>& perm)
{
    std::vector<int64_t> resultShape;
    resultShape.reserve(perm.size());
    for (int p : perm) {
        resultShape.push_back(inputShape[p]);
    }
    return resultShape;
}

bool IsIdentityPermutation(const std::vector<int>& perm)
{
    if (perm.size() <= 1) {
        return true;
    }
    for (size_t i = 0; i < perm.size(); ++i) {
        if (perm[i] != static_cast<int>(i)) {
            return false;
        }
    }
    return true;
}

void NormalizePermutation(std::vector<int>& perm, int shapeSize)
{
    for (int& p : perm) {
        if (p < 0) {
            p += shapeSize;
        }
    }
}

void ValidatePermutation(const std::vector<int>& perm, int shapeSize)
{
    ASSERT(VectorErrorCode::ERR_PARAM_INVALID, perm.size() == static_cast<size_t>(shapeSize))
        << "Permute dim num should match input dim num. Expected: " << shapeSize << ", Got: " << perm.size();

    std::vector<bool> used(shapeSize, false);
    for (int p : perm) {
        ASSERT(VectorErrorCode::ERR_PARAM_INVALID, p >= 0 && p < shapeSize)
            << "Permute dim is invalid: " << p << ". Should be in range [0, " << shapeSize << ")";
        ASSERT(VectorErrorCode::ERR_PARAM_INVALID, !used[p]) << "Permute dims contain duplicate values at index " << p;
        used[p] = true;
    }
}

static LogicalTensorPtr MakePermutedLogicalTensor(
    Function& function, const LogicalTensorPtr& self, const std::vector<int>& perm)
{
    std::vector<int64_t> resultShape = PermuteResultShape(self->shape, perm);
    std::vector<SymbolicScalar> resultValidShape;
    if (!self->GetDynValidShape().empty()) {
        for (int p : perm) {
            resultValidShape.push_back(self->GetDynValidShape()[p]);
        }
    } else {
        resultValidShape = SymbolicScalar::FromConcrete(resultShape);
    }
    return std::make_shared<LogicalTensor>(function, self->tensor->datatype, resultShape, resultValidShape);
}

void TiledPermuteOperation(
    Function& function, const TileShape& tileShape, size_t cur, Input& input, const LogicalTensorPtr& result,
    const std::vector<int>& perm);

Tensor TensorPermuteOperation(Function& function, LogicalTensorPtr self, const std::vector<int>& perm)
{
    auto result = MakePermutedLogicalTensor(function, self, perm);
    auto& op = function.AddOperation(Opcode::OP_PERMUTE, {self}, {result});
    op.SetAttribute(OpAttributeKey::perm, perm);
    function.UpdateTensorDataUsage(op);
    return result;
}

void TiledPermuteOperation(
    Function& function, const TileShape& tileShape, size_t cur, Input& input, const LogicalTensorPtr& result,
    const std::vector<int>& perm)
{
    int shapeSize = static_cast<int>(input.tensor.GetShape().size());
    if (cur == static_cast<size_t>(shapeSize)) {
        auto srcTile = input.tensor.GetStorage()->View(function, input.tileInfo.shape, input.tileInfo.offset);
        auto resultTileShape = PermuteTileVector(input.tileInfo.shape, perm);
        auto resultTileOffset = PermuteTileVector(input.tileInfo.offset, perm);
        auto resultTile = result->View(function, resultTileShape, resultTileOffset);

        auto& op = function.AddOperation(Opcode::OP_PERMUTE, {srcTile}, {resultTile});
        op.SetAttribute(OpAttributeKey::perm, perm);
        op.SetAttribute(OP_ATTR_PREFIX + "validShape", resultTile->GetDynValidShape());
        return;
    }

    auto& vecTile = tileShape.GetVecTile();
    for (int i = 0; i < input.tensor.GetShape()[cur]; i += vecTile[cur]) {
        input.tileInfo.shape[cur] = std::min(input.tensor.GetShape()[cur] - i, vecTile[cur]);
        input.tileInfo.offset[cur] = i;
        TiledPermuteOperation(function, tileShape, cur + 1, input, result, perm);
    }
}

void PermuteOperationTileFunc(
    Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
    const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
    PermuteOperationOperandCheck(iOperand, oOperand);

    std::vector<int> perm = op.GetVectorIntAttribute<int>(OpAttributeKey::perm);

    TileInfo tileInfo(iOperand[0]->shape.size(), iOperand[0]->offset.size());
    Input input{iOperand[0], tileInfo};
    TiledPermuteOperation(function, tileShape, 0, input, oOperand[0], perm);
}

std::vector<int64_t> PermuteElementTileVector(const std::vector<int64_t>& values, const std::vector<int>& perm)
{
    std::vector<int64_t> result;
    result.reserve(perm.size());
    for (int axis : perm) {
        result.push_back(values[axis]);
    }
    return result;
}

Tensor TensorElementPermuteOperation(Function& function, LogicalTensorPtr self, const std::vector<int>& perm)
{
    auto result = MakePermutedLogicalTensor(function, self, perm);
    auto& op = function.AddOperation(Opcode::OP_PERMUTE_ELEMENT, {self}, {result});
    op.SetAttribute(OpAttributeKey::perm, perm);
    function.UpdateTensorDataUsage(op);
    return result;
}

void TiledPermuteElementOperation(
    Function& function, const TileShape& tileShape, size_t cur, Input& input, const LogicalTensorPtr& result,
    const std::vector<int>& perm)
{
    int shapeSize = static_cast<int>(input.tensor.GetShape().size());
    if (cur == static_cast<size_t>(shapeSize)) {
        auto srcTile = input.tensor.GetStorage()->View(function, input.tileInfo.shape, input.tileInfo.offset);
        auto resultTileShape = PermuteElementTileVector(input.tileInfo.shape, perm);
        auto resultTileOffset = PermuteElementTileVector(input.tileInfo.offset, perm);
        auto resultTile = result->View(function, resultTileShape, resultTileOffset);

        auto& op = function.AddOperation(Opcode::OP_PERMUTE_ELEMENT, {srcTile}, {resultTile});
        op.SetAttribute(OpAttributeKey::perm, perm);
        op.SetAttribute(OP_ATTR_PREFIX + "validShape", resultTile->GetDynValidShape());
        return;
    }

    auto& vecTile = tileShape.GetVecTile();
    for (int i = 0; i < input.tensor.GetShape()[cur]; i += vecTile[cur]) {
        input.tileInfo.shape[cur] = std::min(input.tensor.GetShape()[cur] - i, vecTile[cur]);
        input.tileInfo.offset[cur] = i;
        TiledPermuteElementOperation(function, tileShape, cur + 1, input, result, perm);
    }
}

void PermuteElementOperationTileFunc(
    Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
    const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
    PermuteOperationOperandCheck(iOperand, oOperand);

    std::vector<int> perm = op.GetVectorIntAttribute<int>(OpAttributeKey::perm);

    TileInfo tileInfo(iOperand[0]->shape.size(), iOperand[0]->offset.size());
    Input input{iOperand[0], tileInfo};
    TiledPermuteElementOperation(function, tileShape, 0, input, oOperand[0], perm);
}

Tensor Permute(Function& function, const Tensor& self, std::vector<int> perm)
{
    DECLARE_TRACER();
    CheckTensorShapeSize(self.GetStorage(), "PERMUTE");
    std::unordered_set<DataType> supportedTypes = {
        DT_FP8E4M3, DT_FP8E5M2, DT_HF8, DT_FP8E8M0,
        DT_FP16, DT_BF16, DT_FP32,
        DT_INT8, DT_UINT8, DT_INT16, DT_UINT16,
        DT_INT32, DT_UINT32, DT_INT64, DT_UINT64,
        DT_BOOL
    };
    CheckTensorDataType(self.GetStorage(), supportedTypes, "PERMUTE");
    
    DataType dtype = self.GetDataType();
    if (dtype == DT_FP8E4M3 || dtype == DT_FP8E5M2 || dtype == DT_HF8 || dtype == DT_FP8E8M0) {
        ASSERT(VectorErrorCode::ERR_PARAM_INVALID, 
               Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510)
            << "PERMUTE: FP8 types (DT_FP8E4M3, DT_FP8E5M2, DT_HF8, DT_FP8E8M0) are only supported on DAV_3510 architecture.";
    }
    
    if (dtype == DT_INT64 || dtype == DT_UINT64) {
        ASSERT(VectorErrorCode::ERR_PARAM_INVALID, self.Format() != TileOpFormat::TILEOP_NZ)
            << "PERMUTE: INT64/UINT64 do not support NZ format.";
    }
    CheckTensorDimRange(self.GetStorage(), 1, 5, "PERMUTE");

    const int shapeSize = static_cast<int>(self.GetShape().size());

    ASSERT(VectorErrorCode::ERR_PARAM_INVALID, perm.size() == static_cast<size_t>(shapeSize))
        << "Permute dim num should match input dim num. Expected: " << shapeSize << ", Got: " << perm.size();

    if (shapeSize == 1) {
        return self;
    }

    NormalizePermutation(perm, shapeSize);
    ValidatePermutation(perm, shapeSize);

    if (IsIdentityPermutation(perm)) {
        return self;
    }

    bool lastAxisInvolved = (perm[shapeSize - 1] != shapeSize - 1);
    if (lastAxisInvolved) {
        RETURN_CALL(ElementPermuteOperation, function, self.GetStorage(), perm);
    }

    RETURN_CALL(PermuteOperation, function, self.GetStorage(), perm);
}

Tensor Permute(const Tensor& self, std::vector<int> perm)
{
    DECLARE_TRACER();
    auto& function = *Program::GetInstance().GetCurrentFunction();
    return Permute(function, self, perm);
}

REGISTER_OPERATION_TILED_FUNC(OP_PERMUTE, Opcode::OP_PERMUTE, PermuteOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_PERMUTE_ELEMENT, Opcode::OP_PERMUTE_ELEMENT, PermuteElementOperationTileFunc);

} // namespace npu::tile_fwk