@@ -2,7 +2,7 @@ conf = {
"WORK_PATH": "./work",
"CUDA_VISIBLE_DEVICES": "0,1,2,3",
"data": {
- 'dataset_path': "your_dataset_path",
+ 'dataset_path': "./predata",
'resolution': '64',
'dataset': 'CASIA-B',
# In CASIA-B, data of subject #5 is incomplete.
@@ -1,20 +1,77 @@
-import math
+import math
import os
import os.path as osp
import random
import sys
from datetime import datetime
-
+# from apex import amp
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as tordata
+from tqdm import tqdm
+
+import argparse
+import shutil
+import warnings
+import time
+
+import torch.distributed
from .network import TripletLoss, SetNet
from .utils import TripletSampler
+class wrapperNet(nn.Module):
+ def __init__(self, module):
+ super(wrapperNet, self).__init__()
+ self.module = module
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, fmt=':f', start_count_index=10):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+ self.start_count_index = start_count_index
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ if self.count == 0:
+ self.N = n
+
+ self.val = val
+ self.count += n
+ if self.count > (self.start_count_index * self.N):
+ self.sum += val * n
+ self.avg = self.sum / (self.count - self.start_count_index * self.N)
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+class ProgressMeter(object):
+ def __init__(self, n_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(n_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ print(entries)
+
+ def _get_batch_fmtstr(self, n_batches):
+ n_digits = len(str(n_batches // 1))
+ fmt = '{:' + str(n_digits) + 'd}'
+ return '[' + fmt + '/' + fmt.format(n_batches) + ']'
class Model:
def __init__(self,
@@ -53,18 +110,48 @@ class Model:
self.total_iter = total_iter
self.img_size = img_size
-
- self.encoder = SetNet(self.hidden_dim).float()
- self.encoder = nn.DataParallel(self.encoder)
- self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
- self.triplet_loss = nn.DataParallel(self.triplet_loss)
- self.encoder.cuda()
- self.triplet_loss.cuda()
-
- self.optimizer = optim.Adam([
- {'params': self.encoder.parameters()},
- ], lr=self.lr)
-
+
+ use_dist = False
+ '''
+ try:
+ local_rank = torch.distributed.get_rank()
+ except AssertionError: # Default process group is not initialized
+ use_dist = False
+ '''
+
+ if use_dist:
+ self.encoder = SetNet(self.hidden_dim).float()
+ self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
+ self.optimizer = optim.Adam([
+ {'params': self.encoder.parameters()},
+ ], lr=self.lr)
+
+ self.local_device = f'cpu:{local_rank}'
+ self.encoder.to(self.local_device)
+ self.triplet_loss.to(self.local_device)
+
+ # self.encoder,self.optimizer = amp.initialize(self.encoder,self.optimizer,opt_level="O2", loss_scale=32.0)
+
+ local_rank = torch.distributed.get_rank()
+ if torch.cpu.device_count() > 1:
+ print("Let's use",torch.cpu.device_count(),"CPUs!")
+ print('-----RANK=', local_rank)
+ self.encoder = nn.parallel.DistributedDataParallel(self.encoder, broadcast_buffers=False, device_ids=[local_rank])
+ else:
+ self.local_device = 'cpu'
+ self.encoder = SetNet(self.hidden_dim).float()
+ self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
+ self.triplet_loss = nn.DataParallel(self.triplet_loss)
+ self.encoder = self.encoder.cpu()
+ self.triplet_loss = self.triplet_loss.cpu()
+
+ self.optimizer = optim.Adam([
+ {'params': self.encoder.parameters()},
+ ], lr=self.lr)
+
+ # self.encoder,self.optimizer = amp.initialize(self.encoder,self.optimizer,opt_level="O2", loss_scale=64.0)
+ self.encoder = nn.DataParallel(self.encoder)
+
self.hard_loss_metric = []
self.full_loss_metric = []
self.full_loss_num = []
@@ -81,47 +168,52 @@ class Model:
view = [batch[i][2] for i in range(batch_size)]
seq_type = [batch[i][3] for i in range(batch_size)]
label = [batch[i][4] for i in range(batch_size)]
+
batch = [seqs, view, seq_type, label, None]
-
+
def select_frame(index):
sample = seqs[index]
frame_set = frame_sets[index]
+
if self.sample_type == 'random':
- frame_id_list = random.choices(frame_set, k=self.frame_num)
+ frame_list = sorted(list(frame_set))
+
+ frame_id_list = random.choices(frame_list, k=self.frame_num)
+
_ = [feature.loc[frame_id_list].values for feature in sample]
else:
_ = [feature.values for feature in sample]
return _
-
+
seqs = list(map(select_frame, range(len(seqs))))
if self.sample_type == 'random':
seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
else:
- gpu_num = min(torch.cuda.device_count(), batch_size)
- batch_per_gpu = math.ceil(batch_size / gpu_num)
+ npu_num = batch_size
+ batch_per_npu = math.ceil(batch_size / npu_num)
batch_frames = [[
len(frame_sets[i])
- for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
+ for i in range(batch_per_npu * _, batch_per_npu * (_ + 1))
if i < batch_size
- ] for _ in range(gpu_num)]
- if len(batch_frames[-1]) != batch_per_gpu:
- for _ in range(batch_per_gpu - len(batch_frames[-1])):
+ ] for _ in range(npu_num)]
+ if len(batch_frames[-1]) != batch_per_npu:
+ for _ in range(batch_per_npu - len(batch_frames[-1])):
batch_frames[-1].append(0)
- max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
+ max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(npu_num)])
seqs = [[
np.concatenate([
seqs[i][j]
- for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
+ for i in range(batch_per_npu * _, batch_per_npu * (_ + 1))
if i < batch_size
- ], 0) for _ in range(gpu_num)]
+ ], 0) for _ in range(npu_num)]
for j in range(feature_num)]
seqs = [np.asarray([
np.pad(seqs[j][_],
((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
'constant',
constant_values=0)
- for _ in range(gpu_num)])
+ for _ in range(npu_num)])
for j in range(feature_num)]
batch[4] = np.asarray(batch_frames)
@@ -129,6 +221,14 @@ class Model:
return batch
def fit(self):
+ is_8p = torch.cpu.device_count() > 1
+
+ batch_time = AverageMeter('Time', ':6.3f')
+ data_time = AverageMeter('Data', ':6.3f')
+ hard_loss_mean = AverageMeter('Hard_Loss', ':.6e', start_count_index=0)
+ full_loss_mean = AverageMeter('Full_Loss', ':.6e', start_count_index=0)
+ p_full_loss_num = AverageMeter('Full_Loss_Num', ':6.3e', start_count_index=0)
+
if self.restore_iter != 0:
self.load(self.restore_iter)
@@ -136,89 +236,113 @@ class Model:
self.sample_type = 'random'
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
+
+ local_rank = 'npu'
+ try:
+ local_rank = torch.distributed.get_rank()
+ except AssertionError:
+ pass
+
triplet_sampler = TripletSampler(self.train_source, self.batch_size)
+
train_loader = tordata.DataLoader(
dataset=self.train_source,
+ # shuffle=False,
+ # batch_size=self.batch_size,
+ # pin_memory=False,
batch_sampler=triplet_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers)
-
train_label_set = list(self.train_source.label_set)
train_label_set.sort()
-
_time1 = datetime.now()
- for seq, view, seq_type, label, batch_frame in train_loader:
+
+ progress = ProgressMeter(
+ len(train_loader),
+ [batch_time, data_time, hard_loss_mean, full_loss_mean, p_full_loss_num],
+ prefix="Iter[{}]".format(self.restore_iter))
+ start_time = time.time()
+
+ for iter_i, _t_data in enumerate(train_loader):
+ data_time.update(time.time() - start_time)
+
+ seq, view, seq_type, label, batch_frame = _t_data
+ # triplet_sampler.set_epoch(self.restore_iter)
+
self.restore_iter += 1
self.optimizer.zero_grad()
-
+
for i in range(len(seq)):
seq[i] = self.np2var(seq[i]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
-
- feature, label_prob = self.encoder(*seq, batch_frame)
-
+
+ feature = self.encoder(*seq, batch_frame)
+
target_label = [train_label_set.index(l) for l in label]
target_label = self.np2var(np.array(target_label)).long()
triplet_feature = feature.permute(1, 0, 2).contiguous()
- triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
+ triplet_label = target_label.unsqueeze(0).cpu().repeat(triplet_feature.size(0), 1)
+
+ triplet_feature = triplet_feature.cpu()
+ triplet_label = triplet_label.cpu()
+
(full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
) = self.triplet_loss(triplet_feature, triplet_label)
if self.hard_or_full_trip == 'hard':
loss = hard_loss_metric.mean()
elif self.hard_or_full_trip == 'full':
loss = full_loss_metric.mean()
-
+
self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
self.dist_list.append(mean_dist.mean().data.cpu().numpy())
-
+
if loss > 1e-9:
+ # with amp.scale_loss(loss,self.optimizer) as scaled_loss:
+ # scaled_loss.backward()
loss.backward()
self.optimizer.step()
if self.restore_iter % 1000 == 0:
- print(datetime.now() - _time1)
+ print(f"[{local_rank}]:", datetime.now() - _time1)
_time1 = datetime.now()
- if self.restore_iter % 100 == 0:
- self.save()
- print('iter {}:'.format(self.restore_iter), end='')
- print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
- print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
- print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
+ if self.restore_iter % 10 == 0:
+ print(f"[{local_rank}]: ", 'iter {}:'.format(self.restore_iter), end='')
+
self.mean_dist = np.mean(self.dist_list)
- print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
- print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
- print(', hard or full=%r' % self.hard_or_full_trip)
+ print('mean_dist={0:.8f}'.format(self.mean_dist))
+
+ hard_loss_mean.update(np.mean(self.hard_loss_metric), self.P * self.M)
+ full_loss_mean.update(np.mean(self.full_loss_metric), self.P * self.M)
+ p_full_loss_num.update(np.mean(self.full_loss_num), self.P * self.M)
+ progress.display(self.restore_iter)
+
sys.stdout.flush()
self.hard_loss_metric = []
self.full_loss_metric = []
self.full_loss_num = []
self.dist_list = []
-
- # Visualization using t-SNE
- # if self.restore_iter % 500 == 0:
- # pca = TSNE(2)
- # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
- # for i in range(self.P):
- # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
- # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
- #
- # plt.show()
+
+ batch_time.update(time.time() - start_time)
+ start_time = time.time()
+
+ if self.restore_iter % 200 == 0:
+ self.save()
if self.restore_iter == self.total_iter:
break
-
+
def ts2var(self, x):
- return autograd.Variable(x).cuda()
-
+ return autograd.Variable(x).to(self.local_device, non_blocking=False)
+
def np2var(self, x):
return self.ts2var(torch.from_numpy(x))
-
- def transform(self, flag, batch_size=1):
+
+ def transform(self, flag, bin_file_path=None, batch_size=1, output_path=None, pre_process=False, post_process=False):
self.encoder.eval()
source = self.test_source if flag == 'test' else self.train_source
self.sample_type = 'all'
@@ -233,26 +357,65 @@ class Model:
view_list = list()
seq_type_list = list()
label_list = list()
-
- for i, x in enumerate(data_loader):
+
+ test_len = len(data_loader)
+
+ for i, x in tqdm(enumerate(data_loader), total = len(data_loader)):
+ import time
+ cvt_time = time.time()
+
seq, view, seq_type, label, batch_frame = x
for j in range(len(seq)):
seq[j] = self.np2var(seq[j]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
- # print(batch_frame, np.sum(batch_frame))
-
- feature, _ = self.encoder(*seq, batch_frame)
+
+ if pre_process:
+ bin_img_path = os.path.abspath(bin_file_path + '/'+ f'{i:0>4d}.bin')
+
+ align_size = 100
+
+ # new pre-process align by repeat itself
+ cat_seq = None
+ seq[0] = seq[0].detach().cpu().float()
+ org_size = seq[0].shape[1]
+ if org_size < align_size:
+ pad_shape = list(seq[0].shape)
+ pad_shape[1] = align_size - org_size
+ pad_zeros = torch.zeros(pad_shape).float()
+ cat_seq = torch.cat([pad_zeros.float(), seq[0].float()], dim=1)
+ else:
+ cat_seq = seq[0].float()
+ while cat_seq.shape[1] < align_size:
+ cat_seq = torch.cat([cat_seq, seq[0].float()], dim=1)
+ cat_seq = cat_seq[:, :align_size, :, :]
+
+ cat_seq.numpy().tofile(bin_img_path)
+
+ continue # pre-processing, skip model calculation
+
+ # add post_process
+ feature = None
+ if post_process == False:
+ feature = self.encoder(*seq, batch_frame)
+ else:
+ feat = np.fromfile(output_path+ '/'+ f'{i:0>4d}_0.bin', dtype=np.float32)
+ feature = torch.Tensor(feat).float().cpu().view(1, -1, 256)
+
n, num_bin, _ = feature.size()
feature_list.append(feature.view(n, -1).data.cpu().numpy())
view_list += view
seq_type_list += seq_type
label_list += label
-
+ if pre_process:
+ return None
return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list
def save(self):
os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
+ local_rank = torch.distributed.get_rank()
+ if local_rank != 0:
+ return
torch.save(self.encoder.state_dict(),
osp.join('checkpoint', self.model_name,
'{}-{:0>5}-encoder.ptm'.format(
@@ -262,11 +425,18 @@ class Model:
'{}-{:0>5}-optimizer.ptm'.format(
self.save_name, self.restore_iter)))
- # restore_iter: iteration index of the checkpoint to load
+ # restore_iter, iteration index of the checkpoint to load
def load(self, restore_iter):
- self.encoder.load_state_dict(torch.load(osp.join(
- 'checkpoint', self.model_name,
- '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
self.optimizer.load_state_dict(torch.load(osp.join(
'checkpoint', self.model_name,
- '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
+ '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
+ try:
+ self.encoder.load_state_dict(torch.load(osp.join(
+ 'checkpoint', self.model_name,
+ '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
+ except RuntimeError:
+ wrapped = wrapperNet(self.encoder)
+ wrapped.load_state_dict(torch.load(osp.join(
+ 'checkpoint', self.model_name,
+ '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
+ self.encoder = wrapped.module
@@ -117,4 +117,4 @@ class SetNet(nn.Module):
feature = feature.matmul(self.fc_bin[0])
feature = feature.permute(1, 0, 2).contiguous()
- return feature, None
+ return feature
@@ -14,20 +14,21 @@ class TripletLoss(nn.Module):
n, m, d = feature.size()
hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1)
hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1)
-
+
dist = self.batch_dist(feature)
mean_dist = dist.mean(1).mean(1)
dist = dist.view(-1)
+
# hard
- hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0]
- hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0]
+ hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask.bool()).view(n, m, -1), 2)[0]
+ hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask.bool()).view(n, m, -1), 2)[0]
hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1)
hard_loss_metric_mean = torch.mean(hard_loss_metric, 1)
# non-zero full
- full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1)
- full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1)
+ full_hp_dist = torch.masked_select(dist, hp_mask.bool()).view(n, m, -1, 1)
+ full_hn_dist = torch.masked_select(dist, hn_mask.bool()).view(n, m, 1, -1)
full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1)
full_loss_metric_sum = full_loss_metric.sum(1)
@@ -39,7 +39,7 @@ def load_data(dataset_path, resolution, dataset, pid_num, pid_shuffle, cache=Tru
os.makedirs('partition', exist_ok=True)
np.save(pid_fname, pid_list)
- pid_list = np.load(pid_fname)
+ pid_list = np.load(pid_fname, allow_pickle=True)
train_list = pid_list[0]
test_list = pid_list[1]
train_source = DataSet(
@@ -15,11 +15,11 @@ class DataSet(tordata.Dataset):
self.label = label
self.cache = cache
self.resolution = int(resolution)
- self.cut_padding = int(float(resolution)/64*10)
+ self.cut_padding = int(float(resolution) / 64 * 10)
self.data_size = len(self.label)
self.data = [None] * self.data_size
self.frame_set = [None] * self.data_size
-
+
self.label_set = set(self.label)
self.seq_type_set = set(self.seq_type)
self.view_set = set(self.view)
@@ -57,16 +57,16 @@ class DataSet(tordata.Dataset):
if not self.cache:
data = [self.__loader__(_path) for _path in self.seq_dir[index]]
frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
- frame_set = list(set.intersection(*frame_set))
+ frame_set = sorted(list(set.intersection(*frame_set)))
elif self.data[index] is None:
data = [self.__loader__(_path) for _path in self.seq_dir[index]]
frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
- frame_set = list(set.intersection(*frame_set))
+ frame_set = sorted(list(set.intersection(*frame_set)))
self.data[index] = data
self.frame_set[index] = frame_set
else:
data = self.data[index]
- frame_set = self.frame_set[index]
+ frame_set = sorted(list(self.frame_set[index]))
return data, frame_set, self.view[
index], self.seq_type[index], self.label[index],
@@ -79,6 +79,7 @@ class DataSet(tordata.Dataset):
for _img_path in imgs
if osp.isfile(osp.join(flie_path, _img_path))]
num_list = list(range(len(frame_list)))
+
data_dict = xr.DataArray(
frame_list,
coords={'frame': num_list},
@@ -4,8 +4,8 @@ import numpy as np
def cuda_dist(x, y):
- x = torch.from_numpy(x).cuda()
- y = torch.from_numpy(y).cuda()
+ x = torch.from_numpy(x)
+ y = torch.from_numpy(y)
dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze(
1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1))
dist = torch.sqrt(F.relu(dist))