/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <cstdio>
#include <cstring>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
#include <cmath>

// # include "catlass/detail/alignment.hpp"
# include "fai_tiling.h"
# include "fai_tiling_data.h"

using namespace std;
namespace FAInferTiling {
    void FillBasicTilingData(const FAInfo &faInfo, FATilingData &faTilingData, int64_t maxKvSeqlen)
    {
        uint32_t maxNumBlocksPerBatch = (maxKvSeqlen + faInfo.blockSize - 1) / faInfo.blockSize;
        float scaleValue = static_cast<float>(1.0 / std::sqrt(1.0 * faInfo.embeddingSize));
        faTilingData.batch = static_cast<uint32_t>(faInfo.batch);
        faTilingData.numHeads = static_cast<uint32_t>(faInfo.numHeads);
        faTilingData.kvHeads = static_cast<uint32_t>(faInfo.kvHeads);
        faTilingData.embeddingSize = static_cast<uint32_t>(faInfo.embeddingSize);
        faTilingData.numBlocks = static_cast<uint32_t>(faInfo.numBlocks);
        faTilingData.blockSize = static_cast<uint32_t>(faInfo.blockSize);
        faTilingData.maxKvSeqlen = static_cast<uint32_t>(maxKvSeqlen);
        faTilingData.maxNumBlocksPerBatch = maxNumBlocksPerBatch;
        faTilingData.maskType = static_cast<uint32_t>(faInfo.maskType);
        faTilingData.scaleValue = scaleValue;
    }

    uint32_t GetQNBlockTile(int64_t qSeqlen, uint32_t groupSize)
    {
        uint32_t qRowNumCeil = 128;
        // A trick is used to ensure the qN tile is a even number,
        // thus most tasks have balanced workload between two vec cores,
        // and each vec core possess no more than 64 rows when all-rounded row num is no larger than 128,
        // aiding the coding of rescale block
        uint32_t qNBlockTile = (qRowNumCeil / qSeqlen) / 2 * 2;
        qNBlockTile = std::min(qNBlockTile, groupSize);
        qNBlockTile = std::max(qNBlockTile, static_cast<uint32_t>(1));
        return qNBlockTile;
    }

    uint32_t GetQSBlockTile(int64_t kvSeqlen)
    {
        uint32_t qSBlockTile = 128;
        return qSBlockTile;
    }

    void FillSplitCoreTilingData(const FAInfo &faInfo, FATilingData &faTilingData)
    {
        uint32_t totalTaskNum = 0;
        uint32_t groupSize = faInfo.numHeads / faInfo.kvHeads;
        for (int32_t batchIdx = 0; batchIdx < faInfo.batch; batchIdx++) {
            int64_t qSeqlen = *(faInfo.qSeqlenList + batchIdx);
            int64_t kvSeqlen = *(faInfo.kvSeqlenList + batchIdx);
            uint32_t curQNBlockTile = GetQNBlockTile(qSeqlen, groupSize);
            uint32_t qNBlockNumPerGroup = (groupSize + curQNBlockTile - 1) / curQNBlockTile;
            uint32_t curQNBlockNum = qNBlockNumPerGroup * faInfo.kvHeads;
            uint32_t curQSBlockTile = GetQSBlockTile(kvSeqlen);
            uint32_t curQSBlockNum = (qSeqlen + curQSBlockTile - 1) / curQSBlockTile;
            uint32_t curTaskNum = curQNBlockNum * curQSBlockNum;
            if (batchIdx == 0) {
                faTilingData.firstBatchTaskNum = curTaskNum;
            }
            totalTaskNum += curTaskNum;
        }
        faTilingData.totalTaskNum = totalTaskNum;
    }

    void FillWorkSpaceTilingData(uint32_t blockDim, FATilingData &faTilingData)
    {
        uint64_t mm1OutSize = blockDim * WORKSPACE_BLOCK_SIZE_DB * NUM4 * NUM3;
        uint64_t smOnlineOutSize = blockDim * WORKSPACE_BLOCK_SIZE_DB * NUM2 * NUM3;
        uint64_t mm2OutSize = blockDim * WORKSPACE_BLOCK_SIZE_DB * NUM4 * NUM3;
        uint64_t UpdateSize = blockDim * WORKSPACE_BLOCK_SIZE_DB * NUM4 * NUM3;
        uint64_t workSpaceSize = mm1OutSize + smOnlineOutSize + mm2OutSize + UpdateSize;
        faTilingData.mm1OutSize = mm1OutSize;
        faTilingData.smOnlineOutSize = smOnlineOutSize;
        faTilingData.mm2OutSize = mm2OutSize;
        faTilingData.UpdateSize = UpdateSize;
        faTilingData.workSpaceSize = workSpaceSize;
    }

    int32_t GetFATilingParam(const FAInfo &faInfo, uint32_t blockDim, FATilingData &faTilingData)
    {
        if (faInfo.qSeqlenList == nullptr || faInfo.kvSeqlenList == nullptr) {
            cerr << "[ERROR] pointer tilingData or seq is nullptr." << endl;
            return -1;
        }
        if (faInfo.blockSize != NUM128) {
            cerr << "[ERROR] blockSize != 128 is not supported." << endl;
            return -1;
        }
        int64_t maxKvSeqlen = 0;
        for (int32_t batchIdx = 0; batchIdx < faInfo.batch; batchIdx++) {
            int64_t qSeqlen = *(faInfo.qSeqlenList + batchIdx);
            int64_t kvSeqlen = *(faInfo.kvSeqlenList + batchIdx);
            maxKvSeqlen = std::max(maxKvSeqlen, kvSeqlen);
        }
        FillBasicTilingData(faInfo, faTilingData, maxKvSeqlen);
        FillSplitCoreTilingData(faInfo, faTilingData);
        FillWorkSpaceTilingData(blockDim, faTilingData);
        return 0;
    }
} // namespace FAInferTiling