Block Scheduler 代码开发详解
1. Block Scheduler 概述
Block Scheduler(块调度器)是CATLASS模板库中负责块级任务调度的核心组件,位于Block层,主要负责管理和分配矩阵乘法计算任务到不同的计算单元。它通过优化任务分配顺序和数据访问模式,提高计算资源利用率和缓存命中率,从而提升整体性能。
Block Scheduler采用模板化设计,支持多种调度策略和硬件架构,主要包括:
- 窗口化调度(Window-based Scheduling)
- 蛇形扫描(Snake Scanning)
- 动态任务分配(Dynamic Task Allocation)
本文将以GemmIdentityBlockSwizzle为例,详细讲解Block Swizzle的代码结构、主要接口和调度策略。
2. 模板组装机制
GemmIdentityBlockSwizzle实现基于以下基础模板结构:
template <uint32_t SwizzleOffset = 1, uint32_t SwizzleDirection = 0>
struct GemmIdentityBlockSwizzle {
// 核心实现
};
2.1 核心模板参数
| 参数名 | 描述 |
|---|---|
| SwizzleOffset | 窗口化调度的偏移量,默认为1 |
| SwizzleDirection | 调度方向,0表示按Zn方向,1表示按Nz方向,默认为0 |
这些模板参数与调度策略算法具体实现相关,允许用户灵活配置调度策略,以适应不同的硬件架构和性能需求。
3. 核心数据结构
GemmIdentityBlockSwizzle包含以下核心数据成员,用于维护调度状态和计算参数:
/// Data members
GemmCoord problemShape; // 问题形状,包含M、N、K三个维度的大小
MatrixCoord tileMN; // Tile形状,包含M、N两个维度的大小
MatrixCoord loopsMN; // 循环次数,包含M、N两个维度的循环次数
这些数据成员共同构成了调度器的状态,用于跟踪当前的计算进度和参数。
4. 主要接口说明
4.1 构造函数
CATLASS_DEVICE
GemmIdentityBlockSwizzle() {}
CATLASS_DEVICE
GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_)
: problemShape(problemShape_), tileMN(tileMN_)
{
loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN);
}
CATLASS_DEVICE
GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_,
MatrixCoord const &loopsMN_) : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) {}
GemmIdentityBlockSwizzle提供了三个构造函数,分别用于默认初始化、根据问题形状和Tile形状初始化,以及直接指定循环次数的初始化。
4.2 Update方法
CATLASS_DEVICE
void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_)
{
problemShape = problemShape_;
tileMN = tileMN_;
loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN);
}
CATLASS_DEVICE
void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_)
{
problemShape = problemShape_;
tileMN = tileMN_;
loopsMN = loopsMN_;
}
Update方法用于更新调度器的参数,支持动态调整问题形状、Tile形状和循环次数。
4.3 GetCoreLoops方法
CATLASS_DEVICE
uint32_t GetCoreLoops() const
{
return loopsMN.row() * loopsMN.column();
}
该方法返回核心循环次数,即M维度循环次数乘以N维度循环次数,表示总共有多少个块需要处理。
4.4 GetBatchIdx方法
CATLASS_DEVICE
uint32_t GetBatchIdx(uint32_t taskIdx)
{
return taskIdx / (GetCoreLoops());
}
该方法根据任务索引获取批次索引,用于支持批处理模式。
4.5 GetBlockCoord方法
CATLASS_DEVICE
GemmCoord GetBlockCoord(uint32_t taskIdx)
{
uint32_t innerIdx = taskIdx % GetCoreLoops();
if constexpr (SwizzleDirection == 0) { // Zn
uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset);
uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column());
uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column());
uint32_t nRow = SwizzleOffset;
if (tileBlockIdx == tileBlockLoop - 1) {
nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx;
}
uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow;
uint32_t nIdx = inTileBlockIdx / nRow;
if (tileBlockIdx % 2 == 1) {
nIdx = loopsMN.column() - nIdx - 1;
}
return GemmCoord{mIdx, nIdx, 0};
} else if constexpr (SwizzleDirection == 1) { // Nz
// 类似的Nz方向处理逻辑
}
}
GetBlockCoord是GemmIdentityBlockSwizzle的核心方法,根据任务索引计算块的坐标。它支持两种调度方向:
- Zn方向:按行分块,每块包含SwizzleOffset行
- Nz方向:按列分块,每块包含SwizzleOffset列
该方法还实现了蛇形扫描策略,当块索引为奇数时,反转N维度或M维度的索引,以优化内存访问模式。
4.6 GetActualBlockShape方法
CATLASS_DEVICE
GemmCoord GetActualBlockShape(GemmCoord blockCoord)
{
uint32_t mActual = (blockCoord.m() == (loopsMN.row() - 1)) ?
(problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row();
uint32_t nActual = (blockCoord.n() == (loopsMN.column() - 1)) ?
(problemShape.n() - blockCoord.n() * tileMN.column()) : tileMN.column();
uint32_t kActual = problemShape.k();
return GemmCoord{mActual, nActual, kActual};
}
该方法根据块坐标计算块的实际大小,特别是处理尾部块的边界情况。当块是最后一行或最后一列时,需要调整块的大小以适应实际的问题规模。
5. 调度策略分析
5.1 窗口化调度
GemmIdentityBlockSwizzle采用窗口化调度策略,将M维度或N维度的Tile划分为多个窗口进行处理:
uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset);
uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column());
窗口化调度可以提高缓存命中率,减少缓存抖动,特别是在处理大规模矩阵时效果显著。
5.2 蛇形扫描
在GetBlockCoord方法中实现了蛇形扫描策略:
if (tileBlockIdx % 2 == 1) {
nIdx = loopsMN.column() - nIdx - 1;
}
蛇形扫描可以优化内存访问模式,减少内存带宽压力,提高数据加载效率。
5.3 尾部块处理
调度器对尾部块进行特殊处理,确保计算的正确性:
uint32_t mActual = (blockCoord.m() == (loopsMN.row() - 1)) ?
(problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row();
对于不完整的尾部块,调度器会调整块大小,确保计算的正确性。
6. 动态调度扩展
GemmIdentityBlockSwizzle还提供了动态调度的扩展版本DynamicGemmIdentityBlockSwizzle:
struct DynamicGemmIdentityBlockSwizzle : public GemmIdentityBlockSwizzle<>
{
uint32_t swizzleOffset{1};
uint32_t swizzleDirection{0};
// 构造函数和方法
CATLASS_DEVICE
void SetSwizzleParams(uint32_t swizzleOffset_, uint32_t swizzleDirection_)
{
swizzleOffset = swizzleOffset_;
swizzleDirection = swizzleDirection_;
}
// 重写的GetBlockCoord方法
};
DynamicGemmIdentityBlockSwizzle允许在运行时动态调整SwizzleOffset和SwizzleDirection参数,提供了更灵活的调度控制。
7. 执行流程分析
Block Scheduler的典型执行流程如下:
- 初始化:创建调度器实例,设置问题形状和Tile形状
- 计算循环次数:根据问题形状和Tile形状计算M和N维度的循环次数
- 任务分配:
- 根据任务索引获取批次索引(如果支持批处理)
- 根据任务索引计算块坐标
- 根据块坐标计算块的实际大小
8. 总结
Block Scheduler是CATLASS模板库中负责块级任务调度的核心组件,通过窗口化调度、蛇形扫描、尾部块处理等策略,优化内存访问模式和缓存命中率,从而提升整体性能。
Block Scheduler采用模板化设计,支持多种调度策略和硬件架构,能够灵活适应不同的计算需求。它还提供了动态调度的扩展版本,允许在运行时调整调度参数。
通过合理配置Block Scheduler的参数,可以显著提升矩阵乘法的性能,特别是在处理大规模矩阵时效果更加明显。