import os
import time
import socket
import logging
import torch
from . import metric
from . import callback
import pdb
from apex import amp
"""
Static Model
"""
class static_model(object):
def __init__(self,
net,
criterion=None,
model_prefix='',
single_checkpoint=False,
**kwargs):
if kwargs:
logging.warning("Unknown kwargs: {}".format(kwargs))
self.net = net
self.model_prefix = model_prefix
self.criterion = criterion
self.single_checkpoint = single_checkpoint
if self.single_checkpoint and torch.distributed.is_initialized:
logging.warning(">> only keeping the checkpoint for rank=0 node, making sure you are using shared filesystem")
self.is_distributed = False
self.net_without_ddp = net
print("======================not self.single_checkpoint====================")
print(not self.single_checkpoint)
def set_net_without_ddp(self, net_without_ddp):
self.net_without_ddp = net_without_ddp
def set_distributed_mode(self):
self.is_distributed = True
def load_state(self, state_dict, strict=False):
logging.info(f"==================load param (rank:{torch.distributed.get_rank() if self.is_distributed else 0})==================")
if strict:
logging.info("==================strictly load_state...==================")
self.net_without_ddp.load_state_dict(state_dict=state_dict)
else:
logging.info("==================foreach name & param copy (customized partialy load function)...==================")
net_state_keys = list(self.net_without_ddp.state_dict().keys())
for name, param in state_dict.items():
if name in self.net_without_ddp.state_dict().keys():
dst_param_shape = self.net_without_ddp.state_dict()[name].shape
if param.shape == dst_param_shape:
self.net_without_ddp.state_dict()[name].copy_(param.view(dst_param_shape))
net_state_keys.remove(name)
if net_state_keys:
logging.warning(">> Failed to load: {}".format(net_state_keys))
return False
return True
def get_checkpoint_path(self, epoch, suffix=''):
assert self.model_prefix, "model_prefix undefined!"
if torch.distributed.is_initialized and not self.single_checkpoint:
pth_mark = str(torch.distributed.get_rank())
checkpoint_path = "{}_rank-{}_ep-{:04d}{}.pth".format(self.model_prefix, pth_mark, epoch, suffix)
else:
checkpoint_path = "{}_ep-{:04d}{}.pth".format(self.model_prefix, epoch, suffix)
return checkpoint_path
def load_checkpoint(self, epoch, optimizer=None, suffix=''):
load_path = self.get_checkpoint_path(epoch, suffix)
assert os.path.exists(load_path), "Failed to load: {} (file not exist)".format(load_path)
checkpoint = torch.load(load_path)
all_params_matched = self.load_state(checkpoint['state_dict'], strict=False)
if optimizer:
if 'optimizer' in checkpoint.keys() and all_params_matched:
optimizer.load_state_dict(checkpoint['optimizer'])
logging.info("Model & Optimizer states are resumed from: `{}'".format(load_path))
else:
logging.warning(">> Failed to load optimizer state from: `{}'".format(load_path))
else:
logging.info("Only model state resumed from: `{}'".format(load_path))
if 'epoch' in checkpoint.keys():
if checkpoint['epoch'] != epoch:
logging.warning(">> Epoch information inconsistant: {} vs {}".format(checkpoint['epoch'], epoch))
def save_checkpoint(self, epoch, optimizer_state=None, suffix=''):
if self.is_distributed and self.single_checkpoint and torch.distributed.is_initialized and torch.distributed.get_rank() != 0:
logging.info("Checkpoint saved by node 0 (rank=0). Now is rank: {}".format(torch.distributed.get_rank()))
return
save_path = self.get_checkpoint_path(epoch, suffix)
save_folder = os.path.dirname(save_path)
logging.info(f"===================save folder (rank:{torch.distributed.get_rank() if self.is_distributed else 0})===================")
logging.info(f"save_folder:{save_folder}, save_path:{save_path}")
if not os.path.exists(save_folder):
logging.debug("mkdir {}".format(save_folder))
os.makedirs(save_folder)
if not optimizer_state:
torch.save({'epoch': epoch,
'state_dict': self.net_without_ddp.state_dict()},
save_path)
logging.info("Checkpoint (only model) saved to: {}".format(save_path))
else:
torch.save({'epoch': epoch,
'state_dict': self.net_without_ddp.state_dict(),
'optimizer': optimizer_state},
save_path)
logging.info("Checkpoint (model & optimizer) saved to: {}".format(save_path))
def forward(self, data, target=None):
""" typical forward function with:
single output and single loss
"""
input_var = data.float().npu()
target_var = target.npu()
output = self.net(input_var)
if hasattr(self, 'criterion') and self.criterion is not None \
and target is not None:
loss = self.criterion(output, target_var)
else:
loss = None
return [output], [loss]
"""
Typical model that able to update itself
"""
class model(static_model):
"""
parameter for initializa fit
"""
def __init__(self,
net,
criterion,
model_prefix='',
step_callback=None,
step_callback_freq=20,
epoch_callback=None,
save_checkpoint_freq=1,
opt_batch_size=None,
distributed=False,
args=None,
**kwargs):
super(model, self).__init__(net, criterion=criterion,
model_prefix=model_prefix,
**kwargs)
self.callback_kwargs = {'lr': None,
'epoch': None,
'batch': None,
'batch_elapse': None,
'update_elapse': None,
'epoch_elapse': None,
'namevals': None,
'optimizer_dict': None,}
if not step_callback:
step_callback = callback.CallbackList(callback.SpeedMonitor(),
callback.MetricPrinter())
if not epoch_callback:
epoch_callback = (lambda **kwargs: None)
self.step_callback = step_callback
self.step_callback_freq = step_callback_freq
self.epoch_callback = epoch_callback
self.save_checkpoint_freq = save_checkpoint_freq
self.batch_size=opt_batch_size
self.distributed = distributed
self.args = args
"""
Inorder to customize the callback function,
you will have to overwrite the functions below
"""
def step_end_callback(self):
if self.distributed and torch.distributed.is_initialized and torch.distributed.get_rank() != 0:
return
self.step_callback(**(self.callback_kwargs))
def epoch_end_callback(self):
if self.distributed and torch.distributed.is_initialized and torch.distributed.get_rank() != 0:
return
self.epoch_callback(**(self.callback_kwargs))
if self.callback_kwargs['epoch_elapse'] is not None:
logging.info("Epoch [{:d}] time cost: {:.2f} sec ({:.2f} h)".format(
self.callback_kwargs['epoch'],
self.callback_kwargs['epoch_elapse'],
self.callback_kwargs['epoch_elapse']/3600.))
if self.callback_kwargs['epoch'] == 0 \
or ((self.callback_kwargs['epoch']+1) % self.save_checkpoint_freq) == 0:
self.save_checkpoint(epoch=self.callback_kwargs['epoch']+1,
optimizer_state=self.callback_kwargs['optimizer_dict'])
"""
Optimization
"""
def adjust_learning_rate(self, lr, optimizer):
for param_group in optimizer.param_groups:
if 'lr_mult' in param_group:
lr_mult = param_group['lr_mult']
else:
lr_mult = 1.0
param_group['lr'] = lr * lr_mult
return lr
"""
parameters for deploy training, e.g. device, dataset
"""
def fit(self, train_iter, optimizer, lr_scheduler,
eval_iter=None,
metrics=metric.Accuracy(topk=1),
epoch_start=0,
epoch_end=10000,
precise_bn=False,
precise_bn_steps=500,
epoch_div_factor=1,
train_sampler=None,
**kwargs):
"""
checking
"""
if kwargs:
logging.warning("Unknown kwargs: {}".format(kwargs))
assert torch.npu.is_available(), "only support GPU version"
device_num = len(self.args.gpus.split(","))
"""
setup iterator
"""
precise_bn_steps = 0 if not precise_bn else precise_bn_steps
epoch_freeze_step = int(round(0.2*precise_bn_steps))
epoch_train_steps = int(train_iter.batch_sampler.__len__() / epoch_div_factor)
epoch_eval_steps = int(eval_iter.batch_sampler.__len__() / epoch_div_factor)
if (train_iter.batch_sampler.__len__() - epoch_train_steps) > precise_bn_steps:
epoch_term_steps = epoch_train_steps + precise_bn_steps
else:
epoch_term_steps = epoch_train_steps
epoch_train_steps = epoch_train_steps - precise_bn_steps
logging.warning(">> using the last {} iter for computing the precise bathnorm.")
"""
start the main loop
"""
for i_epoch in range(epoch_start, epoch_end):
if train_sampler is not None:
train_sampler.set_epoch(i_epoch)
self.callback_kwargs['epoch'] = i_epoch
epoch_start_time = time.time()
metrics.reset()
self.net.train()
sum_batch_inst = 0
sum_batch_elapse = 0.
sum_update_elapse = 0
batch_start_time = time.time()
if self.args.master_node:
logging.info("Start epoch {:d}, iter stride {:d}, train steps {:d}, eval steps: {:d} (with {} GPUs/NPUs node: {})".format( \
i_epoch, epoch_div_factor, epoch_train_steps, epoch_eval_steps, device_num, self.args.gpus))
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
for i_batch, (data, target) in enumerate(train_iter):
if i_batch == 5:
end = time.time()
if self.args.master_node:
logging.info(f"=============Start compute FPS=============")
if i_batch >= epoch_term_steps:
break
if precise_bn and i_batch == epoch_train_steps:
if self.args.master_node:
logging.info("Compute precise batchnorm: {} to {}.".format(epoch_train_steps, epoch_term_steps))
self.save_checkpoint(epoch=i_epoch+1, optimizer_state=optimizer.state_dict())
while not os.path.exists(self.get_checkpoint_path(epoch=i_epoch+1)):
print("sleep 1 sec...")
time.sleep(1)
time.sleep(5)
self.load_checkpoint(epoch=i_epoch+1)
metrics.reset()
self.callback_kwargs['batch'] = i_batch
update_start_time = time.time()
if self.args.prof and i_epoch == 0:
with torch.autograd.profiler.profile(use_npu=True) as prof:
outputs, losses = self.forward(data, target)
if self.args.apex:
optimizer.zero_grad()
for loss in losses:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update())
optimizer.step()
else:
if i_batch < epoch_train_steps:
optimizer.zero_grad()
for loss in losses: loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update())
optimizer.step()
elif i_batch < (epoch_term_steps - epoch_freeze_step):
optimizer.zero_grad()
for loss in losses: loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.get_lr())
optimizer.step(visiable=["precise.bn"])
else:
pass
logging.info(">>>prof: "+prof.key_averages().table(sort_by="self_cpu_time_total"))
prof.export_chrome_trace("output.prof")
else:
outputs, losses = self.forward(data, target)
if self.args.apex:
optimizer.zero_grad()
for loss in losses:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update())
optimizer.step()
else:
if i_batch < epoch_train_steps:
optimizer.zero_grad()
for loss in losses: loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update())
optimizer.step()
elif i_batch < (epoch_term_steps - epoch_freeze_step):
optimizer.zero_grad()
for loss in losses: loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.get_lr())
optimizer.step(visiable=["precise.bn"])
else:
pass
if i_batch >= 5:
batch_time.update(time.time() - end)
end = time.time()
self.callback_kwargs['lr'] = lr_scheduler.get_lr()
metrics.update([output.data.cpu() for output in outputs],
target.cpu(),
[loss.data.cpu() for loss in losses])
sum_batch_elapse += time.time() - batch_start_time
sum_update_elapse += time.time() - update_start_time
sum_batch_inst += 1
if (i_batch % self.step_callback_freq) == 0:
name_value_prefix = 'tr-' if i_batch < epoch_train_steps else 'bn-'
self.callback_kwargs['namevals'] = metrics.get_name_value(prefix=name_value_prefix)
metrics.reset()
self.callback_kwargs['batch_elapse'] = sum_batch_elapse / sum_batch_inst
self.callback_kwargs['update_elapse'] = sum_update_elapse / sum_batch_inst
sum_update_elapse = 0
sum_batch_elapse = 0
sum_batch_inst = 0
self.step_end_callback()
if (i_batch % 500) == 0 and i_batch < epoch_train_steps:
self.callback_kwargs['epoch_elapse'] = time.time() - epoch_start_time
self.callback_kwargs['optimizer_dict'] = optimizer.state_dict()
self.epoch_end_callback()
batch_start_time = time.time()
if self.args.master_node:
FPS = f"batch_size: {self.batch_size}, Time: {round(batch_time.avg, 3)}, FPS@all: {round( device_num * self.batch_size/batch_time.avg, 3)}"
logging.info(FPS)
self.callback_kwargs['epoch_elapse'] = time.time() - epoch_start_time
self.callback_kwargs['optimizer_dict'] = optimizer.state_dict()
self.epoch_end_callback()
if (eval_iter is not None) \
and ((i_epoch+1) % max(1, int(self.save_checkpoint_freq/2))) == 0:
if self.args.master_node:
logging.info("========================Start evaluating epoch {:d}:========================".format(i_epoch))
metrics.reset()
self.net.eval()
sum_batch_elapse = 0.
sum_batch_inst = 0
sum_forward_elapse = 0.
with torch.no_grad():
batch_start_time = time.time()
for i_batch, (data, target) in enumerate(eval_iter):
forward_start_time = time.time()
outputs, losses = self.forward(data, target)
metrics.update([output.data.cpu() for output in outputs],
target.cpu(),
[loss.data.cpu() for loss in losses])
sum_forward_elapse += time.time() - forward_start_time
sum_batch_elapse += time.time() - batch_start_time
batch_start_time = time.time()
sum_batch_inst += 1
if i_batch >= epoch_eval_steps:
break
self.callback_kwargs['batch'] = sum_batch_inst
self.callback_kwargs['batch_elapse'] = sum_batch_elapse / sum_batch_inst
self.callback_kwargs['update_elapse'] = sum_forward_elapse / sum_batch_inst
self.callback_kwargs['namevals'] = metrics.get_name_value(prefix='ts-')
self.step_end_callback()
if self.args.master_node:
logging.info("Optimization done!")
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
self.start_count_index = 10
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if self.count == 0:
self.batchsize = n
self.val = val
self.count += n
if self.count > (self.start_count_index * self.batchsize):
self.sum += val * n
self.avg = self.sum / (self.count - self.start_count_index * self.batchsize)
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
if __name__ == "__main__":
"""
Test Static Model
"""
'''
import torch
import torchvision
logging.getLogger().setLevel(logging.DEBUG)
resnet18 = torchvision.models.resnet18()
net = static_model(net=resnet18, model_prefix="./exps/models/debug")
net.save_checkpoint(epoch=10)
net.load_checkpoint(epoch=10)
'''
"""
Test Typical Model
"""
import sys
import torch
import torchvision
sys.path.append(os.path.join("../../"))
import metric
from lr_scheduler import MultiFactorScheduler
import data.dataiter_factory as dataiter_factory
logging.getLogger().setLevel(logging.DEBUG)
resume = False
pretained = False
resnet18 = torchvision.models.resnet18()
import network
logging.info(network.__file__)
from network import symbol_builder
sym_c3d, net_cfg = symbol_builder.get_symbol(name="c3d", num_classes=101)
optimizer = torch.optim.SGD(sym_c3d.parameters(), lr=0.1,
momentum=0.9,
weight_decay=0.005)
net = model(net=sym_c3d, optimizer=optimizer, criterion=torch.nn.CrossEntropyLoss().npu())
if resume:
net.load_checkpoint(epoch=load_epoch)
elif pretained:
pretrained_model_state_dic = GetPretrainedModel(name='resnet')
net.load_state(state_dic=pretrained_model_state_dic, strict=False)
else:
logging.info("Train from scratch using random initialization")
metrics = metric.MetricList(metric.Accuracy(topk=1, name="acc-top1"),
metric.Accuracy(topk=5, name="acc-top5"))
lr_scheduler = MultiFactorScheduler(steps=[300, 1000], base_lr=0.1, factor=0.1)
tr_iter, ts_iter = dataiter_factory.creat(name='ucf101',
data_root='../../dataset/UCF101',
batch_size=1,
)
net.fit(iter_train=tr_iter, metrics_train=metrics,
epoch_start=0, epoch_end=100,
iter_eval=None, metrics_eval=None,
lr_scheduler=lr_scheduler,)