* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#ifndef CATLASS_LIBRARY_GEMM_OPERATION_H
#define CATLASS_LIBRARY_GEMM_OPERATION_H
#include <type_traits>
#include "catlass/library/operation.h"
#include "library_utils.h"
namespace Catlass {
namespace Library {
template <typename Operator_>
class GemmOperationBase : public Operation {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using OperatorKernel = typename Operator::Kernel;
using ElementA = typename OperatorKernel::ElementA;
using ElementB = typename OperatorKernel::ElementB;
using ElementC = typename OperatorKernel::ElementC;
using LayoutA = typename OperatorKernel::LayoutA;
using LayoutB = typename OperatorKernel::LayoutB;
using LayoutC = typename OperatorKernel::LayoutC;
using BlockMmad = typename OperatorKernel::BlockMmad;
using ArchTag = typename OperatorKernel::ArchTag;
using L1TileShape = typename BlockMmad::L1TileShape;
using L0TileShape = typename BlockMmad::L0TileShape;
using BlockScheduler = typename OperatorKernel::BlockScheduler;
GemmOperationBase(char const *name = "")
{
this->description_.name = name;
this->description_.kind = OperationKind::Gemm;
this->description_.A = MakeTensorDescription<ElementA, LayoutA>();
this->description_.B = MakeTensorDescription<ElementB, LayoutB>();
this->description_.C = MakeTensorDescription<ElementC, LayoutC>();
this->description_.tileDescription.L1TileShape =
GemmShapeDescription(L1TileShape::M, L1TileShape::N, L1TileShape::K);
this->description_.tileDescription.L0TileShape =
GemmShapeDescription(L0TileShape::M, L0TileShape::N, L0TileShape::K);
}
virtual OperationDescription const &GetDescription() const override
{
return this->description_;
}
virtual Status CanImplement(void *argsPtr, void *configPtr) override
{
BuildArgs(argsPtr, configPtr);
return op_.CanImplement(this->args_);
}
virtual size_t GetWorkspaceSize(void *argsPtr, void *configPtr) override
{
BuildArgs(argsPtr, configPtr);
return op_.GetWorkspaceSize(this->args_);
}
virtual Status Initialize(
void *argsPtr,
void *configPtr,
uint8_t *workspace,
aclrtStream stream
) override
{
BuildArgs(argsPtr, configPtr);
return op_.Initialize(this->args_, workspace, stream);
}
virtual Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) override
{
return op_.Run(stream, blockDim, fftsAddr);
}
protected:
virtual void BuildArgs(void *argsPtr, void *configPtr) = 0;
GemmOperationDescription description_;
OperatorArguments args_{};
Operator op_;
};
template <typename Operator_>
class BasicMatmulGemmOperation : public GemmOperationBase<Operator_> {
public:
BasicMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)
{
this->description_.gemmKind = GemmKind::BasicMatmul;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;
BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
}
};
template <typename Operator_>
class GroupedMatmulGemmOperation : public GemmOperationBase<Operator_> {
public:
GroupedMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)
{
this->description_.gemmKind = GemmKind::GroupedMatmul;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
GroupedMatmulGemmArguments *arguments = (GroupedMatmulGemmArguments *)argsPtr;
GroupedMatmulGemmConfiguration *config = (GroupedMatmulGemmConfiguration *)configPtr;
this->args_.problemCount = config->groupCount;
this->args_.ptrProblemShape = arguments->problemShapeList;
this->args_.ptrA = arguments->A;
this->args_.ptrLayoutA = arguments->layoutAList;
this->args_.ptrB = arguments->B;
this->args_.ptrLayoutB = arguments->layoutBList;
this->args_.ptrC = arguments->C;
this->args_.ptrLayoutC = arguments->layoutCList;
}
};
template <typename Operator_>
class GroupedMatmulSliceMGemmOperation : public GemmOperationBase<Operator_> {
public:
GroupedMatmulSliceMGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)
{
this->description_.gemmKind = GemmKind::GroupedMatmulSliceM;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
GroupedMatmulSliceMGemmArguments *arguments = (GroupedMatmulSliceMGemmArguments *)argsPtr;
GroupedMatmulSliceMGemmConfiguration *config = (GroupedMatmulSliceMGemmConfiguration *)configPtr;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.problemCount = config->groupCount;
this->args_.ptrGroupList = arguments->deviceGroupList;
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
}
};
template <typename Operator_>
class OptimizedMatmulGemmOperation : public GemmOperationBase<Operator_> {
using Operator = Operator_;
using OperatorKernel = typename Operator::Kernel;
using ElementA = typename OperatorKernel::ElementA;
using ElementB = typename OperatorKernel::ElementB;
using LayoutA = typename OperatorKernel::LayoutA;
using LayoutB = typename OperatorKernel::LayoutB;
using L1TileShape = typename OperatorKernel::BlockMmad::L1TileShape;
public:
OptimizedMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_>(name)
{
this->description_.gemmKind = GemmKind::OptimizedMatmul;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;
BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;
constexpr uint32_t alignByByte = 512;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
constexpr uint32_t alignByElement = 512 / sizeof(half);
if constexpr (!std::is_same<typename OperatorKernel::LayoutWA, typename OperatorKernel::LayoutA>::value) {
isThisKernelPaddingA_ = true;
}
if constexpr (!std::is_same<typename OperatorKernel::LayoutWB, typename OperatorKernel::LayoutB>::value) {
isThisKernelPaddingB_ = true;
}
LayoutA layoutA = LayoutA::template MakeLayout<ElementA>(config->m, config->k);
isNeedPaddingA_ = IsNeedPadding(layoutA, alignByElement);
LayoutA layoutB = LayoutB::template MakeLayout<ElementB>(config->k, config->n);
isNeedPaddingB_ = IsNeedPadding(layoutB, alignByElement);
}
virtual Status CanImplement(void *argsPtr, void *configPtr) override
{
BuildArgs(argsPtr, configPtr);
if ((isThisKernelPaddingA_ != isNeedPaddingA_) || (isThisKernelPaddingB_ != isNeedPaddingB_)) {
return Catlass::Status::kInvalid;
}
return this->op_.CanImplement(this->args_);
}
inline bool IsNeedPadding(Catlass::layout::RowMajor layout, uint32_t align)
{
if (layout.stride(0) < 65536) {
return layout.stride(0) % align != 0;
} else {
return true;
}
}
inline bool IsNeedPadding(Catlass::layout::ColumnMajor layout, uint32_t align)
{
if (layout.stride(1) < 65536) {
return layout.stride(1) % align != 0;
} else {
return true;
}
}
inline bool IsNeedPadding(Catlass::layout::zN layout, uint32_t align) {
return false;
}
inline bool IsNeedPadding(Catlass::layout::nZ layout, uint32_t align) {
return false;
}
bool isThisKernelPaddingA_{false};
bool isThisKernelPaddingB_{false};
bool isNeedPaddingA_{false};
bool isNeedPaddingB_{false};
};
}
}
#endif