problem: "发现{syncbn_num} 个SyncBatchNorm,这可能会导致python任务调度缓慢,设备之间通信频繁,降低训练性能。"
max_syncbn_num: 20
solutions:
- 使能batchnorm:
desc: "可以通过删除像'torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)'这样的代码,禁用SyncBatchNorm。"
- 使能高效的SyncBatchNorm:
desc: "用以下代码替换运行时环境中python脚本'torch_npu/utils/syncbatchnorm.py'的'forward'方法。"
efficient_code: |
@staticmethod
def forward(self, input_tensor, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
input_tensor = input_tensor.contiguous()
input_shape = input_tensor.shape
input_tensor_ = input_tensor.reshape(input_shape[0], input_shape[1], 1, -1)
sum_val, sum_square_val = torch.batch_norm_reduce(input_tensor_, eps)
count = torch.full((1,),
input_tensor.numel() // input_tensor.size(1),
dtype=sum_val.dtype,
device=sum_val.device)
num_channels = input_tensor.shape[1]
combined = torch.cat([sum_val, sum_square_val, count], dim=0)
combined_list = torch.empty((world_size,) + combined.shape, dtype=combined.dtype, device=combined.device)
dist.all_gather_togather(combined_list, combined, process_group, async_op=False)
sum_all, square_sum_all, count_all = torch.split(combined_list, num_channels, dim=1)
size = count_all.view(-1).sum()
if size == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
mean, invstd = torch.batch_norm_gather_stats_update(input_tensor,
sum_all,
square_sum_all,
running_mean,
running_var,
momentum,
eps,
count_all.view(-1))
self.save_for_backward(input_tensor, weight, mean, invstd, count_all.to(torch.int32))
self.process_group = process_group
out = torch.batch_norm_elemt(input_tensor, weight, bias, mean, invstd, eps)
return out