Block MMAD 代码开发详解
1. Block MMAD 概述
Block MMAD(Block Matrix Multiply-Add)是CATLASS模板库中负责块级矩阵乘法的核心组件,位于计算架构的中间层。它上接Kernel层,下接Tile层,负责将全局内存(GM)中的数据高效地加载到本地内存(L1/L0),并调度Tile级别的矩阵乘法运算。
Block MMAD采用高度模块化和模板化的设计,支持多种调度策略、Tile形状和数据类型,能够灵活适应不同的硬件架构和计算需求,本文将以BlockMmadPingpong为例详细讲解。
2. 模板组装机制
Block MMAD实现基于以下基础模板结构:
template <
class DispatchPolicy,
class L1TileShape,
class L0TileShape,
class AType,
class BType,
class CType,
class BiasType = void,
class TileCopy = Gemm::Tile::TileCopy<typename DispatchPolicy::ArchTag, AType, BType, CType, BiasType>,
class TileMmad = Gemm::Tile::TileMmad<typename DispatchPolicy::ArchTag, AType, BType, BiasType>
>
struct BlockMmad {
// 类型定义和核心实现
};
2.1 核心模板参数
| 参数名 | 描述 |
|---|---|
| DispatchPolicy | 调度策略,控制计算任务的分配和执行流程 |
| L1TileShape | L1缓存级别Tile的形状,定义M、N、K三个维度的大小 |
| L0TileShape | L0缓存级别Tile的形状,定义M、N、K三个维度的大小 |
| AType/BType/CType | 输入矩阵A、B和输出矩阵C的数据类型和布局信息 |
| BiasType | 偏置数据类型(可选) |
| TileCopy | Tile级别的数据拷贝组件,负责不同内存层级间的数据传输 |
| TileMmad | Tile级别的矩阵乘法组件,负责实际的计算操作 |
2.2 类型导出
Block MMAD通过类型导出系统建立统一的类型接口,方便上层组件使用:
public:
// Type Aliases
using DispatchPolicy = MmadAtlasA2Pingpong<ENABLE_UNIT_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using ElementA = typename AType_::Element;
using LayoutA = typename AType_::Layout;
using ElementB = typename BType_::Element;
using LayoutB = typename BType_::Layout;
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using TileMmad = TileMmad_;
using CopyGmToL1A = typename TileCopy_::CopyGmToL1A;
using CopyGmToL1B = typename TileCopy_::CopyGmToL1B;
using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A;
using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B;
using CopyL0CToGm = typename TileCopy_::CopyL0CToGm;
using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;
using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc;
using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc;
using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst;
using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst;
using LayoutCInL0 = layout::zN;
3. 内存管理与缓存设计
Block MMAD负责管理不同层级的内存,包括全局内存(GM)、L1缓存和L0缓存:
3.1 静态常量定义
static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG;
static constexpr uint32_t STAGES = DispatchPolicy::STAGES;
static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA);
static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB);
static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE;
static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE;
static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE;
static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES;
static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES;
3.2 内存检查
// Check LayoutC
static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!");
// Check L1TileShape
static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!");
// Check L0TileShape
static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA);
static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB);
static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator);
static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!");
static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!");
static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!");
3.3 多阶段缓存设计
Block MMAD采用多阶段流水线设计,使用pingpong技术隐藏内存访问延迟:
protected:
// Multi-stage tensors list
AscendC::LocalTensor<ElementA> l1ATensorList[STAGES];
AscendC::LocalTensor<ElementB> l1BTensorList[STAGES];
AscendC::LocalTensor<ElementA> l0ATensorList[STAGES];
AscendC::LocalTensor<ElementB> l0BTensorList[STAGES];
AscendC::LocalTensor<ElementAccumulator> l0CTensor;
// Multi-stage event id list
int32_t l1AEventList[STAGES];
int32_t l1BEventList[STAGES];
int32_t l0AEventList[STAGES];
int32_t l0BEventList[STAGES];
// The id of current stage
uint32_t l1ListId{0};
uint32_t l0AListId{0};
uint32_t l0BListId{0};
4. 核心接口实现
4.1 构造函数
CATLASS_DEVICE
BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0)
{
uint32_t l1AOffset = l1BufAddrStart;
uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES;
// Init buffers
for (uint32_t i = 0; i < STAGES; i++) {
// Assign L1/L0A/L0B space for each stages
l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_SIZE * i);
l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_SIZE * i);
l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_PINGPONG_BUF_SIZE * i);
l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_PINGPONG_BUF_SIZE * i);
// Assign event ID for each stages
l1AEventList[i] = i;
l1BEventList[i] = i + STAGES;
l0AEventList[i] = i;
l0BEventList[i] = i + STAGES;
// The event id that needs to be set before the loop
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
l0CTensor = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(0);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
}
4.2 析构函数
CATLASS_DEVICE
~BlockMmad()
{
for (uint32_t i = 0; i < STAGES; i++) {
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]);
AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]);
}
AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
}
4.3 核心计算接口 operator()
operator()是Block MMAD的核心接口,负责执行块级矩阵乘法:
CATLASS_DEVICE
void operator()(
AscendC::GlobalTensor<ElementA> const &gmA, LayoutA const &layoutA,
AscendC::GlobalTensor<ElementB> const &gmB, LayoutB const &layoutB,
AscendC::GlobalTensor<ElementC> const &gmC, LayoutC const &layoutC,
GemmCoord const &actualShape)
{
// 1. 初始化参数和布局
// 2. 预加载第一批数据到L1缓存
// 3. 主循环:
// a. 预加载下一批数据到L1缓存
// b. 将当前L1缓存中的数据加载到L0缓存
// c. 执行Tile级矩阵乘法
// d. 管理多阶段流水线
// 4. 将结果从L0缓存写回全局内存
}
5. 执行流程分析
以BlockMmadPingpong为例,其执行流程如下:
5.1 数据预加载
// load first matrix A tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual));
copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);
// load first matrix B tile from GM to L1
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]);
auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n()));
copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);
5.2 主循环(多阶段流水线)
// main loop
uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k());
for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) {
uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0;
uint32_t kActualNext{0};
// 预加载下一批数据到L1缓存
if (kLoopIdx < kTileCount - 1) {
// ... 预加载逻辑 ...
}
// 处理当前L1缓存中的数据
// Get L1 tensor for current stage
auto l1ATensor = l1ATensorList[l1ListId];
auto l1BTensor = l1BTensorList[l1ListId];
// 将L1缓存中的数据加载到L0缓存并执行计算
// ... L0级处理逻辑 ...
// 切换到下一个阶段
l1ListId = l1ListIdNext;
kActual = kActualNext;
}
5.3 L0级处理与计算
for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) {
// ... M维度循环 ...
for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) {
// ... K维度循环 ...
for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) {
// ... N维度循环 ...
// 执行Tile级矩阵乘法
bool initC = ((kLoopIdx == 0) && (kPartIdx == 0));
uint8_t unitFlag = 0b00;
// ... unitFlag设置 ...
tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag);
}
}
}
5.4 结果写回
// copy block out
LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN());
if constexpr (!ENABLE_UNIT_FLAG) {
AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0);
copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C);
AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0);
} else {
copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11);
}
6. 多阶段流水线与事件同步
Block MMAD使用事件驱动的同步机制确保多阶段流水线的正确执行:
// 等待事件
AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]);
// 执行操作
copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA);
// 触发事件
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);
7. 不同Block MMAD实现的共性与差异
CATLASS提供了多种Block MMAD实现,它们具有以下共性:
- 统一的模板接口:所有实现都基于相同的模板参数结构
- 多阶段流水线设计:都采用多阶段设计来隐藏内存访问延迟
- 事件驱动同步:都使用事件机制确保数据传输和计算的正确顺序
- 内存层级管理:都负责管理GM、L1和L0之间的数据传输
不同实现的主要差异在于:
- 调度策略:如pingpong、preload、streamk等不同的调度方式
- 硬件适配:针对不同硬件架构的优化实现
- 特殊功能支持:如量化、稀疏计算、偏置处理等
- 性能优化:不同的流水线深度和内存访问模式
8. 总结
Block MMAD是CATLASS模板库中连接Kernel层和Tile层的关键组件,它通过以下设计实现了高效的块级矩阵乘法:
- 高度模块化:通过模板参数组装不同的功能组件
- 多阶段流水线:使用pingpong技术隐藏内存访问延迟
- 事件驱动同步:确保数据传输和计算的正确顺序
- 自动内存管理:高效管理不同层级的内存资源
- 灵活适配:支持不同的硬件架构和计算需求
Block MMAD的设计体现了现代高性能计算库的核心思想,通过精心的流水线设计和内存管理,充分发挥硬件的计算能力,实现高效的矩阵乘法运算。