/* -------------------------------------------------------------------------

 * This file is part of the MindStudio project.

 * Copyright (c) 2025 Huawei Technologies Co.,Ltd.

 *

 * MindStudio is licensed under Mulan PSL v2.

 * You can use this software according to the terms and conditions of the Mulan PSL v2.

 * You may obtain a copy of Mulan PSL v2 at:

 *

 *          http://license.coscl.org.cn/MulanPSL2

 *

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,

 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,

 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.

 * See the Mulan PSL v2 for more details.

 * ------------------------------------------------------------------------- */



#ifndef CATLASS_LIBRARY_GEMM_OPERATION_H

#define CATLASS_LIBRARY_GEMM_OPERATION_H



#include <type_traits>

#include "catlass/library/operation.h"

#include "library_utils.h"



namespace Catlass {

namespace Library {



template <typename Operator_>

class GemmOperationBase : public Operation {

public:

    using Operator = Operator_;

    using OperatorArguments = typename Operator::Arguments;

    using OperatorKernel = typename Operator::Kernel;



    using ElementA = typename OperatorKernel::ElementA;

    using ElementB = typename OperatorKernel::ElementB;

    using ElementC = typename OperatorKernel::ElementC;

    using LayoutA = typename OperatorKernel::LayoutA;

    using LayoutB = typename OperatorKernel::LayoutB;

    using LayoutC = typename OperatorKernel::LayoutC;

    using BlockMmad = typename OperatorKernel::BlockMmad;

    using ArchTag = typename OperatorKernel::ArchTag;

    using L1TileShape = typename BlockMmad::L1TileShape;

    using L0TileShape = typename BlockMmad::L0TileShape;

    using BlockScheduler = typename OperatorKernel::BlockScheduler;



    GemmOperationBase(char const *name = "")

    {

        this->description_.name = name;

        this->description_.kind = OperationKind::Gemm;



        this->description_.A = MakeTensorDescription<ElementA, LayoutA>();

        this->description_.B = MakeTensorDescription<ElementB, LayoutB>();

        this->description_.C = MakeTensorDescription<ElementC, LayoutC>();



        this->description_.tileDescription.L1TileShape =

            GemmShapeDescription(L1TileShape::M, L1TileShape::N, L1TileShape::K);

        this->description_.tileDescription.L0TileShape =

            GemmShapeDescription(L0TileShape::M, L0TileShape::N, L0TileShape::K);

    }



    virtual OperationDescription const &GetDescription() const override

    {

        return this->description_;

    }



    virtual Status CanImplement(void *argsPtr, void *configPtr) override

    {

        BuildArgs(argsPtr, configPtr);

        return op_.CanImplement(this->args_);

    }



    virtual size_t GetWorkspaceSize(void *argsPtr, void *configPtr) override

    {

        BuildArgs(argsPtr, configPtr);

        return op_.GetWorkspaceSize(this->args_);

    }



    virtual Status Initialize(

        void *argsPtr,

        void *configPtr,

        uint8_t *workspace,

        aclrtStream stream

    ) override

    {

        BuildArgs(argsPtr, configPtr);

        return op_.Initialize(this->args_, workspace, stream);

    }



    virtual Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) override

    {

        return op_.Run(stream, blockDim, fftsAddr);

    }



protected:

    virtual void BuildArgs(void *argsPtr, void *configPtr) = 0;



    GemmOperationDescription description_;

    OperatorArguments args_{};

    Operator op_;

};



/********************* basic matmul *********************/

template <typename Operator_>

class BasicMatmulGemmOperation : public GemmOperationBase<Operator_> {

public:

    BasicMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)

    {

        this->description_.gemmKind = GemmKind::BasicMatmul;

    }



private:

    virtual void BuildArgs(void *argsPtr, void *configPtr) override

    {

        BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;

        BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;

        this->args_.problemShape = GemmCoord{config->m, config->n, config->k};

        this->args_.ptrA = arguments->A;

        this->args_.ptrB = arguments->B;

        this->args_.ptrC = arguments->C;

    }

};

/********************* basic matmul end *********************/



/********************* grouped matmul *********************/

template <typename Operator_>

class GroupedMatmulGemmOperation : public GemmOperationBase<Operator_> {

public:

    GroupedMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)

    {

        this->description_.gemmKind = GemmKind::GroupedMatmul;

    }



private:

    virtual void BuildArgs(void *argsPtr, void *configPtr) override

    {

        GroupedMatmulGemmArguments *arguments = (GroupedMatmulGemmArguments *)argsPtr;

        GroupedMatmulGemmConfiguration *config = (GroupedMatmulGemmConfiguration *)configPtr;



        this->args_.problemCount = config->groupCount;

        this->args_.ptrProblemShape = arguments->problemShapeList;

        this->args_.ptrA = arguments->A;

        this->args_.ptrLayoutA = arguments->layoutAList;

        this->args_.ptrB = arguments->B;

        this->args_.ptrLayoutB = arguments->layoutBList;

        this->args_.ptrC = arguments->C;

        this->args_.ptrLayoutC = arguments->layoutCList;

    }

};

/********************* grouped matmul end *********************/



/********************* grouped matmul slice M *********************/

template <typename Operator_>

class GroupedMatmulSliceMGemmOperation : public GemmOperationBase<Operator_> {

public:

    GroupedMatmulSliceMGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)

    {

        this->description_.gemmKind = GemmKind::GroupedMatmulSliceM;

    }



private:

    virtual void BuildArgs(void *argsPtr, void *configPtr) override

    {

        GroupedMatmulSliceMGemmArguments *arguments = (GroupedMatmulSliceMGemmArguments *)argsPtr;

        GroupedMatmulSliceMGemmConfiguration *config = (GroupedMatmulSliceMGemmConfiguration *)configPtr;



        this->args_.problemShape = GemmCoord{config->m, config->n, config->k};

        this->args_.problemCount = config->groupCount;

        this->args_.ptrGroupList = arguments->deviceGroupList;

        this->args_.ptrA = arguments->A;

        this->args_.ptrB = arguments->B;

        this->args_.ptrC = arguments->C;

    }

};

/********************* grouped matmul slice M end *********************/



/********************* optimized matmul *********************/

template <typename Operator_>

class OptimizedMatmulGemmOperation : public GemmOperationBase<Operator_> {

    using Operator = Operator_;

    using OperatorKernel = typename Operator::Kernel;

    using ElementA = typename OperatorKernel::ElementA;

    using ElementB = typename OperatorKernel::ElementB;

    using LayoutA = typename OperatorKernel::LayoutA;

    using LayoutB = typename OperatorKernel::LayoutB;

    using L1TileShape = typename OperatorKernel::BlockMmad::L1TileShape;

public:

    OptimizedMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)

    {

        this->description_.gemmKind = GemmKind::OptimizedMatmul;

    }



private:

    virtual void BuildArgs(void *argsPtr, void *configPtr) override

    {

        BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;

        BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;



        constexpr uint32_t alignByByte = 512;



        this->args_.problemShape = GemmCoord{config->m, config->n, config->k};

        this->args_.ptrA = arguments->A;

        this->args_.ptrB = arguments->B;

        this->args_.ptrC = arguments->C;



        constexpr uint32_t alignByElement = 512 / sizeof(half);



        if constexpr (!std::is_same<typename OperatorKernel::LayoutWA, typename OperatorKernel::LayoutA>::value) {

            isThisKernelPaddingA_ = true;

        }

        if constexpr (!std::is_same<typename OperatorKernel::LayoutWB, typename OperatorKernel::LayoutB>::value) {

            isThisKernelPaddingB_ = true;

        }



        LayoutA layoutA = LayoutA::template MakeLayout<ElementA>(config->m, config->k);

        isNeedPaddingA_ = IsNeedPadding(layoutA, alignByElement);

        LayoutA layoutB = LayoutB::template MakeLayout<ElementB>(config->k, config->n);

        isNeedPaddingB_ = IsNeedPadding(layoutB, alignByElement);

    }



    virtual Status CanImplement(void *argsPtr, void *configPtr) override

    {

        // check the operator's padding configuration is correct under the given problem shape

        BuildArgs(argsPtr, configPtr);



        if ((isThisKernelPaddingA_ != isNeedPaddingA_) || (isThisKernelPaddingB_ != isNeedPaddingB_)) {

            return Catlass::Status::kInvalid;

        }



        return this->op_.CanImplement(this->args_);

    }



    inline bool IsNeedPadding(Catlass::layout::RowMajor layout, uint32_t align)

    {

        // If the stride is greater than 65536, padding is required to reduce the stride.

        if (layout.stride(0) < 65536) {

            return layout.stride(0) % align != 0;

        } else {

            return true;

        }

    }



    inline bool IsNeedPadding(Catlass::layout::ColumnMajor layout, uint32_t align)

    {

        // If the stride is greater than 65536, padding is required to reduce the stride.

        if (layout.stride(1) < 65536) {

            return layout.stride(1) % align != 0;

        } else {

            return true;

        }

    }



    inline bool IsNeedPadding(Catlass::layout::zN layout, uint32_t align) {

        return false;

    }



    inline bool IsNeedPadding(Catlass::layout::nZ layout, uint32_t align) {

        return false;

    }



    bool isThisKernelPaddingA_{false};

    bool isThisKernelPaddingB_{false};

    bool isNeedPaddingA_{false};

    bool isNeedPaddingB_{false};

};

/********************* optimized matmul end *********************/



}

}



#endif // CATLASS_LIBRARY_GEMM_OPERATION_H