import math
import os
import os.path as osp
import random
import sys
import time
import argparse
import shutil
import warnings
from datetime import datetime
from apex import amp
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as tordata
try:
from torch_npu.utils.profiler import Profile
except:
print("Profile not in torch_npu.utils.profiler now.. Auto Profile disabled.", flush=True)
class Profile:
def __init__(self, *args, **kwargs):
pass
def start(self):
pass
def end(self):
pass
from .network import TripletLoss, SetNet
from .utils import TripletSampler
class wrapperNet(nn.Module):
def __init__(self, module):
super(wrapperNet, self).__init__()
self.module = module
class NoProfiling(object):
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', start_count_index=10):
self.name = name
self.fmt = fmt
self.reset()
self.start_count_index = start_count_index
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if self.count == 0:
self.N = n
self.val = val
self.count += n
if self.count > (self.start_count_index * self.N):
self.sum += val * n
self.avg = self.sum / (self.count - self.start_count_index * self.N)
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, n_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(n_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print(entries)
def _get_batch_fmtstr(self, n_batches):
n_digits = len(str(n_batches // 1))
fmt = '{:' + str(n_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(n_batches) + ']'
class Model:
def __init__(self,
hidden_dim,
lr,
hard_or_full_trip,
margin,
num_workers,
batch_size,
restore_iter,
total_iter,
save_name,
train_pid_num,
frame_num,
model_name,
train_source,
test_source,
profiling,
start_step,
stop_step,
img_size=64):
self.save_name = save_name
self.train_pid_num = train_pid_num
self.train_source = train_source
self.test_source = test_source
self.hidden_dim = hidden_dim
self.lr = lr
self.hard_or_full_trip = hard_or_full_trip
self.margin = margin
self.frame_num = frame_num
self.num_workers = num_workers
self.batch_size = batch_size
self.model_name = model_name
self.P, self.M = batch_size
self.profiling = profiling
self.start_step = start_step
self.stop_step = stop_step
self.restore_iter = restore_iter
self.total_iter = total_iter
self.img_size = img_size
local_rank = 0
try:
local_rank = torch.distributed.get_rank()
from config import conf_8p as conf
device_str = conf['ASCEND_VISIBLE_DEVICES']
self.device_count = len(device_str) // 2 + 1
self.local_device = f'npu:{local_rank}'
except:
local_rank = torch.npu.current_device()
self.device_count = 1
self.local_device = f'npu:{local_rank}'
print(f'----Using device:{local_rank}----')
self.encoder = SetNet(self.hidden_dim).float()
self.encoder.to(self.local_device)
self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
self.triplet_loss.to(self.local_device)
if os.getenv('ALLOW_FP32') or os.getenv('ALLOW_HF32'):
import torch_npu
self.optimizer = torch_npu.optim.NpuFusedAdam([
{'params': self.encoder.parameters()},
], lr=self.lr)
else:
self.optimizer = optim.Adam([
{'params': self.encoder.parameters()},
], lr=self.lr)
if not os.getenv('ALLOW_FP32') and not os.getenv('ALLOW_HF32'):
self.encoder,self.optimizer = amp.initialize(self.encoder,self.optimizer,opt_level="O2", loss_scale=32.0)
if self.device_count > 1:
print("Let's use", self.device_count, "NPUs!")
try:
self.encoder = nn.parallel.DistributedDataParallel(self.encoder, broadcast_buffers=False, device_ids=[local_rank])
except:
self.encoder = nn.parallel.DistributedDataParallel(self.encoder, broadcast_buffers=False, device_ids=[local_rank])
self.hard_loss_metric = []
self.full_loss_metric = []
self.full_loss_num = []
self.dist_list = []
self.mean_dist = 0.01
self.sample_type = 'all'
def collate_fn(self, batch):
batch_size = len(batch)
feature_num = len(batch[0][0])
seqs = [batch[i][0] for i in range(batch_size)]
frame_sets = [batch[i][1] for i in range(batch_size)]
view = [batch[i][2] for i in range(batch_size)]
seq_type = [batch[i][3] for i in range(batch_size)]
label = [batch[i][4] for i in range(batch_size)]
batch = [seqs, view, seq_type, label, None]
def select_frame(index):
sample = seqs[index]
frame_set = frame_sets[index]
if self.sample_type == 'random':
frame_list = sorted(list(frame_set))
frame_id_list = random.choices(frame_list, k=self.frame_num)
_ = [feature.loc[frame_id_list].values for feature in sample]
else:
_ = [feature.values for feature in sample]
return _
seqs = list(map(select_frame, range(len(seqs))))
if self.sample_type == 'random':
seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
else:
npu_num = min(self.device_count, batch_size)
batch_per_npu = math.ceil(batch_size / npu_num)
batch_frames = [[
len(frame_sets[i])
for i in range(batch_per_npu * _, batch_per_npu * (_ + 1))
if i < batch_size
] for _ in range(npu_num)]
if len(batch_frames[-1]) != batch_per_npu:
for _ in range(batch_per_npu - len(batch_frames[-1])):
batch_frames[-1].append(0)
max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(npu_num)])
seqs = [[
np.concatenate([
seqs[i][j]
for i in range(batch_per_npu * _, batch_per_npu * (_ + 1))
if i < batch_size
], 0) for _ in range(npu_num)]
for j in range(feature_num)]
seqs = [np.asarray([
np.pad(seqs[j][_],
((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
'constant',
constant_values=0)
for _ in range(npu_num)])
for j in range(feature_num)]
batch[4] = np.asarray(batch_frames)
batch[0] = seqs
return batch
def fit(self):
batch_time = AverageMeter('Time', ':6.3f')
fps = AverageMeter('FPS', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
hard_loss_mean = AverageMeter('Hard_Loss', ':.6e', start_count_index=0)
full_loss_mean = AverageMeter('Full_Loss', ':.6e', start_count_index=0)
p_full_loss_num = AverageMeter('Full_Loss_Num', ':6.3e', start_count_index=0)
if self.restore_iter != 0:
self.load(self.restore_iter)
self.encoder.train()
self.sample_type = 'random'
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
local_rank = 0
if self.device_count > 1:
local_rank = torch.distributed.get_rank()
triplet_sampler = TripletSampler(self.train_source, self.batch_size)
train_loader = tordata.DataLoader(
dataset=self.train_source,
batch_sampler=triplet_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers)
train_label_set = list(self.train_source.label_set)
train_label_set.sort()
_time1 = datetime.now()
progress = ProgressMeter(
len(train_loader),
[batch_time, fps, data_time, hard_loss_mean, full_loss_mean, p_full_loss_num],
prefix="Iter[{}]".format(self.restore_iter))
start_time = time.time()
profile = Profile(start_step=int(os.getenv('PROFILE_START_STEP', 10)),
profile_type=os.getenv('PROFILE_TYPE'))
for iter_i, _t_data in enumerate(train_loader):
data_time.update(time.time() - start_time)
seq, view, seq_type, label, batch_frame = _t_data
self.restore_iter += 1
profile.start()
self.optimizer.zero_grad()
for i in range(len(seq)):
seq[i] = self.np2var(seq[i]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
feature = self.encoder(*seq, batch_frame)
target_label = [train_label_set.index(l) for l in label]
target_label = self.np2var(np.array(target_label)).long()
triplet_feature = feature.permute(1, 0, 2).contiguous()
triplet_label = target_label.unsqueeze(0).cpu().repeat(triplet_feature.size(0), 1)
triplet_feature = triplet_feature.to(self.local_device)
triplet_label = triplet_label.to(self.local_device)
(full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
) = self.triplet_loss(triplet_feature, triplet_label)
if self.hard_or_full_trip == 'hard':
loss = hard_loss_metric.mean()
elif self.hard_or_full_trip == 'full':
loss = full_loss_metric.mean()
self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
self.dist_list.append(mean_dist.mean().data.cpu().numpy())
if loss > 1e-10:
if os.getenv('ALLOW_FP32') or os.getenv('ALLOW_HF32'):
loss.backward()
else:
with amp.scale_loss(loss,self.optimizer) as scaled_loss:
scaled_loss.backward()
self.optimizer.step()
else:
print('loss very small at: ', iter_i)
profile.end()
if self.restore_iter % 50 == 0:
print(f"[{local_rank}]:", datetime.now() - _time1)
_time1 = datetime.now()
if self.restore_iter % 50 == 0:
print(f"[{local_rank}]: ", 'iter {}:'.format(self.restore_iter), end='')
self.mean_dist = np.mean(self.dist_list)
print('mean_dist={0:.8f}'.format(self.mean_dist))
hard_loss_mean.update(np.mean(self.hard_loss_metric), self.P * self.M)
full_loss_mean.update(np.mean(self.full_loss_metric), self.P * self.M)
p_full_loss_num.update(np.mean(self.full_loss_num), self.P * self.M)
progress.display(self.restore_iter)
sys.stdout.flush()
self.hard_loss_metric = []
self.full_loss_metric = []
self.full_loss_num = []
self.dist_list = []
org_frame_num = 30
from config import conf_8p as conf
frame_aug_rate = conf['model']['frame_num'] / org_frame_num
time_spend = time.time() - start_time
batch_time.update(time_spend)
fps.update(self.device_count * self.P * self.M * frame_aug_rate / time_spend)
if self.restore_iter < 5:
print('Iter_time: {}'.format((time.time() - start_time)))
start_time = time.time()
if self.restore_iter % 1000 == 0:
self.save()
if self.restore_iter == self.total_iter:
break
def ts2var(self, x):
return autograd.Variable(x).to(self.local_device, non_blocking=False)
def np2var(self, x):
return self.ts2var(torch.from_numpy(x))
def transform(self, flag, batch_size=1, pre_process=False):
self.encoder.eval()
source = self.test_source if flag == 'test' else self.train_source
self.sample_type = 'all'
data_loader = tordata.DataLoader(
dataset=source,
batch_size=batch_size,
sampler=tordata.sampler.SequentialSampler(source),
collate_fn=self.collate_fn,
num_workers=self.num_workers)
feature_list = list()
view_list = list()
seq_type_list = list()
label_list = list()
for i, x in enumerate(data_loader):
seq, view, seq_type, label, batch_frame = x
for j in range(len(seq)):
seq[j] = self.np2var(seq[j]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
feature = self.encoder(*seq, batch_frame)
n, num_bin, _ = feature.size()
feature_list.append(feature.view(n, -1).data.cpu().numpy())
view_list += view
seq_type_list += seq_type
label_list += label
if self.P * self.M >= 64:
print(f'[{i:0>2d}/{len(data_loader)}]Batch Tested')
elif i % 20 == 0:
print(f'[{i:0>4d}/{len(data_loader)}]Batch Tested')
return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list
def save(self):
os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
local_rank = 0
try:
local_rank = torch.distributed.get_rank()
except:
pass
if local_rank != 0:
return
torch.save(self.encoder.state_dict(),
osp.join('checkpoint', self.model_name,
'{}-{:0>5}-encoder.ptm'.format(
self.save_name, self.restore_iter)))
torch.save(self.optimizer.state_dict(),
osp.join('checkpoint', self.model_name,
'{}-{:0>5}-optimizer.ptm'.format(
self.save_name, self.restore_iter)))
def load(self, restore_iter):
try:
self.encoder.load_state_dict(torch.load(osp.join(
'checkpoint', self.model_name,
'{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
except RuntimeError:
wrapped = wrapperNet(self.encoder)
wrapped.load_state_dict(torch.load(osp.join(
'checkpoint', self.model_name,
'{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
self.encoder = wrapped.module
print('Checkpoint loaded through wrapper.')
self.optimizer.load_state_dict(torch.load(osp.join(
'checkpoint', self.model_name,
'{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))