文件最后提交记录最后更新时间
修复kvshuffle环境变量 Co-authored-by: zhangyunqi<zhangyunqi5@huawei.com> # message auto-generated for no-merge-commit merge: !185 merge fixkv into master 修复kvshuffle环境变量 Created-by: zhangyunqi Commit-by: zhangyunqi Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> rpath关闭环境下找不到so文件 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #123--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> https://gitcode.com/cann/shmem/issues/151 ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ![image.png](https://raw.gitcode.com/user-images/assets/8546182/e95b3e37-a3fe-48f1-8869-87c81826c979/image.png 'image.png') ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [x] Bug修复 - [ ] 新特性 - [ ] 性能优化 - [ ] 文档更新 - [ ] 其他,请描述: See merge request: cann/shmem!1852 个月前
fix LICENSE Co-authored-by: jiang-xinyu3<jiangxinyu3@hisilicon.com> 5 个月前
-O0 -g编译QA补充 Co-authored-by: zhangyunqi<zhangyunqi5@huawei.com> # message auto-generated for no-merge-commit merge: !255 merge adddebugdoc into master -O0 -g编译QA补充 Created-by: zhangyunqi Commit-by: zhangyunqi Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> -O0 -g编译QA补充 kvshuffle算子文档说明补充 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #123--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> https://gitcode.com/cann/shmem/issues/6 ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] Bug修复 - [ ] 新特性 - [ ] 性能优化 - [x] 文档更新 - [ ] 其他,请描述: See merge request: cann/shmem!2551 个月前
编译选项整改 xcce2xasc Co-authored-by: zhangyunqi<zhangyunqi5@huawei.com> # message auto-generated for no-merge-commit merge: !128 merge xcce2xasc into master 编译选项整改 xcce2xasc Created-by: zhangyunqi Commit-by: zhangyunqi Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> xcce2xasc ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #123--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> https://gitcode.com/cann/shmem/issues/95 ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ![image.png](https://raw.gitcode.com/user-images/assets/8546182/9e5c647c-f4ef-43e2-90bb-14f900cd48fa/image.png 'image.png') A5 ![image.png](https://raw.gitcode.com/user-images/assets/8546182/86160a4e-89c8-4b48-8fe0-992809bfba91/image.png 'image.png') ![image.png](https://raw.gitcode.com/user-images/assets/8546182/0ca6578d-9de2-4661-9099-641c2223a74b/image.png 'image.png') ![image.png](https://raw.gitcode.com/user-images/assets/8546182/5201bfb2-5503-451e-8028-f4abb1c6d21e/image.png 'image.png') rdma ![image.png](https://raw.gitcode.com/user-images/assets/8546182/5a1313b7-ea0b-45cf-b8a3-bda6c332ad3f/image.png 'image.png') ![image.png](https://raw.gitcode.com/user-images/assets/8546182/90473372-2bb5-421a-bed4-4c168901d845/image.png 'image.png') ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [x] Bug修复 - [ ] 新特性 - [ ] 性能优化 - [ ] 文档更新 - [ ] 其他,请描述: See merge request: cann/shmem!1281 个月前
fix LICENSE Co-authored-by: jiang-xinyu3<jiangxinyu3@hisilicon.com> 5 个月前
example下的pe名称统一与readme补充 Co-authored-by: dovahkiiin<haorunzhe@h-partners.com> # message auto-generated for no-merge-commit merge: !150 merge fix/fix_issue_115-117 into master example下的pe名称统一与readme补充 Created-by: dovahkiiin Commit-by: dovahkiiin Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #123--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] Bug修复 - [ ] 新特性 - [ ] 性能优化 - [x] 文档更新 - [ ] 其他,请描述: See merge request: cann/shmem!1503 个月前
README.md

使用方式

  1. 编译项目
    shmem/ 根目录下执行编译脚本:

    bash scripts/build.sh -examples
    
  2. 运行KV_Shuffle示例程序
    进入示例目录并执行运行脚本:

    cd examples/kv_shuffle
    bash scripts/run.sh [pe_size]
    
    • 参数说明
      • pe_size:指定算子运行的pe个数。
      • 示例:使用第0和第1个NPU设备运行2卡kv_shuffle示例:
        bash scripts/run.sh 2
        

算子介绍

KV Shuffle 算子核心功能是实现 KV Cache(键值缓存)的跨设备 / 跨 PE 数据重排与远程拷贝,适配大模型训练 / 推理中 KV Cache 的分布式调度需求。

在大模型分布式训练 / 推理场景中,KV Cache 会按 Block(块)管理,不同计算 PE 之间需要根据调度策略(如 shuffle table)对 KV Cache 的 Block 进行重排、迁移,本算子为该场景提供高效 KV Block 跨 PE 拷贝与重映射,相比传统主机侧调度,大幅降低 KV Cache 迁移的延迟和带宽开销。

cpp接口

class KVShuffleOps {
public:
    // 默认构造函数
    KVShuffleOps(uint32_t block_dims, void* stream);

    ~KVShuffleOps();

    // 接受 Tensor 的函数
    void compute(
        uint8_t* k_cache,
        uint8_t* v_cache,
        uint8_t* global_shuffle_table,
        uint8_t* src_block_table,
        uint8_t* dst_block_table,
        int64_t block_nums,
        int64_t kv_head_num, int64_t page_size, int64_t head_dim);
private:
    void* sync_ptr_;
    int32_t count_;
    uint32_t block_dims_;
    void* stream_;
    uint64_t fftsAddr_;
};

接口参数说明

参数名 输入/输出 描述
uint8_t* k_cache 输入/输出 指向键缓存全局内存的指针,存储需要进行shuffle操作的键数据块,连续内存,按块组织,每个块的大小为 kv_head_num * page_size * head_dim * sizeof(data_type)
uint8_t* v_cache 输入/输出 指向值缓存全局内存的指针,存储需要进行shuffle操作的值数据块,与k_cache相同的连续内存布局
uint8_t* global_shuffle_table 输入 全局shuffle表,存储每个进程的配对信息和操作类型,实际存储int64_t类型数据,内存布局 :数组结构,每个PE对应2个int64_t条目:[pair_rank_0, operation_0, pair_rank_1, operation_1, ..., pair_rank_n, operation_n] 数据限制 :大小必须为 2 * n_pes * sizeof(int64_t) ,其中n_pes是进程总数,operation只能是0或1(0表示发送,1表示接收)配对关系必须是双向的(A的pair_rank是B,则B的pair_rank必须是A)
uint8_t* src_block_table 输入 源块索引表,指示每个shuffle操作的源块ID,实际存储int64_t类型数据,一维数组,长度为block_nums,每个元素的值必须是有效的块ID(0 ≤ src_block_id < block_nums)
uint8_t* dst_block_table 输入 目标块索引表,指示每个shuffle操作的目标块ID,实际存储int64_t类型数据,每个元素的值必须是有效的块ID(0 ≤ dst_block_id < block_nums)
int64_t block_nums 输入 需要进行shuffle操作的块数量
int64_t kv_head_num 输入 键值数据的头数量
int64_t page_size 输入 KV缓存中每个页面的大小
int64_t head_dim 输入 每个头的维度

torch接口

# 创建算子
kv_shuffle = torch.classes.ShmemOps.KVShuffle()
# 计算
kv_shuffle.compute(global_shuffle_tensor, aclshmem_k_cache_tensor,
                                aclshmem_v_cache_tensor, src_block_tensor, dst_block_tensor)

接口参数说明

1. global_shuffle_tensor
  • 含义 :全局shuffle表,存储每个进程的配对信息和操作类型
  • 数据类型 :PyTorch张量, torch.int64 类型
  • 形状 :二维数组,形状为 [n_pes, 2] ,其中n_pes是进程总数
  • 内容结构 :
    [
      [pair_rank_0, operation_0],
      [pair_rank_1, operation_1],
      ...,
      [pair_rank_n, operation_n]
    ]
    
  • 数据限制 :
    • 必须在NPU设备上(使用 .npu() 方法转换)
    • operation 只能是0或1(0表示发送,1表示接收)
    • 配对关系必须是双向的(A的pair_rank是B,则B的pair_rank必须是A)
    • 数据类型必须是 int64
2. aclshmem_k_cache_tensor
  • 含义 :指向键缓存全局内存的张量,存储需要进行shuffle操作的键数据块
  • 数据类型 :PyTorch张量,torch.int8 类型
  • 形状 :四维数组,形状为 [block_nums, kv_head_num, page_size, head_dim]
  • 数据限制 :
    • 通过 aclshmem_common.malloc_like() 创建的ACL SHMEM共享内存张量
    • 维度顺序必须严格为 [块数, 头数, 页大小, 头维度]
    • 块数必须与 block_nums 参数匹配
3. aclshmem_v_cache_tensor
  • 含义 :指向值缓存全局内存的张量,存储需要进行shuffle操作的值数据块
  • 数据类型 :PyTorch张量,与 aclshmem_k_cache_tensor 相同
  • 形状 :四维数组,形状与 aclshmem_k_cache_tensor 相同: [block_nums, kv_head_num, page_size, head_dim]
  • 数据限制 :
    • 通过 aclshmem_common.malloc_like() 创建的ACL SHMEM共享内存张量
    • 在NPU设备上
    • 数据类型必须与 aclshmem_k_cache_tensor 一致
    • 形状必须与 aclshmem_k_cache_tensor 完全匹配
4. src_block_tensor
  • 含义 :源块索引表,指示每个shuffle操作的源块ID
  • 数据类型 :PyTorch张量, torch.int64 类型
  • 形状 :一维数组,长度为 block_nums
  • 数据限制 :
    • 必须在NPU设备上
    • 数据类型必须是 int64
    • 每个元素的值必须是有效的块ID(0 ≤ src_block_id < block_nums)
    • 数组长度必须与当前进程需要处理的块数匹配
5. dst_block_tensor
  • 含义 :目标块索引表,指示每个shuffle操作的目标块ID
  • 数据类型 :PyTorch张量, torch.int64 类型
  • 形状 :一维数组,长度为 block_nums
  • 数据限制 :
    • 必须在NPU设备上
    • 数据类型必须是 int64
    • 每个元素的值必须是有效的块ID(0 ≤ dst_block_id < block_nums)
    • 数组长度必须与 src_block_tensor 完全匹配
关键参数推导

在PyTorch扩展的C++实现中,以下参数从输入张量中推导出来:

  • block_nums :从 dst_block_tensor.size(0) 获取,表示需要处理的块数量
  • kv_head_num :从 KeyCache.size(1) 获取,表示键值数据的头数量
  • page_size :从 KeyCache.size(2) 获取,表示每个页面的token数量
  • head_dim :从 KeyCache.size(3) 获取,表示每个头的维度

KVShuffle算子数据流转效果说明

本文档通过具体示例展示KVShuffle算子的数据流转过程,包括初始数据状态、传输策略计算、块表生成和最终数据变换结果。

1. 示例配置

为了清晰展示数据流转,我们使用以下简化配置:

参数 说明
RANKS 2 进程总数
INIT_BATCH 2 每个进程的初始批次数
PAGE_SIZE 4 页面大小(每个页面包含的token数量)
KV_HEAD_NUM 1 头数量(简化为1便于展示)
HEAD_DIM 2 头维度(简化为2便于展示)
MAX_SEQLEN 8 最大序列长度

2. Batch Token与KV缓存的关系

在理解数据流转之前,需要先明确Batch Token与KV缓存之间的核心关系:

2.1 基本概念
概念 含义
Batch Token 一个batch中包含的token数量,即序列长度(seqlen)
KV缓存 存储键值对的缓存结构,用于注意力机制的高效计算
Page KV缓存的基本存储单位,每个page包含固定数量的token(PAGE_SIZE)
Block 一个batch在KV缓存中占用的连续pages集合
2.2 关系公式
  1. 块数计算:一个batch需要的块数 = batch token数 ÷ 页面大小 + 1(向上取整)

    block_num = seqlen // PAGE_SIZE + 1
    
  2. KV缓存总大小

    total_cache_size = max_block_nums × kv_head_num × page_size × head_dim × data_type_size
    
2.3 映射关系示例

以进程0的Batch 0为例:

  • Batch Token数:6
  • 页面大小(PAGE_SIZE):4
  • 需要的块数:6 ÷ 4 + 1 = 2块
  • 这2块对应KV缓存中的Block 0和Block 1
2.4 直观理解
Batch Token (seqlen=6) → 映射到 → KV缓存的2个块
┌─────────────────┐     ┌────────────────────────────────────────────┐
│ Token 0-5       │     │ Block 0 (PAGE_SIZE=4): 存储Token 0-3       │
│ (6个token)      │     │ Block 1 (PAGE_SIZE=4): 存储Token 4-5       │
└─────────────────┘     └────────────────────────────────────────────┘

这种映射关系确保了即使不同batch的token长度不同,也能在KV缓存中高效存储和访问。

3. 初始数据状态

3.1 进程0的初始数据

KV缓存形状:(block_num, kv_head_num, page_size, head_dim) = (4, 1, 4, 2)

K缓存数据

# Block 0
[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]]

# Block 1
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]]

# Block 2
[[[3.1, 3.2], [3.3, 3.4], [3.5, 3.6], [3.7, 3.8]]]

# Block 3
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]

V缓存数据

# Block 0
[[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]]

# Block 1
[[[0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6]]]

# Block 2
[[[1.7, 1.8], [1.9, 2.0], [2.1, 2.2], [2.3, 2.4]]]

# Block 3
[[[2.5, 2.6], [2.7, 2.8], [2.9, 3.0], [3.1, 3.2]]]
3.2 进程1的初始数据

KV缓存形状:(block_num, kv_head_num, page_size, head_dim) = (2, 1, 4, 2)

K缓存数据

# Block 0
[[[5.1, 5.2], [5.3, 5.4], [5.5, 5.6], [5.7, 5.8]]]

# Block 1
[[[6.1, 6.2], [6.3, 6.4], [6.5, 6.6], [6.7, 6.8]]]

V缓存数据

# Block 0
[[[3.3, 3.4], [3.5, 3.6], [3.7, 3.8], [3.9, 4.0]]]

# Block 1
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]

3. 负载均衡与传输策略计算

3.1 Batch token长度

假设生成的batch token长度如下:

进程 Batch 0 Batch 1 总token数
0 6 7 13
1 3 3 6
3.2 块数计算

每个batch的块数计算公式:block_num = seqlen // PAGE_SIZE + 1

进程 Batch 0 Batch 1 总块数
0 2 (6//4+1) 2 (7//4+1) 4
1 1 (3//4+1) 1 (3//4+1) 2

注意:由于 token 是按块管理的,最后一个块可能没有填满(例如 Batch 0 的 Block 1 只有 2 个 token)。

3.3 Batch块映射

进程0的batch_blocks_list

[  # batch_blocks_list[0]
    (0, [0, 1]),  # Batch 0使用块0和块1
    (1, [2, 3])   # Batch 1使用块2和块3
]

进程1的batch_blocks_list

[  # batch_blocks_list[1]
    (0, [0]),     # Batch 0使用块0
    (1, [1])      # Batch 1使用块1
]
3.4 负载均衡计算
  • 平均token数:(13 + 6) / 2 = 9.5
  • 进程0需要传输的token数:13 - 9.5 = 3.5(取整为3)
  • 进程1需要接收的token数:9.5 - 6 = 3.5(取整为3)
3.5 传输batch选择

由于token是按块管理的,选择传输Batch 0(6个token)来尽可能接近理想负载:

transfer_tokens_list

[  # transfer_tokens_list
    (6, [0]),  # 进程0传输Batch 0的6个token
    (-1, [])   # 进程1不需要传输
]

4. 块表生成

4.1 src_block_table

进程0需要传输Batch 0对应的块ID [0, 1]:

src_block_table

[  # src_block_table
    [0, 1],  # 进程0的源块表
    []       # 进程1的源块表
]
4.2 dst_block_table

进程1当前使用了2个块(0和1),因此目标块ID从2开始:

dst_block_table

[  # dst_block_table
    [2, 3],  # 进程0的目标块表(传输到进程1的块2和3)
    []       # 进程1的目标块表
]
4.3 配对关系

pair_list

[  # pair_list
    [1, 0],  # 进程0与进程1配对,角色为发送方(0)
    [0, 1]   # 进程1与进程0配对,角色为接收方(1)
]

5. KVShuffle数据变换

5.1 数据传输过程
源进程 源块ID 目标进程 目标块ID 传输的数据
0 0 1 2 进程0的K块0、V块0
0 1 1 3 进程0的K块1、V块1
5.2 关于源进程块的清理说明

为什么进程0的块没有被清理?

  • KVShuffle算子默认执行的是数据复制而非数据移动
  • 这是因为在分布式训练场景中,源进程可能仍然需要这些数据用于后续的计算或其他batch处理
  • 如果应用层确实需要清理源进程的数据,可以在KVShuffle操作完成后,手动释放或标记这些块为可用
  • 清理操作通常由应用层根据具体业务逻辑决定,而不是由KVShuffle算子自动执行
5.3 变换后数据状态
进程0的最终数据(不变)

K缓存

# Block 0
[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]]

# Block 1
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]]

# Block 2
[[[3.1, 3.2], [3.3, 3.4], [3.5, 3.6], [3.7, 3.8]]]

# Block 3
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]

V缓存

# Block 0
[[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]]

# Block 1
[[[0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6]]]

# Block 2
[[[1.7, 1.8], [1.9, 2.0], [2.1, 2.2], [2.3, 2.4]]]

# Block 3
[[[2.5, 2.6], [2.7, 2.8], [2.9, 3.0], [3.1, 3.2]]]
进程1的最终数据(新增块2和3)

K缓存

# Block 0(原有)
[[[5.1, 5.2], [5.3, 5.4], [5.5, 5.6], [5.7, 5.8]]]

# Block 1(原有)
[[[6.1, 6.2], [6.3, 6.4], [6.5, 6.6], [6.7, 6.8]]]

# Block 2(新增,来自进程0的块0)
[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]]

# Block 3(新增,来自进程0的块1)
[[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]]

V缓存

# Block 0(原有)
[[[3.3, 3.4], [3.5, 3.6], [3.7, 3.8], [3.9, 4.0]]]

# Block 1(原有)
[[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]

# Block 2(新增,来自进程0的块0)
[[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]]

# Block 3(新增,来自进程0的块1)
[[[0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6]]]

6. 数据验证

6.1 传输前后数据一致性
  • 进程1的K块2与进程0的K块0完全相同
  • 进程1的K块3与进程0的K块1完全相同
  • 进程1的V块2与进程0的V块0完全相同
  • 进程1的V块3与进程0的V块1完全相同
6.2 负载均衡效果

传输前后的token分布:

进程 传输前token数 传输后token数 均衡度
0 13 13 - 6 = 7 更接近平均值9.5
1 6 6 + 6 = 12 更接近平均值9.5

7. 数据流转总结

┌─────────────────────────┐     ┌────────────────────────┐
│        进程0初始数据     │     │        进程1初始数据    │
│  K块0: [1.1, 1.2, ...]  │     │  K块0: [5.1, 5.2, ...]  │
│  K块1: [2.1, 2.2, ...]  │     │  K块1: [6.1, 6.2, ...]  │
│  K块2: [3.1, 3.2, ...]  │     │  V块0: [3.3, 3.4, ...]  │
│  K块3: [4.1, 4.2, ...]  │     │  V块1: [4.1, 4.2, ...]  │
│  V块0: [0.1, 0.2, ...]  │     └─────────────────────────┘
│  V块1: [0.9, 1.0, ...]  │               ▲
│  V块2: [1.7, 1.8, ...]  │               │
│  V块3: [2.5, 2.6, ...]  │               │
└────────────┬────────────┘               │
             │                            │
             │ 传输Batch 0的块0和块1       │
             │                            │
             ▼                            │
┌─────────────────────────┐     ┌─────────┴─────────┐
│      生成块表和策略      │     │       数据传输    │
│  src_block_table: [0, 1]│───▶│ K块0 → 进程1的K块2 │
│  dst_block_table: [2, 3]│     │ K块1 → 进程1的K块3 │
│  pair_list: [1, 0]      │     │ V块0 → 进程1的V块2 │
└─────────────────────────┘     │ V块1 → 进程1的V块3 │
                                └───────────────────┘
                                          │
                                          ▼
┌─────────────────────────┐     ┌────────────────────────┐
│      进程0最终数据       │     │      进程1最终数据      │
│  K块0: [1.1, 1.2, ...]  │     │  K块0: [5.1, 5.2, ...]  │
│  K块1: [2.1, 2.2, ...]  │     │  K块1: [6.1, 6.2, ...]  │
│  K块2: [3.1, 3.2, ...]  │     │  K块2: [1.1, 1.2, ...]  │
│  K块3: [4.1, 4.2, ...]  │     │  K块3: [2.1, 2.2, ...]  │
│  V块0: [0.1, 0.2, ...]  │     │  V块0: [3.3, 3.4, ...]  │
│  V块1: [0.9, 1.0, ...]  │     │  V块1: [4.1, 4.2, ...]  │
│  V块2: [1.7, 1.8, ...]  │     │  V块2: [0.1, 0.2, ...]  │
│  V块3: [2.5, 2.6, ...]  │     │  V块3: [0.9, 1.0, ...]  │
└─────────────────────────┘     └─────────────────────────┘

8. 关键数据结构示例

8.1 global_shuffle_tensor
[[1, 0],  # 进程0与进程1配对,角色为发送方
 [0, 1]]  # 进程1与进程0配对,角色为接收方
8.2 aclshmem_k_cache_tensor(进程0)
# 形状: (4, 1, 4, 2)
[[[[1.1, 1.2], [1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]],  # Block 0
 [[[2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8]]],  # Block 1
 [[[3.1, 3.2], [3.3, 3.4], [3.5, 3.6], [3.7, 3.8]]],  # Block 2
 [[[4.1, 4.2], [4.3, 4.4], [4.5, 4.6], [4.7, 4.8]]]]  # Block 3
8.3 src_block_tensor(进程0)
[0, 1]  # 要传输的源块ID
8.4 dst_block_tensor(进程0)
[2, 3]  # 传输到目标进程的块ID

9. 性能指标示例

指标 说明
传输块数 2 本次传输了2个块
传输token数 6 从进程0传输2个块(共8个token存储空间,实际有用6个token:1个满块4个token+1个半块2个token)到进程1
传输数据量 2×1×4×2×2=32字节 K和V各16字节(float16类型)
负载均衡度 从13:6变为7:12 更接近理想的9.5:9.5

10. 应用场景说明

通过这个具体的数据流示例,我们可以看到KVShuffle算子:

  1. 解决了负载不均衡问题:将数据从负载高的进程传输到负载低的进程
  2. 保持了数据完整性:传输前后的数据内容完全一致
  3. 高效利用了内存:只传输必要的块,避免了不必要的数据移动
  4. 支持动态批处理:可以根据实际batch大小动态调整传输策略

这种数据流转机制特别适合于分布式训练中的KV缓存管理,可以有效提高训练效率和资源利用率。