Block Scheduler 公共框架
基础接口定义,所有 BlockScheduler 实现需遵循此框架
概述
BlockScheduler 负责 tile 切分、Block 分配、Z 型扫描、尾块处理等调度任务。不同 Kernel 类型使用不同的 Scheduler 实现: 详见:README.md 查看 API 清单和实现对比。
模板参数
通用模板参数
template <class ProblemShape_, ...>
class BlockScheduler;
| 参数 | 类型 | 说明 |
|---|---|---|
| ProblemShape_ | Shape<int64_t, int64_t, int64_t, int64_t> | 问题规模 (m, n, k, batch) |
特有模板参数
| Scheduler | 特有参数 | 说明 |
|---|---|---|
| BlockSchedulerMatmulBasic | FullLoadMode_ | 全载模式(0/1/2) |
| BlockSchedulerStreamK | - | - |
| BlockSchedulerQuantBatchMatmulV3 | FullLoadMode_, LayoutA_, LayoutB_, AType_ | 全载模式、布局、数据类型 |
类型别名
通用类型
using BlockShape = Shape<int64_t, int64_t, int64_t, int64_t>;
using ProblemShape = ProblemShape_;
// BlockSchedulerMatmulBasic / BlockSchedulerStreamK
using BlockCoord = Coord<int64_t, int64_t, int64_t, int64_t>;
// BlockSchedulerQuantBatchMatmulV3
using BlockCoord = Coord<int64_t, int64_t>;
| 类型 | 说明 |
|---|---|
| BlockShape | Block 形状 (m, n, k, batch) |
| BlockCoord | Block 坐标,Matmul Basic/StreamK 为 (mTileIdx, nTileIdx, kTileIdx, batchIdx);QuantBatchMatmulV3 为 Coord<int64_t, int64_t>,仅包含 (mTileIdx, nTileIdx) |
| ProblemShape | 问题规模类型 |
Params 结构体
BlockSchedulerMatmulBasic::Params
详见 BlockSchedulerMatmulBasic 参数详解
构造函数
通用构造函数接口
__aicore__ inline BlockScheduler(const ProblemShape& shape, const Params& params)
| 参数 | 类型 | 说明 |
|---|---|---|
| shape | ProblemShape | 问题规模 (m, n, k, batch) |
| params | Params | 调度参数 |
通用成员方法
GetTileNum / GetTotalTileNum
__aicore__ inline int64_t GetTileNum(); // BlockSchedulerMatmulBasic
__aicore__ inline int64_t GetTotalTileNum(); // BlockSchedulerStreamK
功能:返回总 tile 数量(含 batch)。
GetBlockNum
__aicore__ inline int64_t GetBlockNum(ProblemShape shape, int64_t blockNum)
功能:返回实际使用的 Block 数量(不超过 tile 总数)。
GetBlockShape
__aicore__ inline BlockShape GetBlockShape(int64_t tileIdx, ...); // BlockSchedulerMatmulBasic
__aicore__ inline BlockShape GetBlockShape(int64_t tileIdx); // BlockSchedulerStreamK
template <QuantMode aQuantMode, QuantMode bQuantMode, bool weightNz = false>
__aicore__ inline BlockShape GetBlockShape(BlockCoord blockCoord); // BlockSchedulerQuantBatchMatmulV3
功能:返回当前 tile 的 Block 形状。
QuantBatchMatmulV3 的 BlockShape 第 3、4 个字段用于携带 M/N 尾块切分偏移。
GetBlockCoord / GetTileIdx
__aicore__ inline BlockCoord GetBlockCoord(int64_t tileIdx); // BlockSchedulerMatmulBasic
__aicore__ inline BlockCoord GetBlockCoord(int64_t tileIdx); // BlockSchedulerStreamK
__aicore__ inline bool GetTileIdx(BlockCoord& blockCoord); // BlockSchedulerQuantBatchMatmulV3
功能:返回当前 tile 的 Block 坐标。
GetTileCoord
__aicore__ inline void GetTileCoord(BlockCoord blockCoord, int64_t& mPos, int64_t& nPos);
功能:QuantBatchMatmulV3 根据 2D Block 坐标和 Scheduler 内部记录的尾块切分偏移计算 GM 地址偏移。
Z 型扫描
扫描策略
所有 Scheduler 使用 Z 型扫描策略:
- WINDOW_LEN = 4:扫描窗口大小
- 正向扫描:偶数行(rowIdx % 2 == 0)
- 反向扫描:奇数行(rowIdx % 2 != 0)
扫描示意
Z 型扫描示意(mTileNum=4, nTileNum=4)
N轴 →
+--+--+--+--+
|0 |1 |2 |3 | Row 0(正向)
M +--+--+--+--+
轴 |7 |6 |5 |4 | Row 1(反向)
↓ +--+--+--+--+
|8 |9 |10|11| Row 2(正向)
+--+--+--+--+
|15|14|13|12| Row 3(反向)
+--+--+--+--+
扫描顺序:0→1→2→3→7→6→5→4→8→9→10→11→15→14→13→12
扫描实现
// 奇数行反向扫描
if (rowIdx % 2 != 0) {
nTileIdx = nTileNum - 1 - nTileIdx;
}
尾块处理
尾块判断
| Scheduler | 尾块判断 |
|---|---|
| BlockSchedulerMatmulBasic | mTileIdx >= mL1NormCnt_ 或 nTileIdx >= nL1NormCnt_ |
| BlockSchedulerStreamK | mTileIdx == (mTileNum - 1) 或 nTileIdx == (nTileNum - 1) |
| BlockSchedulerQuantBatchMatmulV3 | mTileIdx >= mBaseNormCnt_ 或 nTileIdx >= nBaseNormCnt_ |
尾块切分
| Scheduler | 尾块切分支持 |
|---|---|
| BlockSchedulerMatmulBasic | 支持(mTailCnt, nTailCnt) |
| BlockSchedulerStreamK | 支持(K 轴尾块) |
| BlockSchedulerQuantBatchMatmulV3 | 支持(mTailTile, nTailTile) |
数据流
通用调度流程
问题规模 (m, n, k, batch)
↓
tile 切分 (mTileNum, nTileNum, kTileNum)
↓
尾块参数计算
↓
Block 分配 (blockIdx, blockNum)
↓
Z 型扫描 (rowIdx, mTileIdx, nTileIdx)
↓
Block 形状/坐标 (BlockShape, BlockCoord)
↓
BlockMmad 执行
调用示例
Matmul Basic Scheduler
using ProblemShape = Shape<int64_t, int64_t, int64_t, int64_t>;
using BlockScheduler = Blaze::Gemm::Block::BlockSchedulerMatmulBasic<ProblemShape, 0>;
BlockScheduler::Params params = {
.mL1 = 256, .nL1 = 256, .kL1 = 128,
.baseM = 128, .baseN = 128, .baseK = 64,
// ...
};
ProblemShape shape{m, n, k, batch};
BlockScheduler scheduler(shape, params);
int64_t blockIdx = AscendC::GetBlockIdx();
int64_t blockNum = scheduler.GetBlockNum(shape);
for (int64_t tileIdx = blockIdx; tileIdx < scheduler.GetTileNum(); tileIdx += blockNum) {
auto blockShape = scheduler.GetBlockShape<transB, BType>(tileIdx);
auto blockCoord = scheduler.GetBlockCoord(tileIdx);
// BlockMmad 计算
}