* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CATLASS_LIBRARY_GEMM_OPERATION_H
#define CATLASS_LIBRARY_GEMM_OPERATION_H
#include <type_traits>
#include "catlass/library/operation.h"
#include "tla/layout.hpp"
#include "library_utils.h"
namespace Catlass {
namespace Library {
template <typename Operator_, typename Description_>
class GemmOperationBase : public Operation {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using OperatorKernel = typename Operator::Kernel;
using ElementA = typename OperatorKernel::ElementA;
using ElementB = typename OperatorKernel::ElementB;
using ElementC = typename OperatorKernel::ElementC;
using LayoutA = typename OperatorKernel::LayoutA;
using LayoutB = typename OperatorKernel::LayoutB;
using LayoutC = typename OperatorKernel::LayoutC;
using BlockMmad = typename OperatorKernel::BlockMmad;
using ArchTag = typename OperatorKernel::ArchTag;
using L1TileShape = typename BlockMmad::L1TileShape;
using L0TileShape = typename BlockMmad::L0TileShape;
using BlockScheduler = typename OperatorKernel::BlockScheduler;
GemmOperationBase(char const *name = "")
{
this->description_.name = name;
this->description_.kind = OperationKind::Gemm;
if constexpr (std::is_same<LayoutA, Catlass::layout::RowMajor>::value ||
std::is_same<LayoutA, Catlass::layout::ColumnMajor>::value ||
std::is_same<LayoutA, Catlass::layout::nZ>::value ||
std::is_same<LayoutA, Catlass::layout::zN>::value ||
std::is_same<LayoutA, Catlass::layout::zZ>::value ||
std::is_same<LayoutA, Catlass::layout::PaddingRowMajor>::value ||
std::is_same<LayoutA, Catlass::layout::PaddingColumnMajor>::value ||
std::is_same<LayoutA, Catlass::layout::nN>::value ||
std::is_same<LayoutA, Catlass::layout::NDC1HWC0>::value ||
std::is_same<LayoutA, Catlass::layout::KDC1KHKWN1N0C0>::value ||
std::is_same<LayoutA, Catlass::layout::VectorLayout>::value ||
std::is_same<LayoutA, Catlass::layout::PaddingRowMajor>::value) {
this->description_.A = MakeTensorDescription<ElementA, LayoutA>();
this->description_.B = MakeTensorDescription<ElementB, LayoutB>();
this->description_.C = MakeTensorDescription<ElementC, LayoutC>();
}
if constexpr (std::is_same<L1TileShape, Catlass::GemmCoord>::value) {
this->description_.tileDescription.L1TileShape =
GemmShapeDescription(L1TileShape::M, L1TileShape::N, L1TileShape::K);
this->description_.tileDescription.L0TileShape =
GemmShapeDescription(L0TileShape::M, L0TileShape::N, L0TileShape::K);
}
}
virtual OperationDescription const &GetDescription() const override
{
return this->description_;
}
virtual Status CanImplement(void *argsPtr, void *configPtr) override
{
BuildArgs(argsPtr, configPtr);
return op_.CanImplement(this->args_);
}
virtual size_t GetWorkspaceSize(void *argsPtr, void *configPtr) override
{
BuildArgs(argsPtr, configPtr);
return op_.GetWorkspaceSize(this->args_);
}
virtual Status Initialize(
void *argsPtr,
void *configPtr,
uint8_t *workspace,
aclrtStream stream
) override
{
BuildArgs(argsPtr, configPtr);
return op_.Initialize(this->args_, workspace, stream);
}
virtual Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) override
{
return op_.Run(stream, blockDim, fftsAddr);
}
protected:
virtual void BuildArgs(void *argsPtr, void *configPtr) = 0;
Description_ description_;
OperatorArguments args_{};
Operator op_;
};
template <typename Operator_>
class BasicMatmulGemmOperation : public GemmOperationBase<Operator_, GemmOperationDescription> {
public:
BasicMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_, GemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::BasicMatmul;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;
BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
}
};
template <typename Operator_>
class GroupedMatmulGemmOperation : public GemmOperationBase<Operator_, GemmOperationDescription> {
public:
GroupedMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_, GemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::GroupedMatmul;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
GroupedMatmulGemmArguments *arguments = (GroupedMatmulGemmArguments *)argsPtr;
GroupedMatmulGemmConfiguration *config = (GroupedMatmulGemmConfiguration *)configPtr;
this->args_.problemCount = config->groupCount;
this->args_.ptrProblemShape = arguments->problemShapeList;
this->args_.ptrA = arguments->A;
this->args_.ptrLayoutA = arguments->layoutAList;
this->args_.ptrB = arguments->B;
this->args_.ptrLayoutB = arguments->layoutBList;
this->args_.ptrC = arguments->C;
this->args_.ptrLayoutC = arguments->layoutCList;
}
};
template <typename Operator_>
class GroupedMatmulSliceMGemmOperation : public GemmOperationBase<Operator_, GemmOperationDescription> {
public:
GroupedMatmulSliceMGemmOperation(char const *name = "") : GemmOperationBase<Operator_, GemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::GroupedMatmulSliceM;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
GroupedMatmulSliceMGemmArguments *arguments = (GroupedMatmulSliceMGemmArguments *)argsPtr;
GroupedMatmulSliceMGemmConfiguration *config = (GroupedMatmulSliceMGemmConfiguration *)configPtr;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.problemCount = config->groupCount;
this->args_.ptrGroupList = arguments->deviceGroupList;
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
}
};
template <typename Operator_>
class OptimizedMatmulGemmOperation : public GemmOperationBase<Operator_, GemmOperationDescription> {
using Operator = Operator_;
using OperatorKernel = typename Operator::Kernel;
using ElementA = typename OperatorKernel::ElementA;
using ElementB = typename OperatorKernel::ElementB;
using LayoutA = typename OperatorKernel::LayoutA;
using LayoutB = typename OperatorKernel::LayoutB;
using L1TileShape = typename OperatorKernel::BlockMmad::L1TileShape;
public:
OptimizedMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_, GemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::OptimizedMatmul;
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;
BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;
constexpr uint32_t alignByByte = 512;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
constexpr uint32_t alignByElement = 512 / sizeof(half);
if constexpr (!std::is_same<typename OperatorKernel::LayoutWA, typename OperatorKernel::LayoutA>::value) {
isThisKernelPaddingA_ = true;
}
if constexpr (!std::is_same<typename OperatorKernel::LayoutWB, typename OperatorKernel::LayoutB>::value) {
isThisKernelPaddingB_ = true;
}
LayoutA layoutA = LayoutA::template MakeLayout<ElementA>(config->m, config->k);
isNeedPaddingA_ = IsNeedPadding(layoutA, alignByElement);
LayoutA layoutB = LayoutB::template MakeLayout<ElementB>(config->k, config->n);
isNeedPaddingB_ = IsNeedPadding(layoutB, alignByElement);
}
virtual Status CanImplement(void *argsPtr, void *configPtr) override
{
BuildArgs(argsPtr, configPtr);
if ((isThisKernelPaddingA_ != isNeedPaddingA_) || (isThisKernelPaddingB_ != isNeedPaddingB_)) {
return Catlass::Status::kInvalid;
}
return this->op_.CanImplement(this->args_);
}
inline bool IsNeedPadding(Catlass::layout::RowMajor layout, uint32_t align)
{
if (layout.stride(0) < 65536) {
return layout.stride(0) % align != 0;
} else {
return true;
}
}
inline bool IsNeedPadding(Catlass::layout::ColumnMajor layout, uint32_t align)
{
if (layout.stride(1) < 65536) {
return layout.stride(1) % align != 0;
} else {
return true;
}
}
inline bool IsNeedPadding(Catlass::layout::zN layout, uint32_t align) {
return false;
}
inline bool IsNeedPadding(Catlass::layout::nZ layout, uint32_t align) {
return false;
}
bool isThisKernelPaddingA_{false};
bool isThisKernelPaddingB_{false};
bool isNeedPaddingA_{false};
bool isNeedPaddingB_{false};
};
template <typename Operator_>
class QuantMatmulGemmOperation : public GemmOperationBase<Operator_, QuantMatmulGemmOperationDescription> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using OperatorKernel = typename Operator::Kernel;
using ElementD = typename OperatorKernel::ElementD;
using ElementScale = typename OperatorKernel::ElementScale;
using ElementPerTokenScale = typename OperatorKernel::ElementPerTokenScale;
using LayoutD = typename OperatorKernel::LayoutD;
using LayoutScale = typename OperatorKernel::LayoutScale;
using LayoutPerTokenScale = typename OperatorKernel::LayoutPerTokenScale;
QuantMatmulGemmOperation(char const *name = "") : GemmOperationBase<Operator_, QuantMatmulGemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::QuantMatmul;
this->description_.D = MakeTensorDescription<ElementD, LayoutD>();
this->description_.Scale = MakeTensorDescription<ElementScale, LayoutScale>();
this->description_.PerTokenScale = MakeTensorDescription<ElementPerTokenScale, LayoutPerTokenScale>();
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
QuantMatmulGemmArguments *arguments = (QuantMatmulGemmArguments *)argsPtr;
QuantMatmulGemmConfiguration *config = (QuantMatmulGemmConfiguration *)configPtr;
this->args_.problemShape = arguments->problemShape;
this->args_.aicCoreNum = arguments->aicCoreNum;
this->args_.ptrA = arguments->ptrA;
this->args_.ptrB = arguments->ptrB;
this->args_.ptrD = arguments->ptrD;
this->args_.ptrScale = arguments->ptrScale;
this->args_.ptrPerTokenScale = arguments->ptrPerTokenScale;
}
private:
};
template <typename Operator_>
class BasicMatmul950GemmOperation : public GemmOperationBase<Operator_, GemmOperationDescription> {
using Operator = Operator_;
using OperatorKernel = typename Operator::Kernel;
using ElementA = typename OperatorKernel::ElementA;
using ElementB = typename OperatorKernel::ElementB;
using ElementC = typename OperatorKernel::ElementC;
using LayoutTagA = typename OperatorKernel::BlockMmad::TileCopy::LayoutTagA;
using LayoutTagB = typename OperatorKernel::BlockMmad::TileCopy::LayoutTagB;
using LayoutTagC = typename OperatorKernel::BlockMmad::TileCopy::LayoutTagC;
using L1TileShape = typename OperatorKernel::BlockMmad::L1TileShape;
using L0TileShape = typename OperatorKernel::BlockMmad::L0TileShape;
public:
BasicMatmul950GemmOperation(char const *name = "") : GemmOperationBase<Operator_, GemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::BasicMatmulTla950;
this->description_.A = MakeTensorDescription<ElementA, LayoutTagA>();
this->description_.B = MakeTensorDescription<ElementB, LayoutTagB>();
this->description_.C = MakeTensorDescription<ElementC, LayoutTagC>();
static constexpr uint32_t L1_TILE_M = tla::get<0>(L1TileShape{});
static constexpr uint32_t L1_TILE_N = tla::get<1>(L1TileShape{});
static constexpr uint32_t L1_TILE_K = tla::get<2>(L1TileShape{});
static constexpr uint32_t L0_TILE_M = tla::get<0>(L0TileShape{});
static constexpr uint32_t L0_TILE_N = tla::get<1>(L0TileShape{});
static constexpr uint32_t L0_TILE_K = tla::get<2>(L0TileShape{});
this->description_.tileDescription.L1TileShape =
GemmShapeDescription(L1_TILE_M, L1_TILE_N, L1_TILE_K);
this->description_.tileDescription.L0TileShape =
GemmShapeDescription(L0_TILE_M, L0_TILE_N, L0_TILE_K);
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
BasicMatmulGemmArguments *arguments = (BasicMatmulGemmArguments *)argsPtr;
BasicMatmulGemmConfiguration *config = (BasicMatmulGemmConfiguration *)configPtr;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.ptrA = arguments->A;
this->args_.ptrB = arguments->B;
this->args_.ptrC = arguments->C;
this->args_.ptrBias = nullptr;
auto m = config->m;
auto n = config->n;
auto k = config->k;
using LayoutTagA = layout::RowMajor;
using LayoutTagB = layout::RowMajor;
using LayoutTagC = layout::RowMajor;
using ElementA = float32_t;
using ElementB = float32_t;
this->args_.layoutA = tla::MakeLayout<ElementA, LayoutTagA>(m, k);
this->args_.layoutB = tla::MakeLayout<ElementB, LayoutTagB>(k, n);
this->args_.layoutC = tla::MakeLayout<ElementC, LayoutTagC>(m, n);
}
virtual Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) override
{
return this->op_.Run(stream, blockDim, 0U);
}
};
template <typename Operator_>
class MatmulGeluGemmOperation : public GemmOperationBase<Operator_, MatmulGeluGemmOperationDescription> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using OperatorKernel = typename Operator::Kernel;
using ElementD = typename OperatorKernel::ElementD;
using LayoutD = typename OperatorKernel::LayoutD;
MatmulGeluGemmOperation(char const *name = "") : GemmOperationBase<Operator_, MatmulGeluGemmOperationDescription>(name)
{
this->description_.gemmKind = GemmKind::MatmulGelu;
this->description_.D = MakeTensorDescription<ElementD, LayoutD>();
}
private:
virtual void BuildArgs(void *argsPtr, void *configPtr) override
{
MatmulGeluGemmArguments *arguments = (MatmulGeluGemmArguments *)argsPtr;
MatmulGeluGemmConfiguration *config = (MatmulGeluGemmConfiguration *)configPtr;
this->args_.problemShape = GemmCoord{config->m, config->n, config->k};
this->args_.elementSize = config->elementSize;
this->args_.ptrA = arguments->ptrA;
this->args_.ptrB = arguments->ptrB;
this->args_.ptrD = arguments->ptrD;
}
};
}
}
#endif