/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
 * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
 * the software repository for the full text of the License.
 */

#ifndef OPTEST_MLA_H
#define OPTEST_MLA_H

#include <stdexcept>
#include <string>

#include <torch/torch.h>
#include <tiling/platform/platform_ascendc.h>

#include "catlass_kernel_prebuilt.h"
#include "common/run_npu_func.h"
#include "torch_utils.h"

namespace CatlassKernelWrapper {

struct MlaHost {
    using OutputType = at::Tensor;

    static void GetKernelInfo(
        const at::Tensor& query_nope,
        const at::Tensor& query_rope,
        const at::Tensor& key_cache,
        const at::Tensor& key_rope_cache,
        const at::Tensor& actual_seq_lengths,
        const at::Tensor& actual_seq_lengths_kv,
        const at::Tensor& block_table,
        int64_t num_heads,
        int64_t num_key_value_heads,
        int64_t sparse_mode,
        CatlassKernel::MlaParams& params)
    {
        aclDataType queryDtype = TorchDtypeToAclDtype(query_nope.scalar_type());
        aclDataType keyDtype = TorchDtypeToAclDtype(key_cache.scalar_type());
        TORCH_CHECK(
            queryDtype == keyDtype,
            "query and key cache must have the same dtype");
        TORCH_CHECK(
            queryDtype == ACL_FLOAT16 || queryDtype == ACL_BF16,
            "mla supports float16 and bfloat16 only");

        TORCH_CHECK(
            actual_seq_lengths.scalar_type() == at::kInt || actual_seq_lengths.scalar_type() == at::kLong,
            "actual_seq_lengths must be int32 or int64");
        TORCH_CHECK(
            actual_seq_lengths_kv.scalar_type() == at::kInt || actual_seq_lengths_kv.scalar_type() == at::kLong,
            "actual_seq_lengths_kv must be int32 or int64");

        int64_t qNtokens = query_nope.size(0);
        int64_t embeddingSize = query_nope.size(2);
        int64_t qRopeHeadDim = query_rope.size(2);
        int64_t kvRopeHeadDim = key_rope_cache.size(3);
        int64_t batch = actual_seq_lengths.numel();
        TORCH_CHECK(
            batch == actual_seq_lengths_kv.numel(),
            "actual_seq_lengths and actual_seq_lengths_kv must have the same size");
        TORCH_CHECK(
            query_nope.size(1) == num_heads,
            "query_nope num_heads mismatch");
        TORCH_CHECK(
            query_rope.size(0) == qNtokens && query_rope.size(1) == num_heads,
            "query_rope shape mismatch");
        TORCH_CHECK(
            key_cache.size(2) == num_key_value_heads,
            "key_cache kv_heads mismatch");

        int64_t qSeqlen = actual_seq_lengths.max().item<int64_t>();
        int64_t kvSeqlen = actual_seq_lengths_kv.max().item<int64_t>();

        uint32_t maskType = 0;
        if (sparse_mode == 0) {
            maskType = 0;
        } else if (sparse_mode == 1) {
            maskType = 1;
        } else {
            throw std::runtime_error("sparse_mode of mla should be 0 or 1");
        }

        params.inputAddr.resize(5);
        params.inputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(query_nope.storage().data()));
        params.inputAddr[1] = static_cast<uint8_t*>(const_cast<void*>(query_rope.storage().data()));
        params.inputAddr[2] = static_cast<uint8_t*>(const_cast<void*>(key_cache.storage().data()));
        params.inputAddr[3] = static_cast<uint8_t*>(const_cast<void*>(key_rope_cache.storage().data()));
        params.inputAddr[4] = static_cast<uint8_t*>(const_cast<void*>(block_table.storage().data()));

        auto qSeqCpu = actual_seq_lengths.contiguous().cpu().to(at::kLong);
        auto kvSeqCpu = actual_seq_lengths_kv.contiguous().cpu().to(at::kLong);
        params.qSeqHost.resize(batch);
        params.kvSeqHost.resize(batch);
        for (int64_t i = 0; i < batch; ++i) {
            params.qSeqHost[static_cast<size_t>(i)] = static_cast<int32_t>(qSeqCpu[i].item<int64_t>());
            params.kvSeqHost[static_cast<size_t>(i)] = static_cast<int32_t>(kvSeqCpu[i].item<int64_t>());
        }

        params.qNtokens = static_cast<uint32_t>(qNtokens);
        params.batch = static_cast<uint32_t>(batch);
        params.qSeqlen = static_cast<uint32_t>(qSeqlen);
        params.kvSeqlen = static_cast<uint32_t>(kvSeqlen);
        params.numHeads = static_cast<uint32_t>(num_heads);
        params.kvHeads = static_cast<uint32_t>(num_key_value_heads);
        params.embeddingSize = static_cast<uint32_t>(embeddingSize);
        params.qRopeHeadDim = static_cast<uint32_t>(qRopeHeadDim);
        params.kvRopeHeadDim = static_cast<uint32_t>(kvRopeHeadDim);
        params.numBlocks = static_cast<uint32_t>(key_cache.size(0));
        params.blockSize = static_cast<uint32_t>(key_cache.size(1));
        params.maskType = maskType;
        params.dataType = queryDtype;
    }

    static OutputType AllocOutput(CatlassKernel::MlaParams& params)
    {
        OutputType output = GetOutputTensor(
            {params.qNtokens, params.numHeads, params.embeddingSize},
            AclDtypeToTorchDtype(params.dataType));
        params.outputAddr.resize(1);
        params.outputAddr[0] = static_cast<uint8_t*>(const_cast<void*>(output.storage().data()));
        return output;
    }
};

struct MlaOp : MlaHost {
    static OutputType Run(
        const at::Tensor& query_nope,
        const at::Tensor& query_rope,
        const at::Tensor& key_cache,
        const at::Tensor& key_rope_cache,
        const at::Tensor& actual_seq_lengths,
        const at::Tensor& actual_seq_lengths_kv,
        const at::Tensor& block_table,
        int64_t num_heads,
        int64_t num_key_value_heads,
        int64_t sparse_mode)
    {
        CatlassKernel::MlaParams params;
        GetKernelInfo(
            query_nope, query_rope, key_cache, key_rope_cache,
            actual_seq_lengths, actual_seq_lengths_kv, block_table,
            num_heads, num_key_value_heads, sparse_mode, params);
        OutputType output = AllocOutput(params);
        aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
        uint32_t aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic();
        RUN_NPU_FUNC(CatlassKernel::Mla, aicCoreNum, stream, params);
        return output;
    }
};

} // namespace CatlassKernelWrapper

#endif