/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file post_utest.cpp
 * \brief
 */

#include "gtest/gtest.h"
#include "tilefwk/tilefwk_op.h"

#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/tensor/raw_tensor.h"
#include "interface/interpreter/raw_tensor_data.h"
#include "operator/models/nsa/attention_post.h"
#include "interface/configs/config_manager.h"
#include "interface/tensor/float.h"

using namespace npu::tile_fwk;

class AttentionPostUTest : public testing::Test {
public:
    static void SetUpTestCase() {}

    static void TearDownTestCase() {}

    void SetUp() override { Program::GetInstance().Reset(); }

    void TearDown() override {}
};

struct TestPostParams {
    int b;
    int n;
    int s;
    int h;
    int kvLoraRank;
    int vHeadDim;
};

template <
    typename T = npu::tile_fwk::float16, bool nz = true, typename wUvDtype = int8_t, bool isSmoothWuv = false,
    typename wODtype = int8_t, bool isSmoothWo = false>
void TestAttentionPostUt(const TestPostParams& params, const PostTileConfig& tileConfig)
{
    int b = params.b;
    int n = params.n;
    int s = params.s;
    int h = params.h;
    int kvLoraRank = params.kvLoraRank;
    int vHeadDim = params.vHeadDim;

    DataType dType = (std::is_same<T, npu::tile_fwk::float16>::value) ? DT_FP16 : DT_BF16;
    bool isQuantWUv = std::is_same<wUvDtype, int8_t>::value;
    bool isQuantWo = std::is_same<wODtype, int8_t>::value;

    std::vector<int64_t> xShape = {b, s, n, kvLoraRank};
    std::vector<int64_t> wUvShape = {n, kvLoraRank, vHeadDim};
    std::vector<int64_t> wUvScaleShape = {n, 1, vHeadDim};
    std::vector<int64_t> smoothWUvShape = {1, kvLoraRank};
    std::vector<int64_t> woShape = {n * vHeadDim, h};
    std::vector<int64_t> woScaleShape = {1, h};
    std::vector<int64_t> smoothWoShape = {1, n * vHeadDim};
    std::vector<int64_t> outShape = {b, s, h};

    TileOpFormat weightFormat = nz ? TileOpFormat::TILEOP_NZ : TileOpFormat::TILEOP_ND;
    Tensor x(dType, xShape, "x");
    Tensor wUv(isQuantWUv ? DT_INT8 : dType, wUvShape, "wUv");
    Tensor wUvScale;
    Tensor smoothWUv;
    Tensor wo(isQuantWo ? DT_INT8 : dType, woShape, "wo", weightFormat);
    Tensor woScale;
    Tensor smoothWo;
    Tensor postOut(dType, outShape, "postOut");

    if (isQuantWUv) {
        Tensor scale(DT_FP32, wUvScaleShape, "wUvScale");
        wUvScale = scale;
        if (isSmoothWuv) {
            Tensor smooth(DT_FP32, smoothWUvShape, "smoothWUv");
            smoothWUv = smooth;
        }
    }
    if (isQuantWo) {
        Tensor scale(DT_FP32, woScaleShape, "woScale");
        woScale = scale;
        if (isSmoothWo) {
            Tensor smooth(DT_FP32, smoothWoShape, "smoothWo");
            smoothWo = smooth;
        }
    }

    PostTensors postTensors{wUv, wo, wUvScale, smoothWUv, woScale, smoothWo};
    AttentionPostStandalone(x, postTensors, tileConfig, postOut);
}