problem: "Found {syncbn_num} SyncBatchNorm, which can lead to slow python task dispatch and frequent communication between devices and finally reducing training efficiency."
max_syncbn_num: 20
solutions:
  - enable batchnorm:
      desc: "disable SyncBatchNorm by remove the code like 'torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)' if possible."
  - enable efficient SyncBatchNorm:
      desc: "replace the 'forward' method of python script 'torch_npu/utils/syncbatchnorm.py' in your runtime environment."
      efficient_code: |
         @staticmethod
         def forward(self, input_tensor, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
             input_tensor = input_tensor.contiguous()
             input_shape = input_tensor.shape
             input_tensor_ = input_tensor.reshape(input_shape[0], input_shape[1], 1, -1)
             sum_val, sum_square_val = torch.batch_norm_reduce(input_tensor_, eps)

             count = torch.full((1,),
                                input_tensor.numel() // input_tensor.size(1),
                                dtype=sum_val.dtype,
                                device=sum_val.device)

             num_channels = input_tensor.shape[1]
             combined = torch.cat([sum_val, sum_square_val, count], dim=0)
             combined_list = torch.empty((world_size,) + combined.shape, dtype=combined.dtype, device=combined.device)
             dist.all_gather_togather(combined_list, combined, process_group, async_op=False)
             sum_all, square_sum_all, count_all = torch.split(combined_list, num_channels, dim=1)
             size = count_all.view(-1).sum()
             if size == 1:
                 raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

             mean, invstd = torch.batch_norm_gather_stats_update(input_tensor,
                                                                 sum_all,
                                                                 square_sum_all,
                                                                 running_mean,
                                                                 running_var,
                                                                 momentum,
                                                                 eps,
                                                                 count_all.view(-1))
             self.save_for_backward(input_tensor, weight, mean, invstd, count_all.to(torch.int32))
             self.process_group = process_group
             out = torch.batch_norm_elemt(input_tensor, weight, bias, mean, invstd, eps)
             return out