Quick Start of msProbe in the PyTorch Scenario
Overview
This document describes how to quickly get started with the precision debugging tool MindStudio Probe (msProbe) during the training development process in the PyTorch scenario.
For foundation models developed based on Ascend or migrated from GPU to Ascend NPU, training issues such as precision overflow/underflow, loss divergence, or non-convergent loss may arise. Since metrics such as the training loss cannot accurately locate the failed module, msProbe is recommended for rapid fault demarcation.
Usage Process
When using msProbe for model precision debugging, perform the following operations:
-
Configuration check before training
Identify the configuration differences that affect the precision in two environments.
-
Training status monitoring
Monitor exceptions that occur during computing, communication, and optimizer operations during training.
-
Precision data collection
Collect the forward and backward input and output data at the API or module level during training.
-
Precision pre-check
Scan API data to identify APIs with precision issues.
-
Precision comparison
Compare the API data on NPU with that in the benchmark environment to quickly locate precision issues.
This Quick Start guide focuses on rapid onboarding for precision data collection and precision comparison. For usage details of other tool functions, please refer to the relevant documentation.
Environment Setup
-
Prepare a training server equipped with Ascend NPUs (such as Atlas A2 training servers) and install the NPU driver and firmware.
-
Install the CANN Toolkit and OPS (operator package) of the matching version and configure CANN environment variables. The following uses CANN 8.5.0 as an example. For details, see CANN Software Installation Guide.
-
Install the framework.
In the following example, PyTorch 2.9.0, Python 3.12, AArch64-based system, and torchvision 0.24.0 are used as examples in the PyTorch training scenario. For details, see "Installing PyTorch > Method 1: Installation via a Binary Package" in Ascend Extension for PyTorch Installation Guide.
-
Install msProbe by referring to msProbe Installation Guide.
pip install mindstudio-probe --pre
Precision Data Collection
In this example, the ResNet-50 model and virtual data are used for training.
Prerequisites
- Complete procedures listed in Environment Setup.
Data Collection
-
Prepare a training script.
pytorch_main.pyis used as an example. In the GPU and Ascend NPU environments, you can directly copy the sample code from PyTorch Precision Data Collection Code Sample. When training is performed in the GPU environment, add the following lines 24 and 25 to the script.24 import torch_npu 25 from torch_npu.contrib import transfer_to_npu -
Create a configuration file.
Create a
config.jsonfile in the directory where the training script is located. Copy the file content as follows:{ "task": "statistics", "dump_path": "/home/dump", "rank": [], "step": [0,1], "level": "L1", "async_dump": false, "statistics": { "scope": [], "list": [], "tensor_list": [], "data_mode": ["all"], "summary_mode": "statistics" } } -
Add the tool to the training script (
pytorch_main.py) in the GPU and Ascend NPU environments.NOTE
Ensure that the tool has been added to the sample code in PyTorch Precision Data Collection Code Sample. Below is where the tool interface is added to the script.
26 # Import the data collection interface of the tool. Execute seed_all and instantiate PrecisionDebugger after the package import statements in the iterative training script. 27 from msprobe.pytorch import PrecisionDebugger, seed_all 28 seed_all(seed=1234, mode=True) # Fix the random seed and enable deterministic computing to ensure that the model execution data is consistent each time. ... 314 def train(train_loader, model, criterion, optimizer, epoch, device, args): ... 331 end = time.time() 332 debugger = PrecisionDebugger(config_path="./config.json") # Instantiate PrecisionDebugger and load the dump configuration file. # Dataset iteration typically marks the start of model training. 333 for i, (images, target) in enumerate(train_loader): 334 debugger.start() # Enable data dump. ... 356 357 # measure elapsed time 358 batch_time.update(time.time() - end) 359 end = time.time() 360 361 debugger.stop() # Disable data dump. If you enable data dump again, the collected data will be recorded in the same step. 362 debugger.step() # End data dump. If you enable data dump again, the collected data will be recorded in the next step.NOTE
Precision data occupies certain drive space. As a result, the server may be unavailable when the drive space is used up. The space required by precision data is closely related to model parameters, collection configurations, and number of collection iterations. You need to ensure that the available drive space in the directory where precision data is flushed is sufficient.
-
Run the training script command. The tool collects precision data during model training.
python pytorch_main.py -a resnet50 -b 32 --gpu 1 --dummyIf the following information is displayed in the log, you can manually stop model training and view the collected data to save time.
**************************************************************************** * msprobe ends successfully. * ****************************************************************************
Checking the Result
The following directory structure is displayed in the path specified by dump_path. Select data for analysis as required.
/home/dump/
├── step0
└── proc3209296 # If the training process does not contain rank information, it will be saved in proc{pid} in the single-rank training scenario or rank{id} in the multi-rank training scenario.
├── construct.json # Hierarchical relationship information of modules. This file is empty in the current scenario.
├── dump.json # Input and output statistics and overflow/underflow information of forward and backward APIs.
└── stack.json # API call stack information
├── step1
...
The collected data needs to be further analyzed for precision comparison. For details, see Precision Comparison.
Precision Comparison
Precision Comparison via compare Command
Prerequisites
- Complete procedures listed in Environment Setup.
- Complete procedures listed in Precision Data Collection to obtain the precision data in the GPU and Ascend NPU environments.
Performing Comparison
-
Prepare data.
After dumping data in the GPU and Ascend NPU environments, copy the precision data dumped from the GPU environment to the Ascend NPU environment. Pay attention to the directory names specified by
dump_path.dump_data_npuanddump_data_gpuare used as examples.The path of
dump.jsonin thedump_data_gpudirectory is/home/dump/dump_data_gpu/step0/rank/dump.json.The path of
dump.jsonfile in thedump_data_npudirectory is/home/dump/dump_data_npu/step0/rank/dump.json. -
Perform the comparison.
The command is as follows:
msprobe compare -tp /home/dump/dump_data_gpu/step0/rank/dump.json -gp /home/dump/dump_data_npu/step0/rank/dump.json -o /home/accuracy_compareIf the following information is displayed, the comparison is successful:
... The result excel file path is: /home/accuracy_compare/compare_result_{timestamp}.xlsx ************************************************************************************ * msprobe compare ends successfully. * ************************************************************************************ -
Analyze the comparison result file.
comparegenerates the following file in/home/accuracy_compare:compare_result_{timestamp}.xlsx: This file lists the details about all APIs for precision comparison and comparison results. You can locate suspicious operators based on the comparison result (Result) and error message (Err_Message). However, each metric has a determination standard. Since each metric has its own evaluation criteria, make judgments based on actual circumstances.Examples:
Figure 1 compare_result_1

Figure 2 compare_result_2

Figure 3 compare_result_3

For more information about the comparison result analysis, see Precision Comparison Result Analysis.
Graph Comparison in Hierarchical Visualization Mode
Prerequisites
-
Complete procedures listed in Environment Setup.
-
Complete procedures listed in Precision Data Collection to obtain the precision data in the GPU and Ascend NPU environments.
For hierarchical visualization, the
levelparameter in theconfig.jsonfile must be set toL0ormixduring a data dump. In this example,mixis used to re-collect precision data.
Performing Comparison
-
Prepare data.
After dumping data in the GPU and Ascend NPU environments, copy the precision data dumped from the GPU environment to the Ascend NPU environment. Pay attention to the directory names specified by
dump_path.dump_data_npuanddump_data_gpuare used as examples.The path of
dump_data_gpuis/home/dump/dump_data_gpu/step1.The path of
dump_data_npuis/home/dump/dump_data_npu/step1. -
Perform graph comparison.
msprobe graph_visualize -tp /home/dump/dump_data_npu/step1 -gp /home/dump/dump_data_gpu/step1 -o /home/dump/outputAfter the comparison is complete, a
.vis.dbfile is generated in the/home/dump/outputdirectory. -
Start TensorBoard.
tensorboard --logdir ./output --bind_all--The path specified by
logdiris/home/dump/outputin step 2.After the preceding command is executed, the following log is displayed:
TensorBoard 2.20.0 at http://ubuntu:6006/ (Press CTRL+C to quit)Open a browser in the Windows environment and access
http://ubuntu:6006/, whereubuntushould be replaced with the IP address of your server, for example,http://192.168.1.10:6006/.After the access is successful, the TensorBoard page is displayed, as shown in the following figure.
Figure 1 Graph comparison in hierarchical visualization mode

Code Sample
PyTorch Precision Data Collection Code Sample
import argparse
import os
import random
import shutil
import time
import warnings
from enum import Enum
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset
import torch_npu
from torch_npu.contrib import transfer_to_npu
from msprobe.pytorch import PrecisionDebugger, seed_all
seed_all(seed=1234, mode=True) # Fix the random seed and enable deterministic computing to ensure that the model execution data is consistent each time.
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--no-accel', action='store_true',
help='disables accelerator')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
best_acc1 = 0
def main():
args = parser.parse_args()
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
use_accel = not args.no_accel and torch.accelerator.is_available()
if use_accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
print(f"Using device: {device}")
if device.type =='cuda':
ngpus_per_node = torch.accelerator.device_count()
if ngpus_per_node == 1 and args.dist_backend == "nccl":
warnings.warn("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'")
else:
ngpus_per_node = 1
if args.multiprocessing_distributed:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
args.world_size = ngpus_per_node * args.world_size
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
# Simply call main_worker function
main_worker(args.gpu, ngpus_per_node, args)
def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu
use_accel = not args.no_accel and torch.accelerator.is_available()
if use_accel:
if args.gpu is not None:
torch.accelerator.set_device_index(args.gpu)
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if not use_accel:
print('using CPU, this will be slow')
elif args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if device.type == 'cuda':
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(device)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif device.type == 'cuda':
# DataParallel will divide and allocate batch_size to all available GPUs
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
else:
model.to(device)
# define loss function (criterion), optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = f'{device.type}:{args.gpu}'
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Data loading code
if args.dummy:
print("=> Dummy data is used!")
train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor())
val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor())
else:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
train_sampler = None
val_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
if args.evaluate:
validate(val_loader, model, criterion, args)
return
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, device, args)
# evaluate on validation set
acc1 = validate(val_loader, model, criterion, args)
scheduler.step()
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
'scheduler' : scheduler.state_dict()
}, is_best)
def train(train_loader, model, criterion, optimizer, epoch, device, args):
use_accel = not args.no_accel and torch.accelerator.is_available()
batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE)
data_time = AverageMeter('Data', use_accel, ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.NONE)
top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.NONE)
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
debugger = PrecisionDebugger(config_path="./config.json")
for i, (images, target) in enumerate(train_loader):
debugger.start()
# measure data loading time
data_time.update(time.time() - end)
# move data to the same device as model
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
debugger.stop()
debugger.step()
if i % args.print_freq == 0:
progress.display(i + 1)
def validate(val_loader, model, criterion, args):
use_accel = not args.no_accel and torch.accelerator.is_available()
def run_validate(loader, base_progress=0):
if use_accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(loader):
i = base_progress + i
if use_accel:
if args.gpu is not None and device.type=='cuda':
torch.accelerator.set_device_index(argps.gpu)
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
else:
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i + 1)
batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
run_validate(val_loader)
if args.distributed:
top1.all_reduce()
top5.all_reduce()
if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
aux_val_dataset = Subset(val_loader.dataset,
range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
aux_val_loader = torch.utils.data.DataLoader(
aux_val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
run_validate(aux_val_loader, len(val_loader))
progress.display_summary()
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, use_accel, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.use_accel = use_accel
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
if use_accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()