import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
import numpy as np
import pandas as pd
import os
import cv2
import pickle
import lmdb
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter
import apex
from apex import amp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
if torch.__version__ >= '1.8':
import torch_npu
import random
import time
from .config import config
from .alexnet import SiameseAlexNet, _create_gt_mask, Criterion
from .dataset import ImagnetVIDDataset
from .custom_transforms import Normalize, ToTensor, RandomStretch, RandomCrop, CenterCrop, RandomBlur, ColorAug
def train(data_dir, workers, epochs):
CALCULATE_DEVICE = "npu:0"
torch.npu.set_device(CALCULATE_DEVICE)
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True
meta_data_path = os.path.join(data_dir, "meta_data.pkl")
meta_data = pickle.load(open(meta_data_path, 'rb'))
all_videos = [x[0] for x in meta_data]
train_videos, valid_videos = train_test_split(all_videos, test_size=1-config.train_ratio, random_state=config.seed)
random_crop_size = config.instance_size - 2 * config.total_stride
train_z_transforms = transforms.Compose([
RandomStretch(),
CenterCrop((config.exemplar_size, config.exemplar_size)),
ToTensor()
])
train_x_transforms = transforms.Compose([
RandomStretch(),
RandomCrop((random_crop_size, random_crop_size), config.max_translate),
ToTensor()
])
valid_z_transforms = transforms.Compose([
CenterCrop((config.exemplar_size, config.exemplar_size)),
ToTensor()
])
valid_x_transforms = transforms.Compose([ToTensor()])
db = lmdb.open(data_dir+'.lmdb', readonly=True, map_size=int(50e9))
train_dataset = ImagnetVIDDataset(db, train_videos, data_dir, train_z_transforms, train_x_transforms)
valid_dataset = ImagnetVIDDataset(db, valid_videos, data_dir, valid_z_transforms,
valid_x_transforms, training=False)
trainloader = DataLoader(train_dataset, batch_size=config.train_batch_size,
shuffle=True, pin_memory=False, num_workers=workers, drop_last=True)
validloader = DataLoader(valid_dataset, batch_size=config.valid_batch_size,
shuffle=False, pin_memory=False, num_workers=workers, drop_last=True)
if not os.path.exists(config.log_dir):
os.mkdir(config.log_dir)
summary_writer = SummaryWriter(config.log_dir)
model = SiameseAlexNet()
gt, weight = _create_gt_mask((config.train_response_sz, config.train_response_sz))
train_gt = torch.from_numpy(gt).npu()
train_weight = torch.from_numpy(weight).npu()
gt, weight = _create_gt_mask((config.response_sz, config.response_sz))
valid_gt = torch.from_numpy(gt).npu()
valid_weight = torch.from_numpy(weight).npu()
model.init_weights()
model = model.npu()
optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), lr=config.lr,
momentum=config.momentum, weight_decay=config.weight_decay)
criterion = Criterion().npu()
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0, combine_grad=True)
scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
for epoch in range(epochs+1):
if epoch == epochs:
model.train()
for i, data in enumerate(trainloader):
exemplar_imgs, instance_imgs = data
exemplar_var, instance_var = Variable(exemplar_imgs.npu()), Variable(instance_imgs.npu())
if i < 10:
optimizer.zero_grad()
outputs = model([exemplar_var, instance_var])
loss = criterion(outputs, train_gt, train_weight)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
else:
with torch.autograd.profiler.profile(use_npu=True) as prof:
optimizer.zero_grad()
outputs = model([exemplar_var, instance_var])
loss = criterion(outputs, train_gt, train_weight)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
prof.export_chrome_trace("test/output/output.prof")
break
break
batch_time = AverageMeter('Time', ':6.3f')
train_loss = []
model.train()
end = time.time()
for i, data in enumerate(tqdm(trainloader)):
exemplar_imgs, instance_imgs = data
exemplar_var, instance_var = Variable(exemplar_imgs.npu()), Variable(instance_imgs.npu())
optimizer.zero_grad()
outputs = model([exemplar_var, instance_var])
loss = criterion(outputs, train_gt, train_weight)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
cost_time = time.time() - end
batch_time.update(cost_time)
step = epoch * len(trainloader) + i
summary_writer.add_scalar('train/loss', loss.data, step)
train_loss.append(loss.data)
end = time.time()
train_loss = torch.mean(torch.stack(train_loss))
valid_loss = []
model.eval()
for i, data in enumerate(tqdm(validloader)):
exemplar_imgs, instance_imgs = data
exemplar_var, instance_var = Variable(exemplar_imgs.npu()), Variable(instance_imgs.npu())
outputs = model((exemplar_var, instance_var))
loss = F.binary_cross_entropy_with_logits(outputs, valid_gt, valid_weight,
reduction='sum') / config.valid_batch_size
valid_loss.append(loss.data)
valid_loss = torch.mean(torch.stack(valid_loss))
print("EPOCH %d valid_loss : %.4f, train_loss : %.4f" % (epoch, valid_loss, train_loss))
summary_writer.add_scalar('valid/loss', valid_loss, (epoch + 1) * len(trainloader))
if epochs == 1:
print("Performance Test: ", "batch_size:", config.train_batch_size, 'Time: {:.3f}'.format(batch_time.avg),
'* FPS@all {:.3f}'.format(config.train_batch_size / batch_time.avg))
else:
torch.save(model.cpu().state_dict(), "./models/final/siamfc_{}.pth".format(epoch + 1))
model.npu()
scheduler.step()
def train_dist(args):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '11223'
data_dir = args.data
args.device = torch.device("npu:%d" % args.local_rank)
dist.init_process_group(backend='hccl', world_size=8, rank=args.local_rank)
args.is_master = args.local_rank == 0
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True
meta_data_path = os.path.join(data_dir, "meta_data.pkl")
meta_data = pickle.load(open(meta_data_path, 'rb'))
all_videos = [x[0] for x in meta_data]
train_videos, _ = train_test_split(all_videos, test_size=1 - config.train_ratio, random_state=config.seed)
random_crop_size = config.instance_size - 2 * config.total_stride
train_z_transforms = transforms.Compose([
RandomStretch(),
CenterCrop((config.exemplar_size, config.exemplar_size)),
ToTensor()
])
train_x_transforms = transforms.Compose([
RandomStretch(),
RandomCrop((random_crop_size, random_crop_size), config.max_translate),
ToTensor()
])
db = lmdb.open(data_dir + '.lmdb', readonly=True, map_size=int(50e9))
train_dataset = ImagnetVIDDataset(db, train_videos, data_dir, train_z_transforms, train_x_transforms)
train_sampler = DistributedSampler(train_dataset)
trainloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=config.train_batch_size,
shuffle=False, pin_memory=False, num_workers=args.workers, drop_last=True)
if args.is_master:
dist_log_dir = config.log_dir + "/dist"
if not os.path.exists(dist_log_dir):
os.mkdir(dist_log_dir)
summary_writer = SummaryWriter(dist_log_dir)
model = SiameseAlexNet()
model.init_weights()
torch.npu.set_device(args.local_rank)
model.to(args.device)
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr,
momentum=config.momentum, weight_decay=config.weight_decay)
criterion = Criterion().to(args.device)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0)
scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
model = DDP(model, device_ids=[args.local_rank], broadcast_buffers=False)
FPSrecord = 0
for epoch in range(args.epoch):
batch_time = AverageMeter('Time', ':6.3f')
train_sampler.set_epoch(epoch)
train_loss = []
model.train()
torch.npu.synchronize()
end = time.time()
for i, data in enumerate(trainloader):
exemplar_imgs, instance_imgs = data
exemplar_var, instance_var = Variable(exemplar_imgs.to(args.device)), Variable(
instance_imgs.to(args.device))
optimizer.zero_grad()
outputs = model((exemplar_var, instance_var))
gt, weight = _create_gt_mask((config.train_response_sz, config.train_response_sz))
train_gt = torch.from_numpy(gt).to(args.device)
train_weight = torch.from_numpy(weight).to(args.device)
loss = criterion(outputs, train_gt, train_weight)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
torch.npu.synchronize()
cost_time = time.time() - end
batch_time.update(cost_time)
end = time.time()
step = epoch * len(trainloader) + i
if args.is_master:
summary_writer.add_scalar('train/loss', loss.data, step)
train_loss.append(loss.data)
train_loss = torch.mean(torch.stack(train_loss))
dist.all_reduce(train_loss, op=dist.reduce_op.SUM)
if args.is_master:
print("EPOCH %d train_loss : %.4f" % (epoch, train_loss / 8))
FPSrecord = FPSrecord + config.train_batch_size * 8 / batch_time.avg
print("[npu id:", args.local_rank, "]", "batch_size:", 8 * config.train_batch_size,
'Time: {:.3f}'.format(batch_time.avg), '* FPS@cur {:.3f}'.format(config.train_batch_size * 8 /
batch_time.avg))
if args.is_master:
if (epoch + 1) % 5 == 0:
torch.save(model.module.state_dict(), "./models/siamfc_{}.pth".format(epoch + 1))
scheduler.step()
if args.is_master:
print("Training Done - [npu id:", args.local_rank, "]", "batch_size:",
8 * config.train_batch_size, ' FPS@all {:.3f}'.format(FPSrecord / args.epoch))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', start_count_index=100):
self.name = name
self.fmt = fmt
self.reset()
self.start_count_index = start_count_index
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if self.count == 0:
self.N = n
self.val = val
self.count += n
if self.count > (self.start_count_index * self.N):
self.sum += val * n
self.avg = self.sum / (self.count - self.start_count_index * self.N)
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)