非对齐Ring长序列并行
背景与挑战
随着生成式AI和科研模型领域的发展,长序列训练变得越来越重要。然而,传统的Ring序列并行设计要求序列长度(sequence length)必须能够被长序列并行大小(Context Parallel size, CP size)整除。这在处理动态或不规则输入时带来了限制,特别是在多模态应用中,输入数据的序列长度可能无法预测且经常变化。因此,需要一种机制来支持这些非对齐情况下的操作,以适应更广泛的应用场景。
解决方案
为了解决传统Ring序列并行设计在处理非对齐序列长度时的局限性,“非对齐 Ring”机制通过建立形状协商协议,通信前先交换有效长度信息,并通过get_unaligned_cp_shapes隔离变化,用以获取当前rank长度和目标rank长度,在分块计算和通信时传递非均匀的序列,实现非均匀切分的RingAttention计算。
使用场景
“非对齐 Ring”功能适用于以下几种典型场景:
- 多模态学习:当处理图像、视频、文本等多种类型的数据时,由于不同类型数据的序列长度差异较大,难以统一到固定的CP size。
- 实时数据分析:在处理流数据时,数据到达的时间不确定,导致每次处理的序列长度也可能不同。
- 个性化推荐系统:用户行为数据的序列长度通常各不相同,这种情况下也需要支持非对齐的操作。
使用方法
为了使用“非对齐 Ring”功能,用户需要在调用ringattn_context_parallel接口时,如下面示例代码所示,传入shapes参数。
# 示例代码
output = ringattn_context_parallel(q, k, v, head_num, cp_para, softmax_scale, attn_mask, dropout_p, shapes=shapes)
get_unaligned_cp_shapes是非对齐Ring序列并行中的重要函数,shapes参数的主要作用是获取序列的非对齐切分后的子序列长度信息,函数中通过 shapes[block_id] 和 shapes[next_block_id] 的方式访问元素,所以shapes参数可以是任何支持通过 [] 操作符进行索引访问的数据结构,包括但不限于:
- 列表(list):如 [100, 100, 20]
- 元组(tuple):如 (100, 100, 20)
- 字典(dict):如 {0: 100, 1: 100, 2: 20}
get_unaligned_cp_shapes函数最终会返回一个包含两个元素的列表:[shapes[block_id], shapes[next_block_id]],这两个元素分别对应 block_id 和 next_block_id 索引处的值。
def get_unaligned_cp_shapes(shapes, block_id, next_block_id):
if shapes is None:
return None
unaligned_cp_shapes = [shapes[block_id], shapes[next_block_id]]
return unaligned_cp_shapes
使用效果
通过引入“非对齐 Ring”,系统提升了对不同输入长度的适应能力。这不仅解决了传统Ring序列并行在处理动态或不规则输入序列时遇到的问题,而且保持了良好的扩展能力。
注意事项
- 非对齐Ring长序列并行当前仅支持--attention-mask-type为general的场景。