/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file rfft1d_tiling_align.cpp
 * \brief
 */
#include "platform/platform_info.h"
#include "register/op_def_registry.h"
#include "exe_graph/runtime/shape.h"
#include "rfft1_d_tiling_base.h"
#include <iostream>
#include <cmath>
#include <numeric>
#include <vector>

static const uint8_t MAX_FACTORS_LEN = 3;
static const uint32_t USER_WORKSPACE_SIZE = 2147483648;
static const uint32_t LAST_FACTOR = 64;
static const uint32_t COMPLEX_PART = 2;
static const uint32_t MATMUL_SIZE_MULTIPLIER = 24;
static const uint32_t SIZE_PER_BATCH_MULTIPLIER = 4;
static const uint32_t FIRST_FACTOR = 2;
static const uint32_t BYTES_ALIGN = 8;
static const uint32_t ROW_PAD = 16;
static const uint32_t COL_PAD = 8;
static const uint32_t TWIDDLE_MATRICES_AMOUNT = 2;
static const gert::Shape g_vec_1_shape = {1};

static const uint32_t DFT_BORDER_VALUE = 4096;

using namespace AscendC;
using namespace matmul_tiling;

namespace optiling {

class Rfft1DTiling : public Rfft1DBaseTiling {
public:
    explicit Rfft1DTiling(gert::TilingContext* context) : Rfft1DBaseTiling(context)
    {}

protected:
    ge::graphStatus DoOpTiling() override;
    ge::graphStatus PostTiling() override;
    bool IsCapable() override;
    ge::graphStatus DoLibApiTiling() override;
    uint64_t GetTilingKey() const override;
    ge::graphStatus GetWorkspaceSize() override;
    const gert::Shape& EnsureNotScalar(const gert::Shape& inShape);

private:
    uint32_t batches = 1;
    uint32_t dftRealOffsets[3] = {0, 1, 1};
    uint32_t dftImagOffsets[3] = {0, 1, 1};
    uint32_t twiddleOffsets[3] = {0, 0, 1};
    Rfft1DTilingData tiling;
    void CalcDftSizes(const uint32_t factors[], const bool isBluestein, const uint32_t len);
};

const gert::Shape& Rfft1DTiling::EnsureNotScalar(const gert::Shape& inShape)
{
    if (inShape.IsScalar()) {
        return g_vec_1_shape;
    }
    return inShape;
}

bool Rfft1DTiling::IsCapable()
{
    return true;
}

ge::graphStatus Rfft1DTiling::DoLibApiTiling()
{
    return ge::GRAPH_SUCCESS;
}

uint64_t Rfft1DTiling::GetTilingKey() const
{
    return 0;
}

ge::graphStatus Rfft1DTiling::GetWorkspaceSize()
{
    return ge::GRAPH_SUCCESS;
}

static void CalcColleyTukeyFactors(uint32_t factors[], std::vector<uint32_t> availableFactors, const uint32_t len)
{
    std::vector<uint32_t> factorsTmp;
    int curFactorIndex = availableFactors.size() - 1;
    uint32_t tmpN = len;

    if (tmpN == LAST_FACTOR * LAST_FACTOR * LAST_FACTOR * COMPLEX_PART) {
        for (size_t i = 0; i < MAX_FACTORS_LEN; i++) {
            factors[i] = LAST_FACTOR;
        }
        factors[0] = LAST_FACTOR * COMPLEX_PART;
    } else if (tmpN > DFT_BORDER_VALUE) {
        while (curFactorIndex >= 0) {
            uint32_t curFactor = availableFactors[curFactorIndex];

            while (tmpN % curFactor == 0) {
                tmpN /= curFactor;
                factorsTmp.emplace_back(curFactor);
            }
            curFactorIndex -= 1;
        }

        while (factorsTmp.size() < MAX_FACTORS_LEN) {
            factorsTmp.emplace_back(1);
        }
        for (size_t i = 0; i < MAX_FACTORS_LEN; i++) {
            factors[i] = factorsTmp[i];
        }
    }
}

static void CalcBluesteinFactors(uint32_t factors[], std::vector<uint32_t> availableFactors, const uint32_t pow2)
{
    std::vector<uint32_t> factorsTmpBluestein;
    if (pow2 == LAST_FACTOR * LAST_FACTOR * LAST_FACTOR * COMPLEX_PART) {
        for (size_t i = 0; i < MAX_FACTORS_LEN; i++) {
            factors[i] = LAST_FACTOR;
        }
        factors[0] = LAST_FACTOR * COMPLEX_PART;
    } else {
        uint32_t tmpN = pow2;
        int curFactorIndex = availableFactors.size() - 1;

        while (curFactorIndex >= 0) {
            uint32_t curFactor = availableFactors[curFactorIndex];
            while (tmpN % curFactor == 0) {
                tmpN /= curFactor;
                factorsTmpBluestein.emplace_back(curFactor);
            }
            curFactorIndex -= 1;
        }

        while (factorsTmpBluestein.size() < MAX_FACTORS_LEN) {
            factorsTmpBluestein.emplace_back(1);
        }

        for (size_t i = 0; i < MAX_FACTORS_LEN; ++i) {
            factors[i] = factorsTmpBluestein[i];
        }
    }
}

void Rfft1DTiling::CalcDftSizes(const uint32_t factors[], const bool isBluestein, const uint32_t len)
{
    uint32_t dftRealOverallSize = 0;
    uint32_t dftImagOverallSize = 0;
    uint32_t twiddleOverallSize = 0;

    size_t prevFactors = 1;

    for (size_t curIndex = 0; curIndex < MAX_FACTORS_LEN; ++curIndex) {
        size_t curFactor = factors[curIndex];
        size_t rowsNum = curFactor * (1 + static_cast<size_t>(curIndex == 0 && isBluestein));
        size_t colsNum = curFactor * (2 - static_cast<size_t>(curIndex != 0));
        size_t dftCurSize = rowsNum * colsNum;
        size_t twiddleCurSize = rowsNum * prevFactors * COMPLEX_PART;

        if (curIndex != 0) {
            dftRealOffsets[curIndex] = dftRealOverallSize;
            dftImagOffsets[curIndex] = dftImagOverallSize;
            if (curIndex != 1) {
                twiddleOffsets[curIndex] = twiddleOverallSize;
            }
        }

        dftRealOverallSize += dftCurSize;
        if (curIndex != 0) {
            dftImagOverallSize += dftCurSize;
            twiddleOverallSize += twiddleCurSize;
        }

        prevFactors *= curFactor;
    }
    uint32_t fftPadRow = len + (COL_PAD - len % COL_PAD);
    uint32_t fftPadCol = ((len % ROW_PAD) != 0) ? len + (ROW_PAD - len % ROW_PAD) : len;
    dftRealOverallSize = len <= DFT_BORDER_VALUE ? fftPadRow * fftPadCol : dftRealOverallSize;

    tiling.set_dftRealOverallSize(dftRealOverallSize);
    tiling.set_dftImagOverallSize(dftImagOverallSize);
    tiling.set_twiddleOverallSize(twiddleOverallSize);
    tiling.set_fftMatrOverallSize(
        dftRealOverallSize + dftImagOverallSize + TWIDDLE_MATRICES_AMOUNT * twiddleOverallSize);

    tiling.set_dftRealOffsets(dftRealOffsets);
    tiling.set_dftImagOffsets(dftImagOffsets);
    tiling.set_twiddleOffsets(twiddleOffsets);
}

ge::graphStatus Rfft1DTiling::DoOpTiling()
{
    auto inputX = context_->GetInputShape(0);
    OP_CHECK_NULL_WITH_CONTEXT(context_, inputX);
    auto xShape = EnsureNotScalar(inputX->GetStorageShape());

    // Calculate number of batches
    for (size_t i = 0; i < xShape.GetDimNum() - 1; i++) {
        batches *= xShape.GetDim(i);
    }

    auto runtimeAttrs = context_->GetAttrs();
    OP_CHECK_NULL_WITH_CONTEXT(context_, runtimeAttrs);
    const uint32_t len = *runtimeAttrs->GetAttrPointer<uint32_t>(0);
    const uint32_t norm = *runtimeAttrs->GetAttrPointer<uint32_t>(1);

    uint32_t factors[MAX_FACTORS_LEN] = {1, 1, 1};
    uint32_t prevRadices[MAX_FACTORS_LEN] = {1, 1, 1};
    uint32_t nextRadices[MAX_FACTORS_LEN] = {len / factors[0], 1, 1};
    uint8_t prevRadicesAlign[MAX_FACTORS_LEN] = {0, 1, 1};

    // Calculate factors
    std::vector<uint32_t> availableFactors(LAST_FACTOR - 1);
    std::iota(availableFactors.begin(), availableFactors.end(), COMPLEX_PART);

    CalcColleyTukeyFactors(factors, availableFactors, len);

    const bool isBluestein = (len % LAST_FACTOR != 0) || (factors[0] * factors[1] * factors[2] != len);
    const uint32_t pow2 = COMPLEX_PART * uint32_t(std::pow(2, std::ceil(std::log2(double(len)))));
    const uint32_t lengthPad = isBluestein ? pow2 : len;

    if (isBluestein) {
        CalcBluesteinFactors(factors, availableFactors, pow2);
    }

    for (uint8_t i = 1; i < MAX_FACTORS_LEN; ++i) {
        prevRadices[i] = prevRadices[i - 1] * factors[i - 1];
        nextRadices[i] = nextRadices[i - 1] / factors[i];
        prevRadicesAlign[i] = (COMPLEX_PART * prevRadices[i] % BYTES_ALIGN) == 0;
    }

    auto roundUpBlock = [](const uint32_t& src, const uint32_t blockLen) {
        return src != 0 ? src + (blockLen - src % blockLen) % blockLen : blockLen;
    };

    const uint32_t tailSize =
        COMPLEX_PART * (((len / COMPLEX_PART) + 1) - (factors[2] / COMPLEX_PART) * (len / factors[2]));
    const uint32_t tmpLenPerBatch = 3 * roundUpBlock(
                                            COMPLEX_PART * (isBluestein ? lengthPad : len) + factors[2] * tailSize + 1,
                                            BYTES_ALIGN * SIZE_PER_BATCH_MULTIPLIER);

    tiling.set_length(len);
    tiling.set_isBluestein(isBluestein);
    tiling.set_lengthPad(lengthPad);
    tiling.set_outLength((len / COMPLEX_PART) + 1);
    tiling.set_batchesPerCore(batches / coreNum);
    tiling.set_leftOverBatches(batches % coreNum);
    tiling.set_normal(norm);
    tiling.set_factors(factors);
    tiling.set_prevRadices(prevRadices);
    tiling.set_nextRadices(nextRadices);
    tiling.set_prevRadicesAlign(prevRadicesAlign);
    tiling.set_tailSize(tailSize);
    tiling.set_tmpLenPerBatch(tmpLenPerBatch);
    tiling.set_tmpSizePerBatch(tmpLenPerBatch * SIZE_PER_BATCH_MULTIPLIER);
    tiling.set_matmulTmpsLen(MATMUL_SIZE_MULTIPLIER * tmpLenPerBatch);
    tiling.set_matmulTmpsSize(MATMUL_SIZE_MULTIPLIER * tmpLenPerBatch * sizeof(float));

    CalcDftSizes(factors, isBluestein, len);

    return ge::GRAPH_SUCCESS;
}

ge::graphStatus Rfft1DTiling::PostTiling()
{
    context_->SetBlockDim(coreNum);

    tiling.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
    context_->GetRawTilingData()->SetDataSize(tiling.GetDataSize());

    auto userWorkspace = USER_WORKSPACE_SIZE;
    auto sysWorkspace = 16 * 1024 * 1024;
    size_t* currentWorkspace = context_->GetWorkspaceSizes(1);
    OP_CHECK_NULL_WITH_CONTEXT(context_, currentWorkspace);
    currentWorkspace[0] = userWorkspace + sysWorkspace;
    return ge::GRAPH_SUCCESS;
}

REGISTER_OPS_TILING_TEMPLATE(Rfft1D, Rfft1DTiling, 10000);

} // namespace optiling