* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_mmad.h
* \brief
*/
#ifndef MATMUL_BLOCK_BLOCK_MMAD_H
#define MATMUL_BLOCK_BLOCK_MMAD_H
#include <type_traits>
#include "../utils/arch.h"
#include "../utils/integral_constant.h"
#include "../utils/matmul_layout_type.h"
#include "./block_mmad_utils.h"
namespace Cgmct {
namespace Gemm {
namespace Block {
* @class BlockMmad
* @brief Block matrix multiplication class for performing block matrix multiplication operations
*/
template <
class DispatchPolicy,
/// The shape of L1 tile
class L1TileShape,
/// The shape of L0 tile
class L0TileShape,
/// Type of matrix A
class AType,
/// Type of matrix B
class BType,
/// Type of matrix C
class CType,
/// Type of the bias term, defaulting to the same type of CType
class BiasType = CType,
class TileCopy = void,
typename = void
>
class BlockMmad {
static_assert(AscendC::Std::always_false_v<DispatchPolicy>, "BlockMmad is not implemented for this DispatchPolicy");
};
* @class BlockMmadBase
* @brief Base class of Block matrix multiplication class, serving as the base class for the CRTP pattern
*/
template <
class Derived,
/// The dispatch policy type
class DispatchPolicy_,
/// The shape of L1 tile
class L1TileShape,
/// The shape of L0 tile
class L0TileShape,
/// Type of matrix A
class AType_,
/// Type of matrix B
class BType_,
/// Type of matrix C
class CType_,
/// Type of the bias term
class BiasType_,
/// The tile copy strategy type
class TileCopy_
>
class BlockMmadBase {
public:
using DispatchPolicy = DispatchPolicy_;
using L1Shape = L1TileShape;
using L0Shape = L0TileShape;
using AType = AType_;
using BType = BType_;
using CType = CType_;
using BiasType = BiasType_;
using TileCopy = TileCopy_;
protected:
* @brief Obtain a reference to the derived class
*/
__aicore__ inline Derived& AsDerived()
{
return static_cast<Derived&>(*this);
}
public:
static_assert(IsTileShapeValid<AType, BType, L1Shape, L0Shape>(), "L1Shape or L0Shape is invalid");
};
}
}
}
#endif