多卡并行
MindIE SD 提供多种并行策略来解决单卡显存不足和推理速度瓶颈的问题,不同策略从不同维度对计算和显存进行拆分:
- 张量并行(Tensor Parallel):沿权重矩阵的行或列切分,将矩阵计算分布到多卡,适合隐藏层维度较大的模型。
- 环状序列并行(Ring Sequence Parallel):沿序列维度切分 Q,以环状通信在设备间传递 KV,通过计算掩盖通信开销。
- Ulysses 序列并行(Ulysses Sequence Parallel):沿序列维度切分输入,通过 AlltoAll 在注意力头维度重组,各卡并行计算不同注意力头。
- CFG 并行(CFG Parallel):将正样本和负样本推理分发到不同设备并行执行,适合使用 Classifier-Free Guidance 的扩散模型。
各策略可以独立使用,也可以组合叠加,具体支持情况请参见 supported_matrix.md。
推荐方案:
- 张量并行(TP):可以有效降低显存,但通信开销较大,不推荐优先使用。
- Ulysses 序列并行(USP):通信开销小,推荐优先使用。约束:Ulysses 的并行度需要能被 FA 的 head num 整除。
- 环状序列并行(RSP):可以配合 Ulysses 使用,补充 Ulysses 无法被 head num 整除的部分。
- CFG 并行:通信开销小,当模型的 CFG 大于 1 时,推荐使用。
Tensor Parallel
随着模型规模的扩大,单卡显存容量无法满足大模型的需求。张量并行会将模型的张量计算(如矩阵乘法、卷积等)分散到多个设备上并行执行 ,从而降低单个设备的内存和计算负载。本章节以一次矩阵乘法为例,介绍张量并行的原理。
假如输入数据为X,参数为W,X的维度 = (b, s, h),W的维度 = (h, h'),一次矩阵乘法如下图所示。其中:
- b:batch_size,表示批次大小。
- s:sequence_length,表示输入序列的长度。
- h:hidden_size,表示每个token向量的维度。
- h':参数W的hidden_size。

优化方法分为以下两种:
-
按行切分:按照权重W的行切分,以N=2为例,将矩阵按照虚线切分。

下图展示了切分后的结果,从一个矩阵乘转换为两个矩阵乘,分别在不同的NPU上运算,通过卡间通信将各个结果进行加法运算得到完整结果。

-
按列切分:按照权重W的列切分,以N=2为例,将矩阵按照虚线切分。

下图展示了切分后的结果,从一个矩阵乘转换为两个矩阵乘,分别在不同的NPU上运算,通过卡间通信将各个结果进行拼接得到完整结果。

代码示例
以下示例展示了分布式初始化及张量并行的基本用法:
import os
import torch
import torch.distributed as dist
import torch_npu
# 1. 初始化分布式环境
dist.init_process_group(backend="hccl")
torch.npu.set_device(f"npu:{os.environ['LOCAL_RANK']}")
# 2. 定义原始线性层
linear = torch.nn.Linear(4096, 4096).npu()
x = torch.randn(1, 256, 4096, device="npu")
# 3. 按列切分:每个 rank 持有 W 的一半列
# 前向后通过 all-reduce 通信合并结果
world_size = dist.get_world_size()
rank = dist.get_rank()
with torch.no_grad():
# 切分权重:每个 rank 持有 W[:, h//world_size * rank : h//world_size * (rank+1)]
w_chunk = linear.weight.data.chunk(world_size, dim=0)[rank]
# 本地矩阵乘
local_out = x @ w_chunk.T
# all-reduce 合并各 rank 结果
dist.all_reduce(local_out)
print(f"Rank {rank} output shape: {local_out.shape}")
通信方式
列切分时各设备独立计算本地矩阵乘后通过 all-reduce 合并结果;行切分时各设备计算完整结果的分片,通过 all-gather 拼接完整输出。通信量与 hidden_size 成正比,设备间带宽充足时通信开销占比随模型增大而降低。
适用场景
适合隐藏层维度(hidden_size)较大的模型,当单卡显存不足以容纳完整权重矩阵时尤为有效。TP 依赖高带宽卡间通信(如 HCCS),建议仅在单机多卡范围内使用,TP degree 不应超过单机 NPU 数量。
Ring Sequence Parallel
原理
将 Q 切分到各设备,计算时各设备计算完当前 KV 对后,将持有的 KV 对发送给下一设备,并继续接收前一设备的 KV 对,形成一个环状的通信结构。当卡间通信时间 ≤ 计算时间时,通信开销可被计算掩盖。

通信方式
采用 P2P(点对点)通信。设备 i 完成当前步的注意力计算后,将自己的 KV 发送给设备 i+1,同时从设备 i-1 接收新的 KV。经过 N 轮通信后,所有设备完成全部序列位置的注意力计算。当计算耗时大于通信耗时(即序列较长、head_dim 较大)时,通信开销可被计算完全掩盖。
适用场景
适用于序列长度远大于 head_dim 的长序列场景。当设备间 P2P 带宽充裕(如同机 NPU)时效果最佳。不适用于短序列场景,此时通信开销占比过高。
使用示例
import torch
import torch.distributed as dist
dist.init_process_group(backend="hccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
batch, seqlen, head, dim = 1, 4096, 8, 128
seqlen_chunk = seqlen // world_size
# 各设备持有自己的 Q/K/V 分片
q_chunk = torch.randn(batch, seqlen_chunk, head, dim).npu()
k_chunk = torch.randn(batch, seqlen_chunk, head, dim).npu()
v_chunk = torch.randn(batch, seqlen_chunk, head, dim).npu()
def local_attn(q, k, v):
score = (q @ k.transpose(-2, -1)) / (dim ** 0.5)
return score.softmax(dim=-1) @ v
# 第一轮:计算自身的 KV
out = local_attn(q_chunk, k_chunk, v_chunk)
# 后续轮次:环形传递 KV
for step in range(1, world_size):
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1 + world_size) % world_size
k_recv = torch.empty_like(k_chunk)
v_recv = torch.empty_like(v_chunk)
dist.send_recv(k_chunk, k_recv, send=send_rank, recv=recv_rank)
dist.send_recv(v_chunk, v_recv, send=send_rank, recv=recv_rank)
k_chunk, v_chunk = k_recv, v_recv
out += local_attn(q_chunk, k_chunk, v_chunk)
Ulysses Sequence Parallel
原理
把每个样本在序列维度上进行分割,分配给不同设备。在进行注意力计算之前,对分割后的Q、K和V进行AlltoAll。各设备和其他所有设备交换信息,每个设备都能收到注意力头的非重叠子集。各设备并行计算不同注意力头,计算后再通过AlltoAll收集计算结果。

通信方式
核心采用 AlltoAll 集体通信。每个设备在注意力计算前将自己的序列分块发送给所有其他设备,同时接收其他设备的序列分块,在注意力头维度完成数据重组。计算完成后再次通过 AlltoAll 将结果按序列维度收集回来。当序列长度和设备数同比例增加时,单设备通信量保持恒定(理论分析见 DeepSpeed Ulysses 论文)。
适用场景
适合注意力头数较多、AlltoAll 带宽充裕的场景。相比 RSP,Ulysses 在短序列多头场景下效率更高,特别适用于序列长度与 hidden_size 均较大的情况。
-
未使用Ulysses Sequence Parallel样例:
import torch import torch_npu from mindiesd import attention_forward torch.npu.set_device(0) batch, seqlen, hiddensize = 1, 4096, 512 head = 8 x = torch.randn(batch, seqlen, hiddensize, dtype=torch.float16).npu() x = x.reshape(batch, seqlen, head, -1) out = attention_forward(x, x, x, opt_mode="manual", op_type="prompt_flash_attn", layout="BSND") x = out.reshape(batch, seqlen, hiddensize) -
使用Ulysses Sequence Parallel样例:
import os import torch import torch.distributed as dist import torch_npu from mindiesd import attention_forward batch, seqlen, hiddensize = 1, 4096, 512 head = 8 x = torch.randn(batch, seqlen, hiddensize, dtype=torch.float16).npu() def init_distributed( world_size: int = -1, rank: int = -1, distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "hccl" ): dist.init_process_group( backend=backend, init_method=distributed_init_method, world_size=world_size, rank=rank, ) torch.npu.set_device(f"npu:{os.environ['LOCAL_RANK']}") # 1、初始化分布式环境 world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["LOCAL_RANK"]) init_distributed(world_size, rank) # 2、对seqlen维度按照world_size进行切分 x = torch.chunk(x, world_size, dim=1)[rank] # 序列切分 seqlen_chunk = x.shape[1] x = x.reshape(batch, seqlen_chunk, head, -1) # 3、调用all_to_all使能ulysess并行 in_list = [t.contiguous() for t in torch.tensor_split(x, world_size, 2)] output_list = [torch.empty_like(in_list[0]) for _ in range(world_size)] dist.all_to_all(output_list, in_list) x = torch.cat(output_list, dim=1).contiguous() att_out = attention_forward(x, x, x, opt_mode="manual", op_type="prompt_flash_attn", layout="BSND") in_list = [t.contiguous() for t in torch.tensor_split(att_out, world_size, 1)] output_list = [torch.empty_like(in_list[0]) for _ in range(world_size)] dist.all_to_all(output_list, in_list) x = torch.cat(output_list, dim=2).contiguous() x = x.reshape(batch, seqlen_chunk, hiddensize) # 4、对seqlen维度进行all_gather操作 output_list = [torch.empty_like(x) for _ in range(world_size)] dist.all_gather(output_list, x) x = torch.cat(output_list, dim=1)
FA_Power_Cap 技术
FA_Power_Cap 技术可以通过 --comm_type 0/1/2 在 baseline、插入通信和块级注意力三种路径之间切换。详细手动接入步骤请参见 FA_Power_Cap 技术。
CFG Parallel
原理
对于一个带噪声的图像和文本提示词,模型需要执行两次推理,分别计算正样本和负样本,该计算过程为串行过程,导致每个去噪步骤都需要两次前向传播,增加了推理时间。CFG 并行可以将正样本和负样本分别在不同的设备上计算,将两次串行计算合并为一次并行计算,显著提升推理速度。

通信方式
正负样本计算完全独立,各设备无需中间通信。计算完成后通过 all-gather 收集两个结果,或者在各自设备上直接使用自身计算结果。通信量极小,近似为零开销并行。
适用场景
适用于使用 CFG(guidance_scale > 1)的扩散模型推理场景,且至少拥有 2 卡富余设备。设备越多,加速越接近 2×。如果设备紧张,优先将资源分配给 TP 或序列并行。
使用示例
import os
import torch
import torch.distributed as dist
dist.init_process_group(backend="hccl")
torch.npu.set_device(f"npu:{os.environ['LOCAL_RANK']}")
rank = dist.get_rank()
guidance_scale = 7.5
# rank 0 算负样本(unconditioned),rank 1 算正样本(conditioned)
if rank == 0:
noise_pred_uncond = model(latent, timestep, uncond_embed)
output = noise_pred_uncond
elif rank == 1:
noise_pred_cond = model(latent, timestep, cond_embed)
output = noise_pred_cond
# all-gather 交换结果
output_list = [torch.empty_like(output) for _ in range(world_size)]
dist.all_gather(output_list, output)
# CFG 融合
noise_pred = output_list[0] + guidance_scale * (output_list[1] - output_list[0])
补充内容 —— CFG 融合
CFG 融合是另一种优化思路:不在设备间并行,而是在单设备内将正样本和负样本在 batch 维度拼接后送入模型,使一次前向计算同时产出两个结果,算子调用次数减半。
与 CFG 并行相比,CFG 融合不消耗额外的设备资源,适合设备数有限但希望降低单次推理延迟的场景。两者可根据硬件条件选择使用。
