Quick Start of msProbe in the PyTorch Scenario

Overview

This document describes how to quickly get started with the precision debugging tool MindStudio Probe (msProbe) during the training development process in the PyTorch scenario.

For foundation models developed based on Ascend or migrated from GPU to Ascend NPU, training issues such as precision overflow/underflow, loss divergence, or non-convergent loss may arise. Since metrics such as the training loss cannot accurately locate the failed module, msProbe is recommended for rapid fault demarcation.

Usage Process

When using msProbe for model precision debugging, perform the following operations:

  1. Configuration check before training

    Identify the configuration differences that affect the precision in two environments.

  2. Training status monitoring

    Monitor exceptions that occur during computing, communication, and optimizer operations during training.

  3. Precision data collection

    Collect the forward and backward input and output data at the API or module level during training.

  4. Precision pre-check

    Scan API data to identify APIs with precision issues.

  5. Precision comparison

    Compare the API data on NPU with that in the benchmark environment to quickly locate precision issues.

This Quick Start guide focuses on rapid onboarding for precision data collection and precision comparison. For usage details of other tool functions, please refer to the relevant documentation.

Environment Setup

  1. Prepare a training server equipped with Ascend NPUs (such as Atlas A2 training servers) and install the NPU driver and firmware.

  2. Install the CANN Toolkit and OPS (operator package) of the matching version and configure CANN environment variables. The following uses CANN 8.5.0 as an example. For details, see CANN Software Installation Guide.

  3. Install the framework.

    In the following example, PyTorch 2.9.0, Python 3.12, AArch64-based system, and torchvision 0.24.0 are used as examples in the PyTorch training scenario. For details, see "Installing PyTorch > Method 1: Installation via a Binary Package" in Ascend Extension for PyTorch Installation Guide.

  4. Install msProbe by referring to msProbe Installation Guide.

    pip install mindstudio-probe --pre
    

Precision Data Collection

In this example, the ResNet-50 model and virtual data are used for training.

Prerequisites

Data Collection

  1. Prepare a training script.

    pytorch_main.py is used as an example. In the GPU and Ascend NPU environments, you can directly copy the sample code from PyTorch Precision Data Collection Code Sample. When training is performed in the GPU environment, add the following lines 24 and 25 to the script.

     24 import torch_npu
     25 from torch_npu.contrib import transfer_to_npu
    
  2. Create a configuration file.

    Create a config.json file in the directory where the training script is located. Copy the file content as follows:

    {
        "task": "statistics",
        "dump_path": "/home/dump",
        "rank": [],
        "step": [0,1],
        "level": "L1",
        "async_dump": false,
    
        "statistics": {
            "scope": [], 
            "list": [],
            "tensor_list": [],
            "data_mode": ["all"],
            "summary_mode": "statistics"
        }
    }
    
  3. Add the tool to the training script (pytorch_main.py) in the GPU and Ascend NPU environments.

    NOTE

    Ensure that the tool has been added to the sample code in PyTorch Precision Data Collection Code Sample. Below is where the tool interface is added to the script.

     26 # Import the data collection interface of the tool. Execute seed_all and instantiate PrecisionDebugger after the package import statements in the iterative training script.
     27 from msprobe.pytorch import PrecisionDebugger, seed_all
     28 seed_all(seed=1234, mode=True)    # Fix the random seed and enable deterministic computing to ensure that the model execution data is consistent each time.
    ...
    314 def train(train_loader, model, criterion, optimizer, epoch, device, args):
    ...
    331     end = time.time()
    
    332     debugger = PrecisionDebugger(config_path="./config.json")    # Instantiate PrecisionDebugger and load the dump configuration file.
    
            # Dataset iteration typically marks the start of model training.
    333     for i, (images, target) in enumerate(train_loader):
    334         debugger.start()  # Enable data dump.
    ...
    356
    357         # measure elapsed time
    358         batch_time.update(time.time() - end)
    359         end = time.time()
    360
    361         debugger.stop()    # Disable data dump. If you enable data dump again, the collected data will be recorded in the same step.
    362         debugger.step()    # End data dump. If you enable data dump again, the collected data will be recorded in the next step.
    

    NOTE

    Precision data occupies certain drive space. As a result, the server may be unavailable when the drive space is used up. The space required by precision data is closely related to model parameters, collection configurations, and number of collection iterations. You need to ensure that the available drive space in the directory where precision data is flushed is sufficient.

  4. Run the training script command. The tool collects precision data during model training.

    python pytorch_main.py -a resnet50 -b 32 --gpu 1 --dummy
    

    If the following information is displayed in the log, you can manually stop model training and view the collected data to save time.

    ****************************************************************************
    *                        msprobe ends successfully.                        *
    ****************************************************************************
    

Checking the Result

The following directory structure is displayed in the path specified by dump_path. Select data for analysis as required.

/home/dump/
├── step0
    └── proc3209296                  # If the training process does not contain rank information, it will be saved in proc{pid} in the single-rank training scenario or rank{id} in the multi-rank training scenario.
        ├── construct.json          # Hierarchical relationship information of modules. This file is empty in the current scenario.
        ├── dump.json                # Input and output statistics and overflow/underflow information of forward and backward APIs.
        └── stack.json               # API call stack information
├── step1
...

The collected data needs to be further analyzed for precision comparison. For details, see Precision Comparison.

Precision Comparison

Precision Comparison via compare Command

Prerequisites

Performing Comparison

  1. Prepare data.

    After dumping data in the GPU and Ascend NPU environments, copy the precision data dumped from the GPU environment to the Ascend NPU environment. Pay attention to the directory names specified by dump_path. dump_data_npu and dump_data_gpu are used as examples.

    The path of dump.json in the dump_data_gpu directory is /home/dump/dump_data_gpu/step0/rank/dump.json.

    The path of dump.json file in the dump_data_npu directory is /home/dump/dump_data_npu/step0/rank/dump.json.

  2. Perform the comparison.

    The command is as follows:

    msprobe compare -tp /home/dump/dump_data_gpu/step0/rank/dump.json -gp /home/dump/dump_data_npu/step0/rank/dump.json -o /home/accuracy_compare
    

    If the following information is displayed, the comparison is successful:

    ...
    The result excel file path is: /home/accuracy_compare/compare_result_{timestamp}.xlsx
    ************************************************************************************
    *                        msprobe compare ends successfully.                        *
    ************************************************************************************
    
  3. Analyze the comparison result file.

    compare generates the following file in /home/accuracy_compare:

    compare_result_{timestamp}.xlsx: This file lists the details about all APIs for precision comparison and comparison results. You can locate suspicious operators based on the comparison result (Result) and error message (Err_Message). However, each metric has a determination standard. Since each metric has its own evaluation criteria, make judgments based on actual circumstances.

    Examples:

    Figure 1 compare_result_1

    img

    Figure 2 compare_result_2

    img

    Figure 3 compare_result_3

    img

    For more information about the comparison result analysis, see Precision Comparison Result Analysis.

Graph Comparison in Hierarchical Visualization Mode

Prerequisites

  • Complete procedures listed in Environment Setup.

  • Complete procedures listed in Precision Data Collection to obtain the precision data in the GPU and Ascend NPU environments.

    For hierarchical visualization, the level parameter in the config.json file must be set to L0 or mix during a data dump. In this example, mix is used to re-collect precision data.

Performing Comparison

  1. Prepare data.

    After dumping data in the GPU and Ascend NPU environments, copy the precision data dumped from the GPU environment to the Ascend NPU environment. Pay attention to the directory names specified by dump_path. dump_data_npu and dump_data_gpu are used as examples.

    The path of dump_data_gpu is /home/dump/dump_data_gpu/step1.

    The path of dump_data_npu is /home/dump/dump_data_npu/step1.

  2. Perform graph comparison.

    msprobe graph_visualize -tp /home/dump/dump_data_npu/step1 -gp /home/dump/dump_data_gpu/step1 -o /home/dump/output
    

    After the comparison is complete, a .vis.db file is generated in the /home/dump/output directory.

  3. Start TensorBoard.

    tensorboard --logdir ./output --bind_all
    

    --The path specified by logdir is /home/dump/output in step 2.

    After the preceding command is executed, the following log is displayed:

    TensorBoard 2.20.0 at http://ubuntu:6006/ (Press CTRL+C to quit)
    

    Open a browser in the Windows environment and access http://ubuntu:6006/, where ubuntu should be replaced with the IP address of your server, for example, http://192.168.1.10:6006/.

    After the access is successful, the TensorBoard page is displayed, as shown in the following figure.

    Figure 1 Graph comparison in hierarchical visualization mode

    img

Code Sample

PyTorch Precision Data Collection Code Sample

import argparse
import os
import random
import shutil
import time
import warnings
from enum import Enum

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset

import torch_npu
from torch_npu.contrib import transfer_to_npu

from msprobe.pytorch import PrecisionDebugger, seed_all
seed_all(seed=1234, mode=True)  # Fix the random seed and enable deterministic computing to ensure that the model execution data is consistent each time.

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
                    help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--no-accel', action='store_true',
                    help='disables accelerator')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

best_acc1 = 0


def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    use_accel = not args.no_accel and torch.accelerator.is_available()

    if use_accel:
        device = torch.accelerator.current_accelerator()
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")

    if device.type =='cuda':
        ngpus_per_node = torch.accelerator.device_count()
        if ngpus_per_node == 1 and args.dist_backend == "nccl":
            warnings.warn("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'")
    else:
        ngpus_per_node = 1

    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    use_accel = not args.no_accel and torch.accelerator.is_available()

    if use_accel:
        if args.gpu is not None:
            torch.accelerator.set_device_index(args.gpu)
        device = torch.accelerator.current_accelerator()
    else:
        device = torch.device("cpu")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not use_accel:
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if device.type == 'cuda':
            if args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                model.cuda(device)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs of the current node.
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            else:
                model.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)
    elif device.type == 'cuda':
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.to(device)


    # define loss function (criterion), optimizer, and learning rate scheduler
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = f'{device.type}:{args.gpu}'
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    # Data loading code
    if args.dummy:
        print("=> Dummy data is used!")
        train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor())
        val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, device, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        scheduler.step()

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }, is_best)

def train(train_loader, model, criterion, optimizer, epoch, device, args):

    use_accel = not args.no_accel and torch.accelerator.is_available()

    batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE)
    data_time = AverageMeter('Data', use_accel, ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.NONE)
    top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.NONE)
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    debugger = PrecisionDebugger(config_path="./config.json")
    for i, (images, target) in enumerate(train_loader):
        debugger.start()
        # measure data loading time
        data_time.update(time.time() - end)

        # move data to the same device as model
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        debugger.stop()
        debugger.step()

        if i % args.print_freq == 0:
            progress.display(i + 1)


def validate(val_loader, model, criterion, args):

    use_accel = not args.no_accel and torch.accelerator.is_available()

    def run_validate(loader, base_progress=0):

        if use_accel:
            device = torch.accelerator.current_accelerator()
        else:
            device = torch.device("cpu")

        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(loader):
                i = base_progress + i
                if use_accel:
                    if args.gpu is not None and device.type=='cuda':
                        torch.accelerator.set_device_index(argps.gpu)
                        images = images.cuda(args.gpu, non_blocking=True)
                        target = target.cuda(args.gpu, non_blocking=True)
                    else:
                        images = images.to(device)
                        target = target.to(device)

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % args.print_freq == 0:
                    progress.display(i + 1)

    batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    run_validate(val_loader)
    if args.distributed:
        top1.all_reduce()
        top5.all_reduce()

    if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
        aux_val_dataset = Subset(val_loader.dataset,
                                 range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
        aux_val_loader = torch.utils.data.DataLoader(
            aux_val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        run_validate(aux_val_loader, len(val_loader))

    progress.display_summary()

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, use_accel, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.use_accel = use_accel
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):    
        if use_accel:
            device = torch.accelerator.current_accelerator()
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)

        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()