/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file test_batchmatmul_operation.cpp
 * \brief
 */

#include "test_operation.h"

using namespace tile_fwk::test_operation;
namespace {
struct BatchMatmulOpFuncArgs : public OpFuncArgs {
    BatchMatmulOpFuncArgs(
        const std::vector<int64_t>& viewShape, const std::vector<std::vector<int64_t>>& tileShape,
        const MatmulTestCaseParam& param)
        : viewShape_(viewShape), tileShape_(tileShape), param_(param)
    {}

    std::vector<int64_t> viewShape_;
    std::vector<std::vector<int64_t>> tileShape_;
    MatmulTestCaseParam param_;
};

struct BatchMatmulOpMetaData {
    explicit BatchMatmulOpMetaData(const OpFunc& opFunc, const nlohmann::json& test_data)
        : opFunc_(opFunc), test_data_(test_data)
    {}

    OpFunc opFunc_;
    nlohmann::json test_data_;
};

struct BatchMatmulTileParam {
    bool transA;
    bool transB;
    int mDim;
    int kDim;
    int nDim;
    int mView;
    int nView;

    std::vector<int64_t> aViewShape;
    std::vector<int64_t> bViewShape;
    std::vector<SymbolicScalar> aValidShape;
    std::vector<SymbolicScalar> bValidShape;
    std::vector<SymbolicScalar> aOffset;
    std::vector<SymbolicScalar> bOffset;
    std::vector<SymbolicScalar> cOffset;

    std::vector<int64_t> vecTileShape;
};

static void GetBatchMatmulTileParam(
    const std::vector<Tensor>& inputs, const OpFuncArgs* opArgs, BatchMatmulTileParam& tileParam)
{
    auto args = static_cast<const BatchMatmulOpFuncArgs*>(opArgs);
    tileParam.transA = args->param_.transA;
    tileParam.transB = args->param_.transB;
    size_t inputDim = inputs[0].GetShape().size();
    const size_t DIM_OFFSET_2 = 2;
    tileParam.mDim =
        tileParam.transA ? inputs[0].GetShape()[inputDim - 1] : inputs[0].GetShape()[inputDim - DIM_OFFSET_2];
    tileParam.kDim =
        tileParam.transA ? inputs[0].GetShape()[inputDim - DIM_OFFSET_2] : inputs[0].GetShape()[inputDim - 1];
    tileParam.nDim =
        tileParam.transB ? inputs[1].GetShape()[inputDim - DIM_OFFSET_2] : inputs[1].GetShape()[inputDim - 1];
    tileParam.mView = args->viewShape_[inputDim - DIM_OFFSET_2];
    tileParam.nView = args->viewShape_[inputDim - 1UL];

    tileParam.aViewShape = {inputs[0].GetShape().begin(), inputs[0].GetShape().end() - DIM_OFFSET_2};
    tileParam.bViewShape = {inputs[1].GetShape().begin(), inputs[1].GetShape().end() - DIM_OFFSET_2};

    tileParam.aValidShape = {inputs[0].GetShape().begin(), inputs[0].GetShape().end() - DIM_OFFSET_2};
    tileParam.bValidShape = {inputs[1].GetShape().begin(), inputs[1].GetShape().end() - DIM_OFFSET_2};

    tileParam.aOffset = std::vector<SymbolicScalar>(inputDim - DIM_OFFSET_2, 0);
    tileParam.bOffset = std::vector<SymbolicScalar>(inputDim - DIM_OFFSET_2, 0);
    tileParam.cOffset = std::vector<SymbolicScalar>(inputDim - DIM_OFFSET_2, 0);

    tileParam.vecTileShape = std::vector<int64_t>(inputDim - DIM_OFFSET_2, 1);
    tileParam.vecTileShape.insert(tileParam.vecTileShape.end(), {args->tileShape_[0][1], args->tileShape_[1][1]});
}

static Tensor CallBatchMatmulOp(const Tensor& tensorA, const Tensor& tensorB, const MatmulTestCaseParam& param)
{
    if (!param.transA && !param.transB && !param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, false, false, false);
    } else if (!param.transA && !param.transB && param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, false, false, true);
    } else if (!param.transA && param.transB && !param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, false, true, false);
    } else if (!param.transA && param.transB && param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, false, true, true);
    } else if (param.transA && !param.transB && !param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, true, false, false);
    } else if (param.transA && !param.transB && param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, true, false, true);
    } else if (param.transA && param.transB && !param.isCMatrixNz) {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, true, true, false);
    } else {
        return Matrix::BatchMatmul(param.outDtype, tensorA, tensorB, true, true, true);
    }
}

static void BatchMatmulOperationExeFuncNoSplit(
    const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
    auto args = static_cast<const BatchMatmulOpFuncArgs*>(opArgs);
    BatchMatmulTileParam tileParam;
    GetBatchMatmulTileParam(inputs, opArgs, tileParam);
    size_t inputDim = inputs[0].GetShape().size();
    tileParam.aOffset = std::vector<SymbolicScalar>(inputDim, 0);
    tileParam.bOffset = std::vector<SymbolicScalar>(inputDim, 0);
    tileParam.aValidShape = {inputs[0].GetShape().begin(), inputs[0].GetShape().end()};
    tileParam.bValidShape = {inputs[1].GetShape().begin(), inputs[1].GetShape().end()};

    FUNCTION("testNoSplit", {inputs[0], inputs[1]}, {outputs[0]})
    {
        LOOP("mLoop", FunctionType::DYNAMIC_LOOP, mIdx, LoopRange(1))
        {
            tileParam.aOffset[inputDim - 1] = mIdx;
            Tensor tensorA = View(inputs[0], inputs[0].GetShape(), tileParam.aValidShape, tileParam.aOffset);
            Tensor tensorB = View(inputs[1], inputs[1].GetShape(), tileParam.bValidShape, tileParam.bOffset);
            TileShape::Current().SetCubeTile(
                {args->tileShape_[0][0], args->tileShape_[0][1]}, {args->tileShape_[1][0], args->tileShape_[1][1]},
                {args->tileShape_[2][0], args->tileShape_[2][1]});
            if (args->param_.isAMatrixNz || args->param_.isBMatrixNz || args->param_.isCMatrixNz) {
                TileShape::Current().SetMatrixSize({tileParam.mDim, tileParam.kDim, tileParam.nDim});
            }
            outputs[0] = CallBatchMatmulOp(tensorA, tensorB, args->param_);
        }
    }
}

static void BatchMatmulOperationExeFuncSplitM(
    const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
    auto args = static_cast<const BatchMatmulOpFuncArgs*>(opArgs);
    BatchMatmulTileParam tileParam;
    GetBatchMatmulTileParam(inputs, opArgs, tileParam);
    tileParam.bValidShape = {inputs[1].GetShape().begin(), inputs[1].GetShape().end()};

    FUNCTION("testMSplit", {inputs[0], inputs[1]}, {outputs[0]})
    {
        LOOP(
            "mLoop", FunctionType::DYNAMIC_LOOP, mIdx,
            LoopRange(0, CeilDivSymbolicScalar(tileParam.mDim, tileParam.mView), 1))
        {
            if (tileParam.transA) {
                tileParam.aOffset.insert(tileParam.aOffset.end(), {0, mIdx * tileParam.mView});
                tileParam.aViewShape.insert(tileParam.aViewShape.end(), {tileParam.kDim, tileParam.mView});
                tileParam.aValidShape.insert(
                    tileParam.aValidShape.end(),
                    {tileParam.kDim, std::min(tileParam.mDim - tileParam.mView * mIdx, tileParam.mView)});
            } else {
                tileParam.aOffset.insert(tileParam.aOffset.end(), {mIdx * tileParam.mView, 0});
                tileParam.aViewShape.insert(tileParam.aViewShape.end(), {tileParam.mView, tileParam.kDim});
                tileParam.aValidShape.insert(
                    tileParam.aValidShape.end(),
                    {std::min(tileParam.mDim - tileParam.mView * mIdx, tileParam.mView), tileParam.kDim});
            }
            Tensor tensorA = View(inputs[0], tileParam.aViewShape, tileParam.aValidShape, tileParam.aOffset);

            tileParam.bOffset.insert(tileParam.bOffset.end(), {0, 0});
            Tensor tensorB = View(inputs[1], inputs[1].GetShape(), tileParam.bValidShape, tileParam.bOffset);

            TileShape::Current().SetVecTile(tileParam.vecTileShape);
            TileShape::Current().SetCubeTile(
                {args->tileShape_[0][0], args->tileShape_[0][1]}, {args->tileShape_[1][0], args->tileShape_[1][1]},
                {args->tileShape_[2][0], args->tileShape_[2][1]});
            if (args->param_.isAMatrixNz || args->param_.isBMatrixNz || args->param_.isCMatrixNz) {
                TileShape::Current().SetMatrixSize({tileParam.mDim, tileParam.kDim, tileParam.nDim});
            }
            Tensor tensorC = CallBatchMatmulOp(tensorA, tensorB, args->param_);
            tileParam.cOffset.insert(tileParam.cOffset.end(), {mIdx * tileParam.mView, 0});
            Assemble(tensorC, tileParam.cOffset, outputs[0]);
        }
    }
}

static void BatchMatmulOperationExeFuncSplitN(
    const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
    auto args = static_cast<const BatchMatmulOpFuncArgs*>(opArgs);
    BatchMatmulTileParam tileParam;
    GetBatchMatmulTileParam(inputs, opArgs, tileParam);
    tileParam.aValidShape = {inputs[0].GetShape().begin(), inputs[0].GetShape().end()};

    FUNCTION("testNSplit", {inputs[0], inputs[1]}, {outputs[0]})
    {
        LOOP(
            "nLoop", FunctionType::DYNAMIC_LOOP, nIdx,
            LoopRange(0, CeilDivSymbolicScalar(tileParam.nDim, tileParam.nView), 1))
        {
            tileParam.aOffset.insert(tileParam.aOffset.end(), {0, 0});
            Tensor tensorA = View(inputs[0], inputs[0].GetShape(), tileParam.aValidShape, tileParam.aOffset);
            if (tileParam.transB) {
                tileParam.bOffset.insert(tileParam.bOffset.end(), {nIdx * tileParam.nView, 0});
                tileParam.bViewShape.insert(tileParam.bViewShape.end(), {tileParam.nView, tileParam.kDim});
                tileParam.bValidShape.insert(
                    tileParam.bValidShape.end(),
                    {std::min(tileParam.nDim - nIdx * tileParam.nView, tileParam.nView), tileParam.kDim});
            } else {
                tileParam.bOffset.insert(tileParam.bOffset.end(), {0, nIdx * tileParam.nView});
                tileParam.bViewShape.insert(tileParam.bViewShape.end(), {tileParam.kDim, tileParam.nView});
                tileParam.bValidShape.insert(
                    tileParam.bValidShape.end(),
                    {tileParam.kDim, std::min(tileParam.nDim - nIdx * tileParam.nView, tileParam.nView)});
            }
            Tensor tensorB = View(inputs[1], tileParam.bViewShape, tileParam.bValidShape, tileParam.bOffset);
            TileShape::Current().SetCubeTile(
                {args->tileShape_[0][0], args->tileShape_[0][1]}, {args->tileShape_[1][0], args->tileShape_[1][1]},
                {args->tileShape_[2][0], args->tileShape_[2][1]});
            TileShape::Current().SetVecTile(tileParam.vecTileShape);
            if (args->param_.isAMatrixNz || args->param_.isBMatrixNz || args->param_.isCMatrixNz) {
                TileShape::Current().SetMatrixSize({tileParam.mDim, tileParam.kDim, tileParam.nDim});
            }
            Tensor tensorC = CallBatchMatmulOp(tensorA, tensorB, args->param_);
            tileParam.cOffset.insert(tileParam.cOffset.end(), {0, nIdx * tileParam.nView});
            Assemble(tensorC, tileParam.cOffset, outputs[0]);
        }
    }
}

static void BatchMatmulOperationExeFuncSplitMN(
    const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
    auto args = static_cast<const BatchMatmulOpFuncArgs*>(opArgs);
    BatchMatmulTileParam tileParam;
    GetBatchMatmulTileParam(inputs, opArgs, tileParam);

    FUNCTION("testMNSplit", {inputs[0], inputs[1]}, {outputs[0]})
    {
        LOOP(
            "mLoop", FunctionType::DYNAMIC_LOOP, mIdx,
            LoopRange(0, CeilDivSymbolicScalar(tileParam.mDim, tileParam.mView), 1))
        {
            LOOP(
                "nLoop", FunctionType::DYNAMIC_LOOP, nIdx,
                LoopRange(0, CeilDivSymbolicScalar(tileParam.nDim, tileParam.nView), 1))
            {
                if (tileParam.transA) {
                    tileParam.aViewShape.insert(tileParam.aViewShape.end(), {tileParam.kDim, tileParam.mView});
                    tileParam.aValidShape.insert(
                        tileParam.aValidShape.end(),
                        {tileParam.kDim, std::min(tileParam.mDim - tileParam.mView * mIdx, tileParam.mView)});
                    tileParam.aOffset.insert(tileParam.aOffset.end(), {0, mIdx * tileParam.mView});
                } else {
                    tileParam.aViewShape.insert(tileParam.aViewShape.end(), {tileParam.mView, tileParam.kDim});
                    tileParam.aValidShape.insert(
                        tileParam.aValidShape.end(),
                        {std::min(tileParam.mDim - tileParam.mView * mIdx, tileParam.mView), tileParam.kDim});
                    tileParam.aOffset.insert(tileParam.aOffset.end(), {mIdx * tileParam.mView, 0});
                }
                Tensor tensorA = View(inputs[0], tileParam.aViewShape, tileParam.aValidShape, tileParam.aOffset);

                if (tileParam.transB) {
                    tileParam.bViewShape.insert(tileParam.bViewShape.end(), {tileParam.nView, tileParam.kDim});
                    tileParam.bValidShape.insert(
                        tileParam.bValidShape.end(),
                        {std::min(tileParam.nDim - nIdx * tileParam.nView, tileParam.nView), tileParam.kDim});
                    tileParam.bOffset.insert(tileParam.bOffset.end(), {nIdx * tileParam.nView, 0});
                } else {
                    tileParam.bViewShape.insert(tileParam.bViewShape.end(), {tileParam.kDim, tileParam.nView});
                    tileParam.bValidShape.insert(
                        tileParam.bValidShape.end(),
                        {tileParam.kDim, std::min(tileParam.nDim - nIdx * tileParam.nView, tileParam.nView)});
                    tileParam.bOffset.insert(tileParam.bOffset.end(), {0, nIdx * tileParam.nView});
                }
                Tensor tensorB = View(inputs[1], tileParam.bViewShape, tileParam.bValidShape, tileParam.bOffset);

                TileShape::Current().SetVecTile(tileParam.vecTileShape);
                TileShape::Current().SetCubeTile(
                    {args->tileShape_[0][0], args->tileShape_[0][1]}, {args->tileShape_[1][0], args->tileShape_[1][1]},
                    {args->tileShape_[2][0], args->tileShape_[2][1]});
                if (args->param_.isAMatrixNz || args->param_.isBMatrixNz || args->param_.isCMatrixNz) {
                    TileShape::Current().SetMatrixSize({tileParam.mDim, tileParam.kDim, tileParam.nDim});
                }
                Tensor tensorC = CallBatchMatmulOp(tensorA, tensorB, args->param_);
                tileParam.cOffset.insert(tileParam.cOffset.end(), {mIdx * tileParam.mView, nIdx * tileParam.nView});
                Assemble(tensorC, tileParam.cOffset, outputs[0]);
            }
        }
    }
}

static void BatchMatmulOperationExeFunc(
    const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, const OpFuncArgs* opArgs)
{
    ASSERT(inputs[0].GetShape().size() == inputs[1].GetShape().size());
    auto args = static_cast<const BatchMatmulOpFuncArgs*>(opArgs);
    size_t inputDim = inputs[0].GetShape().size();
    ASSERT(args->viewShape_.size() == inputDim);
    const size_t DIM_OFFSET_2 = 2;
    const int mView = args->viewShape_[inputDim - DIM_OFFSET_2];
    const int nView = args->viewShape_[inputDim - 1UL];

    if (mView > 0 && nView > 0) {
        return BatchMatmulOperationExeFuncSplitMN(inputs, outputs, opArgs);
    } else if (mView > 0) {
        return BatchMatmulOperationExeFuncSplitM(inputs, outputs, opArgs);
    } else if (nView > 0) {
        return BatchMatmulOperationExeFuncSplitN(inputs, outputs, opArgs);
    } else {
        return BatchMatmulOperationExeFuncNoSplit(inputs, outputs, opArgs);
    }
}

class BatchMatmulOperationTest : public npu::tile_fwk::stest::TestSuite_STest_Ops_Aihac_param<BatchMatmulOpMetaData> {};

INSTANTIATE_TEST_SUITE_P(
    TestBatchMatmul, BatchMatmulOperationTest,
    ::testing::ValuesIn(GetOpMetaData<BatchMatmulOpMetaData>({BatchMatmulOperationExeFunc}, "BatchMatmul")));

TEST_P(BatchMatmulOperationTest, TestBatchMatmul)
{
    TestCaseDesc testCase;
    auto test_data = GetParam().test_data_;
    testCase.inputTensors = GetMatmulTensors(test_data, "input_tensors");
    testCase.outputTensors = GetMatmulTensors(test_data, "output_tensors");
    auto args =
        BatchMatmulOpFuncArgs(GetViewShape(test_data), GetMatmulTileShape(test_data), GetMatmulParam(test_data));
    testCase.args = &args;
    testCase.opFunc = GetParam().opFunc_;
    testCase.inputPaths = {
        GetGoldenDir() + "/" + testCase.inputTensors[0].GetStorage()->Symbol() + ".bin",
        GetGoldenDir() + "/" + testCase.inputTensors[1].GetStorage()->Symbol() + ".bin"};
    testCase.goldenPaths = {GetGoldenDir() + "/" + testCase.outputTensors[0].GetStorage()->Symbol() + ".bin"};
    TestExecutor::runTest(testCase);
}

class BatchMatmulVerifyOperationTest
    : public npu::tile_fwk::stest::TestSuite_STest_Ops_Aihac_param<BatchMatmulOpMetaData> {};

INSTANTIATE_TEST_SUITE_P(
    TestBatchMatmulVerify, BatchMatmulVerifyOperationTest,
    ::testing::ValuesIn(GetOpMetaData<BatchMatmulOpMetaData>({BatchMatmulOperationExeFunc}, "BatchMatmulVerify")));

TEST_P(BatchMatmulVerifyOperationTest, TestBatchMatmulVerify)
{
    TestCaseDesc testCase;
    auto test_data = GetParam().test_data_;
    testCase.outputTensors = GetMatmulTensors(test_data, "output_tensors");
    testCase.inputTensors = GetMatmulTensors(test_data, "input_tensors");
    auto args =
        BatchMatmulOpFuncArgs(GetViewShape(test_data), GetMatmulTileShape(test_data), GetMatmulParam(test_data));
    testCase.args = &args;
    testCase.opFunc = GetParam().opFunc_;
    testCase.inputPaths = {
        GetGoldenDir() + "/" + testCase.inputTensors[0].GetStorage()->Symbol() + ".bin",
        GetGoldenDir() + "/" + testCase.inputTensors[1].GetStorage()->Symbol() + ".bin"};
    testCase.goldenPaths = {GetGoldenDir() + "/" + testCase.outputTensors[0].GetStorage()->Symbol() + ".bin"};
    TestFlowVerifier::runTest(testCase);
}
} // namespace