使用方式
-
编译项目
在shmem/根目录下执行编译脚本:bash scripts/build.sh -examples -
运行KV_Shuffle示例程序
进入示例目录并执行运行脚本:cd examples/kv_shuffle bash scripts/run.sh [pe_size]- 参数说明:
pe_size:指定算子运行的pe个数。- 示例:使用第0和第1个NPU设备运行2卡kv_shuffle示例:
bash scripts/run.sh 2
- 参数说明:
算子介绍
KV Shuffle 算子核心功能是实现 KV Cache(键值缓存)的跨设备 / 跨 PE 数据重排与远程拷贝,适配大模型训练 / 推理中 KV Cache 的分布式调度需求。
在大模型分布式训练 / 推理场景中,KV Cache 会按 Block(块)管理,不同计算 PE 之间需要根据调度策略(如 shuffle table)对 KV Cache 的 Block 进行重排、迁移,本算子为该场景提供高效 KV Block 跨 PE 拷贝与重映射,相比传统主机侧调度,大幅降低 KV Cache 迁移的延迟和带宽开销。
cpp接口
class KVShuffleOps {
public:
// 默认构造函数
KVShuffleOps(uint32_t block_dims, void* stream);
~KVShuffleOps();
// 接受 Tensor 的函数
void compute(
uint8_t* k_cache,
uint8_t* v_cache,
uint8_t* global_shuffle_table,
uint8_t* src_block_table,
uint8_t* dst_block_table,
int64_t block_nums,
int64_t kv_head_num, int64_t page_size, int64_t head_dim);
private:
void* sync_ptr_;
int32_t count_;
uint32_t block_dims_;
void* stream_;
uint64_t fftsAddr_;
};
接口参数说明
| 参数名 | 输入/输出 | 描述 |
|---|---|---|
| uint8_t* k_cache | 输入/输出 | 指向键缓存全局内存的指针,存储需要进行shuffle操作的键数据块,连续内存,按块组织,每个块的大小为 kv_head_num * page_size * head_dim * sizeof(data_type) |
| uint8_t* v_cache | 输入/输出 | 指向值缓存全局内存的指针,存储需要进行shuffle操作的值数据块,与k_cache相同的连续内存布局 |
| uint8_t* global_shuffle_table | 输入 | 全局shuffle表,存储每个进程的配对信息和操作类型,实际存储int64_t类型数据,内存布局 :数组结构,每个PE对应2个int64_t条目:[pair_rank_0, operation_0, pair_rank_1, operation_1, ..., pair_rank_n, operation_n] 数据限制 :大小必须为 2 * n_pes * sizeof(int64_t) ,其中n_pes是进程总数,operation只能是0或1(0表示发送,1表示接收)配对关系必须是双向的(A的pair_rank是B,则B的pair_rank必须是A) |
| uint8_t* src_block_table | 输入 | 源块索引表,指示每个shuffle操作的源块ID,实际存储int64_t类型数据,一维数组,长度为block_nums,每个元素的值必须是有效的块ID(0 ≤ src_block_id < block_nums) |
| uint8_t* dst_block_table | 输入 | 目标块索引表,指示每个shuffle操作的目标块ID,实际存储int64_t类型数据,每个元素的值必须是有效的块ID(0 ≤ dst_block_id < block_nums) |
| int64_t block_nums | 输入 | 需要进行shuffle操作的块数量 |
| int64_t kv_head_num | 输入 | 键值数据的头数量 |
| int64_t page_size | 输入 | KV缓存中每个页面的大小 |
| int64_t head_dim | 输入 | 每个头的维度 |
torch接口
# 创建算子
kv_shuffle = torch.classes.ShmemOps.KVShuffle()
# 计算
kv_shuffle.compute(global_shuffle_tensor, aclshmem_k_cache_tensor,
aclshmem_v_cache_tensor, src_block_tensor, dst_block_tensor)
接口参数说明
1. global_shuffle_tensor
- 含义 :全局shuffle表,存储每个进程的配对信息和操作类型
- 数据类型 :PyTorch张量, torch.int64 类型
- 形状 :二维数组,形状为 [n_pes, 2] ,其中n_pes是进程总数
- 内容结构 :
[ [pair_rank_0, operation_0], [pair_rank_1, operation_1], ..., [pair_rank_n, operation_n] ] - 数据限制 :
- 必须在NPU设备上(使用 .npu() 方法转换)
- operation 只能是0或1(0表示发送,1表示接收)
- 配对关系必须是双向的(A的pair_rank是B,则B的pair_rank必须是A)
- 数据类型必须是 int64
2. aclshmem_k_cache_tensor
- 含义 :指向键缓存全局内存的张量,存储需要进行shuffle操作的键数据块
- 数据类型 :PyTorch张量,torch.int8 类型
- 形状 :四维数组,形状为 [block_nums, kv_head_num, page_size, head_dim]
- 数据限制 :
- 通过 aclshmem_common.malloc_like() 创建的ACL SHMEM共享内存张量
- 维度顺序必须严格为 [块数, 头数, 页大小, 头维度]
- 块数必须与 block_nums 参数匹配
3. aclshmem_v_cache_tensor
- 含义 :指向值缓存全局内存的张量,存储需要进行shuffle操作的值数据块
- 数据类型 :PyTorch张量,与 aclshmem_k_cache_tensor 相同
- 形状 :四维数组,形状与 aclshmem_k_cache_tensor 相同: [block_nums, kv_head_num, page_size, head_dim]
- 数据限制 :
- 通过 aclshmem_common.malloc_like() 创建的ACL SHMEM共享内存张量
- 在NPU设备上
- 数据类型必须与 aclshmem_k_cache_tensor 一致
- 形状必须与 aclshmem_k_cache_tensor 完全匹配
4. src_block_tensor
- 含义 :源块索引表,指示每个shuffle操作的源块ID
- 数据类型 :PyTorch张量, torch.int64 类型
- 形状 :一维数组,长度为 block_nums
- 数据限制 :
- 必须在NPU设备上
- 数据类型必须是 int64
- 每个元素的值必须是有效的块ID(0 ≤ src_block_id < block_nums)
- 数组长度必须与当前进程需要处理的块数匹配
5. dst_block_tensor
- 含义 :目标块索引表,指示每个shuffle操作的目标块ID
- 数据类型 :PyTorch张量, torch.int64 类型
- 形状 :一维数组,长度为 block_nums
- 数据限制 :
- 必须在NPU设备上
- 数据类型必须是 int64
- 每个元素的值必须是有效的块ID(0 ≤ dst_block_id < block_nums)
- 数组长度必须与 src_block_tensor 完全匹配
关键参数推导
在PyTorch扩展的C++实现中,以下参数从输入张量中推导出来:
- block_nums :从 dst_block_tensor.size(0) 获取,表示需要处理的块数量
- kv_head_num :从 KeyCache.size(1) 获取,表示键值数据的头数量
- page_size :从 KeyCache.size(2) 获取,表示每个页面的token数量
- head_dim :从 KeyCache.size(3) 获取,表示每个头的维度
KVShuffle算子数据流转效果说明
本文档通过具体示例展示KVShuffle算子的数据流转过程,包括初始数据状态、传输策略计算、块表生成和最终数据变换结果。
1. 示例配置
为了清晰展示数据流转,我们使用以下简化配置:
| 参数 | 值 | 说明 |
|---|---|---|
| RANKS | 2 | 进程总数 |
| INIT_BATCH | 2 | 每个进程的初始批次数 |
| PAGE_SIZE | 4 | 页面大小(每个页面包含的token数量) |
| KV_HEAD_NUM | 1 | 头数量(简化为1便于展示) |
| HEAD_DIM | 2 | 头维度(简化为2便于展示) |
| MAX_SEQLEN | 8 | 最大序列长度 |
2. Batch Token与KV缓存的关系
在理解数据流转之前,需要先明确Batch Token与KV缓存之间的核心关系:
2.1 基本概念
| 概念 | 含义 |
|---|---|
| Batch Token | 一个batch中包含的token数量,即序列长度(seqlen) |
| KV缓存 | 存储键值对的缓存结构,用于注意力机制的高效计算 |
| Page | KV缓存的基本存储单位,每个page包含固定数量的token(PAGE_SIZE) |
| Block | 一个batch在KV缓存中占用的连续pages集合 |
2.2 关系公式
-
块数计算:一个batch需要的块数 = batch token数 ÷ 页面大小 + 1(向上取整)
block_num = seqlen // PAGE_SIZE + 1 -
KV缓存总大小:
total_cache_size = max_block_nums × kv_head_num × page_size × head_dim × data_type_size
2.3 映射关系示例
以进程0的Batch 0为例:
- Batch Token数:6
- 页面大小(PAGE_SIZE):4
- 需要的块数:6 ÷ 4 + 1 = 2块
- 这2块对应KV缓存中的Block 0和Block 1
2.4 直观理解
Batch Token (seqlen=6) → 映射到 → KV缓存的2个块
┌─────────────────┐ ┌────────────────────────────────────────────┐
│ Token 0-5 │ │ Block 0 (PAGE_SIZE=4): 存储Token 0-3 │
│ (6个token) │ │ Block 1 (PAGE_SIZE=4): 存储Token 4-5 │
└─────────────────┘ └────────────────────────────────────────────┘
这种映射关系确保了即使不同batch的token长度不同,也能在KV缓存中高效存储和访问。
3. 初始数据状态
3.1 进程0的初始数据
KV缓存形状:(block_num, kv_head_num, page_size, head_dim) = (4, 1, 4, 2)
K缓存数据:
# Block 0
[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]]
# Block 1
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]]
# Block 2
[[[3.1, 3.2], [3.3, 3.4], [3.5, 3.6], [3.7, 3.8]]]
# Block 3
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]
V缓存数据:
# Block 0
[[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]]
# Block 1
[[[0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6]]]
# Block 2
[[[1.7, 1.8], [1.9, 2.0], [2.1, 2.2], [2.3, 2.4]]]
# Block 3
[[[2.5, 2.6], [2.7, 2.8], [2.9, 3.0], [3.1, 3.2]]]
3.2 进程1的初始数据
KV缓存形状:(block_num, kv_head_num, page_size, head_dim) = (2, 1, 4, 2)
K缓存数据:
# Block 0
[[[5.1, 5.2], [5.3, 5.4], [5.5, 5.6], [5.7, 5.8]]]
# Block 1
[[[6.1, 6.2], [6.3, 6.4], [6.5, 6.6], [6.7, 6.8]]]
V缓存数据:
# Block 0
[[[3.3, 3.4], [3.5, 3.6], [3.7, 3.8], [3.9, 4.0]]]
# Block 1
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]
3. 负载均衡与传输策略计算
3.1 Batch token长度
假设生成的batch token长度如下:
| 进程 | Batch 0 | Batch 1 | 总token数 |
|---|---|---|---|
| 0 | 6 | 7 | 13 |
| 1 | 3 | 3 | 6 |
3.2 块数计算
每个batch的块数计算公式:block_num = seqlen // PAGE_SIZE + 1
| 进程 | Batch 0 | Batch 1 | 总块数 |
|---|---|---|---|
| 0 | 2 (6//4+1) | 2 (7//4+1) | 4 |
| 1 | 1 (3//4+1) | 1 (3//4+1) | 2 |
注意:由于 token 是按块管理的,最后一个块可能没有填满(例如 Batch 0 的 Block 1 只有 2 个 token)。
3.3 Batch块映射
进程0的batch_blocks_list:
[ # batch_blocks_list[0]
(0, [0, 1]), # Batch 0使用块0和块1
(1, [2, 3]) # Batch 1使用块2和块3
]
进程1的batch_blocks_list:
[ # batch_blocks_list[1]
(0, [0]), # Batch 0使用块0
(1, [1]) # Batch 1使用块1
]
3.4 负载均衡计算
- 平均token数:
(13 + 6) / 2 = 9.5 - 进程0需要传输的token数:
13 - 9.5 = 3.5(取整为3) - 进程1需要接收的token数:
9.5 - 6 = 3.5(取整为3)
3.5 传输batch选择
由于token是按块管理的,选择传输Batch 0(6个token)来尽可能接近理想负载:
transfer_tokens_list:
[ # transfer_tokens_list
(6, [0]), # 进程0传输Batch 0的6个token
(-1, []) # 进程1不需要传输
]
4. 块表生成
4.1 src_block_table
进程0需要传输Batch 0对应的块ID [0, 1]:
src_block_table:
[ # src_block_table
[0, 1], # 进程0的源块表
[] # 进程1的源块表
]
4.2 dst_block_table
进程1当前使用了2个块(0和1),因此目标块ID从2开始:
dst_block_table:
[ # dst_block_table
[2, 3], # 进程0的目标块表(传输到进程1的块2和3)
[] # 进程1的目标块表
]
4.3 配对关系
pair_list:
[ # pair_list
[1, 0], # 进程0与进程1配对,角色为发送方(0)
[0, 1] # 进程1与进程0配对,角色为接收方(1)
]
5. KVShuffle数据变换
5.1 数据传输过程
| 源进程 | 源块ID | 目标进程 | 目标块ID | 传输的数据 |
|---|---|---|---|---|
| 0 | 0 | 1 | 2 | 进程0的K块0、V块0 |
| 0 | 1 | 1 | 3 | 进程0的K块1、V块1 |
5.2 关于源进程块的清理说明
为什么进程0的块没有被清理?
- KVShuffle算子默认执行的是数据复制而非数据移动
- 这是因为在分布式训练场景中,源进程可能仍然需要这些数据用于后续的计算或其他batch处理
- 如果应用层确实需要清理源进程的数据,可以在KVShuffle操作完成后,手动释放或标记这些块为可用
- 清理操作通常由应用层根据具体业务逻辑决定,而不是由KVShuffle算子自动执行
5.3 变换后数据状态
进程0的最终数据(不变)
K缓存:
# Block 0
[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]]
# Block 1
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]]
# Block 2
[[[3.1, 3.2], [3.3, 3.4], [3.5, 3.6], [3.7, 3.8]]]
# Block 3
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]
V缓存:
# Block 0
[[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]]
# Block 1
[[[0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6]]]
# Block 2
[[[1.7, 1.8], [1.9, 2.0], [2.1, 2.2], [2.3, 2.4]]]
# Block 3
[[[2.5, 2.6], [2.7, 2.8], [2.9, 3.0], [3.1, 3.2]]]
进程1的最终数据(新增块2和3)
K缓存:
# Block 0(原有)
[[[5.1, 5.2], [5.3, 5.4], [5.5, 5.6], [5.7, 5.8]]]
# Block 1(原有)
[[[6.1, 6.2], [6.3, 6.4], [6.5, 6.6], [6.7, 6.8]]]
# Block 2(新增,来自进程0的块0)
[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]]
# Block 3(新增,来自进程0的块1)
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]]
V缓存:
# Block 0(原有)
[[[3.3, 3.4], [3.5, 3.6], [3.7, 3.8], [3.9, 4.0]]]
# Block 1(原有)
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]
# Block 2(新增,来自进程0的块0)
[[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]]
# Block 3(新增,来自进程0的块1)
[[[0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6]]]
6. 数据验证
6.1 传输前后数据一致性
- 进程1的K块2与进程0的K块0完全相同
- 进程1的K块3与进程0的K块1完全相同
- 进程1的V块2与进程0的V块0完全相同
- 进程1的V块3与进程0的V块1完全相同
6.2 负载均衡效果
传输前后的token分布:
| 进程 | 传输前token数 | 传输后token数 | 均衡度 |
|---|---|---|---|
| 0 | 13 | 13 - 6 = 7 | 更接近平均值9.5 |
| 1 | 6 | 6 + 6 = 12 | 更接近平均值9.5 |
7. 数据流转总结
┌─────────────────────────┐ ┌────────────────────────┐
│ 进程0初始数据 │ │ 进程1初始数据 │
│ K块0: [1.1, 1.2, ...] │ │ K块0: [5.1, 5.2, ...] │
│ K块1: [2.1, 2.2, ...] │ │ K块1: [6.1, 6.2, ...] │
│ K块2: [3.1, 3.2, ...] │ │ V块0: [3.3, 3.4, ...] │
│ K块3: [4.1, 4.2, ...] │ │ V块1: [4.1, 4.2, ...] │
│ V块0: [0.1, 0.2, ...] │ └─────────────────────────┘
│ V块1: [0.9, 1.0, ...] │ ▲
│ V块2: [1.7, 1.8, ...] │ │
│ V块3: [2.5, 2.6, ...] │ │
└────────────┬────────────┘ │
│ │
│ 传输Batch 0的块0和块1 │
│ │
▼ │
┌─────────────────────────┐ ┌─────────┴─────────┐
│ 生成块表和策略 │ │ 数据传输 │
│ src_block_table: [0, 1]│───▶│ K块0 → 进程1的K块2 │
│ dst_block_table: [2, 3]│ │ K块1 → 进程1的K块3 │
│ pair_list: [1, 0] │ │ V块0 → 进程1的V块2 │
└─────────────────────────┘ │ V块1 → 进程1的V块3 │
└───────────────────┘
│
▼
┌─────────────────────────┐ ┌────────────────────────┐
│ 进程0最终数据 │ │ 进程1最终数据 │
│ K块0: [1.1, 1.2, ...] │ │ K块0: [5.1, 5.2, ...] │
│ K块1: [2.1, 2.2, ...] │ │ K块1: [6.1, 6.2, ...] │
│ K块2: [3.1, 3.2, ...] │ │ K块2: [1.1, 1.2, ...] │
│ K块3: [4.1, 4.2, ...] │ │ K块3: [2.1, 2.2, ...] │
│ V块0: [0.1, 0.2, ...] │ │ V块0: [3.3, 3.4, ...] │
│ V块1: [0.9, 1.0, ...] │ │ V块1: [4.1, 4.2, ...] │
│ V块2: [1.7, 1.8, ...] │ │ V块2: [0.1, 0.2, ...] │
│ V块3: [2.5, 2.6, ...] │ │ V块3: [0.9, 1.0, ...] │
└─────────────────────────┘ └─────────────────────────┘
8. 关键数据结构示例
8.1 global_shuffle_tensor
[[1, 0], # 进程0与进程1配对,角色为发送方
[0, 1]] # 进程1与进程0配对,角色为接收方
8.2 aclshmem_k_cache_tensor(进程0)
# 形状: (4, 1, 4, 2)
[[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]], # Block 0
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]], # Block 1
[[[3.1, 3.2], [3.3, 3.4], [3.5, 3.6], [3.7, 3.8]]], # Block 2
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]] # Block 3
8.3 src_block_tensor(进程0)
[0, 1] # 要传输的源块ID
8.4 dst_block_tensor(进程0)
[2, 3] # 传输到目标进程的块ID
9. 性能指标示例
| 指标 | 值 | 说明 |
|---|---|---|
| 传输块数 | 2 | 本次传输了2个块 |
| 传输token数 | 6 | 从进程0传输2个块(共8个token存储空间,实际有用6个token:1个满块4个token+1个半块2个token)到进程1 |
| 传输数据量 | 2×1×4×2×2=32字节 | K和V各16字节(float16类型) |
| 负载均衡度 | 从13:6变为7:12 | 更接近理想的9.5:9.5 |
10. 应用场景说明
通过这个具体的数据流示例,我们可以看到KVShuffle算子:
- 解决了负载不均衡问题:将数据从负载高的进程传输到负载低的进程
- 保持了数据完整性:传输前后的数据内容完全一致
- 高效利用了内存:只传输必要的块,避免了不必要的数据移动
- 支持动态批处理:可以根据实际batch大小动态调整传输策略
这种数据流转机制特别适合于分布式训练中的KV缓存管理,可以有效提高训练效率和资源利用率。