05360171创建于 2022年3月18日历史提交
diff --git a/Learner.py b/Learner.py
index 1c5e38f..a562947 100644
--- a/Learner.py
+++ b/Learner.py
@@ -1,98 +1,174 @@
-from data.data_pipe import de_preprocess, get_train_loader, get_val_data
-from model import Backbone, Arcface, MobileFaceNet, Am_softmax, l2_norm
-from verifacation import evaluate
-import torch
-from torch import optim
+import time
+import math
+import sys
+
+import bcolz
 import numpy as np
+from PIL import Image
 from tqdm import tqdm
-from tensorboardX import SummaryWriter
 from matplotlib import pyplot as plt
-plt.switch_backend('agg')
-from utils import get_time, gen_plot, hflip_batch, separate_bn_paras
-from PIL import Image
+import torch
+from torch import optim
+from torch.nn.parallel.distributed import DistributedDataParallel
 from torchvision import transforms as trans
-import math
-import bcolz
+from tensorboardX import SummaryWriter
+import apex
+from apex import amp
+
+from data.data_pipe import de_preprocess, get_train_loader, get_val_data, get_val_pair
+from model import Backbone, Arcface, MobileFaceNet, Am_softmax, l2_norm
+from verifacation import evaluate
+from utils import get_time, gen_plot, hflip_batch, separate_bn_paras
+
+plt.switch_backend('agg')
+
+
+# nohup时不会延迟打印
+def flush_print(func):
+    def new_print(*args, **kwargs):
+        func(*args, **kwargs)
+        sys.stdout.flush()
+
+    return new_print
+
+
+print = flush_print(print)
+
 
 class face_learner(object):
     def __init__(self, conf, inference=False):
-        print(conf)
+        # 分布式环境下只在主节点下输出日志,单卡则正常输出
+        if conf.is_master_node:
+            print(conf)
+        self.start_epoch = conf.start_epoch
+        self.inference = inference
+
         if conf.use_mobilfacenet:
             self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
             print('MobileFaceNet model generated')
         else:
             self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device)
             print('{}_{} model generated'.format(conf.net_mode, conf.net_depth))
-        
+
+        # 去除分布式得到模型权重带来的影响
+        self.model_without_ddp = self.model
+
         if not inference:
             self.milestones = conf.milestones
-            self.loader, self.class_num = get_train_loader(conf)        
+            self.loader, self.class_num = get_train_loader(conf)
 
             self.writer = SummaryWriter(conf.log_path)
-            self.step = 0
+            self.step = 1
             self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device)
 
-            print('two model heads generated')
+            # 分布式环境下只在主节点下输出日志,单卡则正常输出
+            if conf.is_master_node:
+                print('two model heads generated')
 
             paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)
-            
+
             if conf.use_mobilfacenet:
                 self.optimizer = optim.SGD([
-                                    {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
-                                    {'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4},
-                                    {'params': paras_only_bn}
-                                ], lr = conf.lr, momentum = conf.momentum)
+                    {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
+                    {'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4},
+                    {'params': paras_only_bn}
+                ], lr=conf.lr, momentum=conf.momentum)
             else:
                 self.optimizer = optim.SGD([
-                                    {'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
-                                    {'params': paras_only_bn}
-                                ], lr = conf.lr, momentum = conf.momentum)
-            print(self.optimizer)
-#             self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)
-
-            print('optimizers generated')    
-            self.board_loss_every = len(self.loader)//100
-            self.evaluate_every = len(self.loader)//10
-            self.save_every = len(self.loader)//5
-            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(self.loader.dataset.root.parent)
+                    {'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
+                    {'params': paras_only_bn}
+                ], lr=conf.lr, momentum=conf.momentum)
+
+            # 分布式环境下只在主节点下输出日志,单卡则正常输出
+            if conf.is_master_node:
+                print(self.optimizer)
+                print('optimizers generated')
+
+            #  print step 和 eval step
+            self.board_loss_every = min(500, conf.max_iter // 10) if conf.max_iter != -1 else 500
+            self.evaluate_every = min(5000, conf.max_iter) if conf.max_iter != -1 else len(self.loader)
+
+            # 只需要在lfw做验证即可,故修改此处,减少内存占用
+            self.lfw, self.lfw_issame = get_val_pair(self.loader.dataset.root.parent, 'lfw')
+
+            # apex amp 半精度优化
+            if conf.use_amp:
+                if conf.device_type == 'npu':
+                    self.optimizer = apex.optimizers.NpuFusedSGD([
+                        {'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
+                        {'params': paras_only_bn}
+                    ], lr=conf.lr, momentum=conf.momentum)
+                    self.model, self.optimizer = amp.initialize(self.model,
+                                                                self.optimizer,
+                                                                opt_level=conf.opt_level,
+                                                                loss_scale=conf.loss_scale,
+                                                                combine_grad=True)
+                else:
+                    self.model, self.optimizer = amp.initialize(self.model,
+                                                                self.optimizer,
+                                                                opt_level=conf.opt_level,
+                                                                loss_scale=conf.loss_scale)
+
+            # 分布式训练
+            if conf.distributed:
+                if conf.use_amp and conf.device_type == 'gpu':
+                    from apex.parallel import DistributedDataParallel as DDP
+                    self.model = DDP(self.model)
+                    self.model_without_ddp = self.model.module
+                else:
+                    self.model = DistributedDataParallel(self.model, device_ids=[conf.device])
+                    self.model_without_ddp = self.model.module
+
         else:
             self.threshold = conf.threshold
-    
-    def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False):
+
+    def save_state(self, conf, accuracy, to_save_folder=False, extra=None, epoch=0):
         if to_save_folder:
             save_path = conf.save_path
         else:
             save_path = conf.model_path
-        torch.save(
-            self.model.state_dict(), save_path /
-            ('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
-        if not model_only:
-            torch.save(
-                self.head.state_dict(), save_path /
-                ('head_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
+        if conf.is_master_node:
+            ckpt = {
+                'model': self.model_without_ddp.state_dict(),
+                'head': self.head.state_dict(),
+                'optimizer': self.optimizer.state_dict(),
+                'epoch': epoch,
+                'config': conf,
+            }
             torch.save(
-                self.optimizer.state_dict(), save_path /
-                ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
-    
+                ckpt,
+                save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))
+            )
+
+    # 之前版本所需的加载模型权重接口
     def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False):
         if from_save_folder:
             save_path = conf.save_path
         else:
-            save_path = conf.model_path            
-        self.model.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str)))
+            save_path = conf.model_path
+        self.model.load_state_dict(torch.load(save_path / 'model_{}'.format(fixed_str), map_location='cpu'), )
         if not model_only:
-            self.head.load_state_dict(torch.load(save_path/'head_{}'.format(fixed_str)))
-            self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str)))
-        
+            self.head.load_state_dict(torch.load(save_path / 'head_{}'.format(fixed_str), map_location='cpu'))
+            self.optimizer.load_state_dict(torch.load(save_path / 'optimizer_{}'.format(fixed_str), map_location='cpu'))
+
+    def load_state_dict(self, weights_path, is_finetune=False):
+        ckpt = torch.load(weights_path, map_location='cpu')
+        self.model_without_ddp.load_state_dict(ckpt['model'])
+        if not self.inference and not is_finetune:
+            self.head.load_state_dict(ckpt['head'])
+            self.optimizer.load_state_dict(ckpt['optimizer'])
+            self.start_epoch = ckpt['start_epoch'] + 1
+
     def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
         self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step)
         self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step)
         self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step)
-#         self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
-#         self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
-#         self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)
-        
-    def evaluate(self, conf, carray, issame, nrof_folds = 5, tta = False):
+        print('========================')
+        print(f'{db_name}_accuracy: \t{accuracy}')
+        print(f'{db_name}_best_threshold: \t{best_threshold}')
+        print('========================')
+
+    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
         self.model.eval()
         idx = 0
         embeddings = np.zeros([len(carray), conf.embedding_size])
@@ -102,132 +178,110 @@ class face_learner(object):
                 if tta:
                     fliped = hflip_batch(batch)
                     emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device))
-                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
+                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).detach().cpu().numpy()
                 else:
-                    embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).cpu()
+                    embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).detach().cpu().numpy()
                 idx += conf.batch_size
             if idx < len(carray):
-                batch = torch.tensor(carray[idx:])            
+                batch = torch.tensor(carray[idx:])
                 if tta:
                     fliped = hflip_batch(batch)
                     emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device))
-                    embeddings[idx:] = l2_norm(emb_batch)
+                    embeddings[idx:] = l2_norm(emb_batch).detach().cpu().numpy()
                 else:
-                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
+                    embeddings[idx:] = self.model(batch.to(conf.device)).detach().cpu().numpy()
         tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
         buf = gen_plot(fpr, tpr)
         roc_curve = Image.open(buf)
         roc_curve_tensor = trans.ToTensor()(roc_curve)
         return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor
-    
-    def find_lr(self,
-                conf,
-                init_value=1e-8,
-                final_value=10.,
-                beta=0.98,
-                bloding_scale=3.,
-                num=None):
-        if not num:
-            num = len(self.loader)
-        mult = (final_value / init_value)**(1 / num)
-        lr = init_value
-        for params in self.optimizer.param_groups:
-            params['lr'] = lr
-        self.model.train()
-        avg_loss = 0.
-        best_loss = 0.
-        batch_num = 0
-        losses = []
-        log_lrs = []
-        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):
-
-            imgs = imgs.to(conf.device)
-            labels = labels.to(conf.device)
-            batch_num += 1          
-
-            self.optimizer.zero_grad()
-
-            embeddings = self.model(imgs)
-            thetas = self.head(embeddings, labels)
-            loss = conf.ce_loss(thetas, labels)          
-          
-            #Compute the smoothed loss
-            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
-            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
-            smoothed_loss = avg_loss / (1 - beta**batch_num)
-            self.writer.add_scalar('smoothed_loss', smoothed_loss,batch_num)
-            #Stop if the loss is exploding
-            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
-                print('exited with best_loss at {}'.format(best_loss))
-                plt.plot(log_lrs[10:-5], losses[10:-5])
-                return log_lrs, losses
-            #Record the best loss
-            if smoothed_loss < best_loss or batch_num == 1:
-                best_loss = smoothed_loss
-            #Store the values
-            losses.append(smoothed_loss)
-            log_lrs.append(math.log10(lr))
-            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
-            #Do the SGD step
-            #Update the lr for the next step
-
-            loss.backward()
-            self.optimizer.step()
-
-            lr *= mult
-            for params in self.optimizer.param_groups:
-                params['lr'] = lr
-            if batch_num > num:
-                plt.plot(log_lrs[10:-5], losses[10:-5])
-                return log_lrs, losses    
 
     def train(self, conf, epochs):
         self.model.train()
-        running_loss = 0.            
-        for e in range(epochs):
+        running_loss = 0.
+        for e in range(self.start_epoch, epochs):
+            self.step = 1
             print('epoch {} started'.format(e))
             if e == self.milestones[0]:
                 self.schedule_lr()
             if e == self.milestones[1]:
-                self.schedule_lr()      
+                self.schedule_lr()
             if e == self.milestones[2]:
-                self.schedule_lr()                                 
-            for imgs, labels in tqdm(iter(self.loader)):
+                self.schedule_lr()
+            res = []
+
+            # 分布式训练下,保证每块device所获得的数据不同
+            if conf.distributed:
+                self.loader.set_epoch(e)
+
+            # 开始训练
+            for imgs, labels in iter(self.loader):
+                # 同步流
+                if conf.device_type == 'gpu':
+                    torch.cuda.synchronize()
+                elif conf.device_type == 'npu':
+                    stream = torch.npu.current_stream()
+                    stream.synchronize()
+
+                start = time.time()
                 imgs = imgs.to(conf.device)
                 labels = labels.to(conf.device)
                 self.optimizer.zero_grad()
                 embeddings = self.model(imgs)
                 thetas = self.head(embeddings, labels)
+                end = time.time()
+                res.append(end - start)
                 loss = conf.ce_loss(thetas, labels)
-                loss.backward()
-                running_loss += loss.item()
+
+                # 是否使用apex半精度优化进行反向传播
+                if conf.use_amp:
+                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+                        scaled_loss.backward()
+                    running_loss += scaled_loss.item()
+                else:
+                    loss.backward()
+                    running_loss += loss.item()
+
                 self.optimizer.step()
-                
-                if self.step % self.board_loss_every == 0 and self.step != 0:
+
+                # 分布式环境下只在主节点下输出日志,单卡则正常输出
+                if conf.is_master_node and self.step % self.board_loss_every == 0 and self.step != 0:
                     loss_board = running_loss / self.board_loss_every
+                    now_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
+                    print(f'time:[{now_time}]\tstep: [{self.step}]\ttrain_loss: {loss_board}')
                     self.writer.add_scalar('train_loss', loss_board, self.step)
                     running_loss = 0.
-                
+
                 if self.step % self.evaluate_every == 0 and self.step != 0:
-                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.agedb_30, self.agedb_30_issame)
-                    self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor)
                     accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.lfw, self.lfw_issame)
                     self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor)
-                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.cfp_fp, self.cfp_fp_issame)
-                    self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor)
+
                     self.model.train()
-                if self.step % self.save_every == 0 and self.step != 0:
-                    self.save_state(conf, accuracy)
-                    
+
+                    # 分布式环境下只在主节点下保存权重,单卡则正常输出
+                    if conf.is_master_node:
+                        self.save_state(conf, accuracy, epoch=e)
+
+                if self.step == conf.max_iter:
+                    break
                 self.step += 1
-                
-        self.save_state(conf, accuracy, to_save_folder=True, extra='final')
+
+            time_sum = sum(res)
+            # 分布式环境下只在主节点下输出日志, 单卡则正常保存
+            if conf.is_master_node:
+                print('***************************************')
+                print("fps: %f" % ((len(res) * self.loader.batch_size) / time_sum))
+                print('***************************************')
+
+        # 分布式环境下只在主节点下保存权重, 单卡则正常保存
+        if conf.is_master_node:
+            self.save_state(conf, accuracy, to_save_folder=True, extra='final', epoch=epochs)
 
     def schedule_lr(self):
-        for params in self.optimizer.param_groups:                 
+        for params in self.optimizer.param_groups:
             params['lr'] /= 10
         print(self.optimizer)
-    
+
     def infer(self, conf, faces, target_embs, tta=False):
         '''
         faces : list of PIL Image
@@ -242,12 +296,12 @@ class face_learner(object):
                 emb = self.model(conf.test_transform(img).to(conf.device).unsqueeze(0))
                 emb_mirror = self.model(conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                 embs.append(l2_norm(emb + emb_mirror))
-            else:                        
+            else:
                 embs.append(self.model(conf.test_transform(img).to(conf.device).unsqueeze(0)))
         source_embs = torch.cat(embs)
-        
-        diff = source_embs.unsqueeze(-1) - target_embs.transpose(1,0).unsqueeze(0)
+
+        diff = source_embs.unsqueeze(-1) - target_embs.transpose(1, 0).unsqueeze(0)
         dist = torch.sum(torch.pow(diff, 2), dim=1)
         minimum, min_idx = torch.min(dist, dim=1)
-        min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1
-        return min_idx, minimum               
\ No newline at end of file
+        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
+        return min_idx, minimum
diff --git a/config.py b/config.py
index df514eb..401c0d1 100644
--- a/config.py
+++ b/config.py
@@ -48,4 +48,8 @@ def get_config(training = True):
         #when inference, at maximum detect 10 faces in one image, my laptop is slow
         conf.min_face_size = 30 
         # the larger this value, the faster deduction, comes with tradeoff in small faces
+    # distributed config
+    conf.distributed = 0
+    # apex config
+    conf.use_amp = 0
     return conf
\ No newline at end of file
diff --git a/data/data_pipe.py b/data/data_pipe.py
index bd67f02..a4a74f0 100644
--- a/data/data_pipe.py
+++ b/data/data_pipe.py
@@ -1,5 +1,5 @@
 from pathlib import Path
-from torch.utils.data import Dataset, ConcatDataset, DataLoader
+from torch.utils.data import Dataset, ConcatDataset, DataLoader, DistributedSampler
 from torchvision import transforms as trans
 from torchvision.datasets import ImageFolder
 from PIL import Image, ImageFile
@@ -9,12 +9,11 @@ import cv2
 import bcolz
 import pickle
 import torch
-import mxnet as mx
 from tqdm import tqdm
 
 def de_preprocess(tensor):
     return tensor*0.5 + 0.5
-    
+
 def get_train_dataset(imgs_folder):
     train_transform = trans.Compose([
         trans.RandomHorizontalFlip(),
@@ -31,7 +30,7 @@ def get_train_loader(conf):
         print('ms1m loader generated')
     if conf.data_mode in ['vgg', 'concat']:
         vgg_ds, vgg_class_num = get_train_dataset(conf.vgg_folder/'imgs')
-        print('vgg loader generated')        
+        print('vgg loader generated')
     if conf.data_mode == 'vgg':
         ds = vgg_ds
         class_num = vgg_class_num
@@ -45,9 +44,24 @@ def get_train_loader(conf):
         class_num = vgg_class_num + ms1m_class_num
     elif conf.data_mode == 'emore':
         ds, class_num = get_train_dataset(conf.emore_folder/'imgs')
-    loader = DataLoader(ds, batch_size=conf.batch_size, shuffle=True, pin_memory=conf.pin_memory, num_workers=conf.num_workers)
-    return loader, class_num 
-    
+    if conf.distributed:
+        ds_sample = DistributedSampler(ds, num_replicas=conf.world_size, rank=conf.rank)
+        loader = DataLoader(ds,
+                            batch_size=conf.batch_size,
+                            shuffle=False,
+                            pin_memory=False,
+                            num_workers=conf.num_workers,
+                            sampler=ds_sample,
+                            drop_last=True)
+    else:
+        loader = DataLoader(ds,
+                            batch_size=conf.batch_size,
+                            shuffle=True,
+                            pin_memory=conf.pin_memory,
+                            num_workers=conf.num_workers,
+                            drop_last=True)
+    return loader, class_num
+
 def load_bin(path, rootdir, transform, image_size=[112,112]):
     if not rootdir.exists():
         rootdir.mkdir()
@@ -55,8 +69,8 @@ def load_bin(path, rootdir, transform, image_size=[112,112]):
     data = bcolz.fill([len(bins), 3, image_size[0], image_size[1]], dtype=np.float32, rootdir=rootdir, mode='w')
     for i in range(len(bins)):
         _bin = bins[i]
-        img = mx.image.imdecode(_bin).asnumpy()
-        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+        img_np_arr = np.frombuffer(_bin, np.uint8)
+        img = cv2.imdecode(img_np_arr, cv2.IMREAD_COLOR)
         img = Image.fromarray(img.astype(np.uint8))
         data[i, ...] = transform(img)
         i += 1
@@ -109,10 +123,10 @@ def load_mx_rec(rec_path):
 #                 trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
 #             ])
 #         self.class_num = self.labels[-1] + 1
-        
+
 #     def __len__(self):
 #         return self.length
-    
+
 #     def __getitem__(self, index):
 #         img = torch.tensor(self.imgs[index+1], dtype=torch.float)
 #         label = torch.tensor(self.labels[index+1], dtype=torch.long)
diff --git a/model.py b/model.py
index 10c6f67..d1944bd 100644
--- a/model.py
+++ b/model.py
@@ -269,9 +269,10 @@ class Arcface(Module):
         #      0<=theta+m<=pi
         #     -m<=theta<=pi-m
         cond_v = cos_theta - self.threshold
-        cond_mask = cond_v <= 0
         keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
-        cos_theta_m[cond_mask] = keep_val[cond_mask]
+        cos_theta_m = torch.where(cond_v <= 0, keep_val, cos_theta_m)
+        # cond_mask = cond_v <= 0
+        # cos_theta_m[cond_mask] = keep_val[cond_mask]
         output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
         idx_ = torch.arange(0, nB, dtype=torch.long)
         output[idx_, label] = cos_theta_m[idx_, label]
diff --git a/mtcnn.py b/mtcnn.py
index d8b94d1..af446a6 100644
--- a/mtcnn.py
+++ b/mtcnn.py
@@ -6,14 +6,19 @@ from mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
 from mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
 from mtcnn_pytorch.src.first_stage import run_first_stage
 from mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-# device = 'cpu'
+from config import get_config
 
 class MTCNN():
-    def __init__(self):
-        self.pnet = PNet().to(device)
-        self.rnet = RNet().to(device)
-        self.onet = ONet().to(device)
+    def __init__(self, config=None):
+        if not config:
+            conf = get_config()
+            self.device = conf.device
+        else:
+            conf = config
+            self.device = conf.device
+        self.pnet = PNet().to(self.device)
+        self.rnet = RNet().to(self.device)
+        self.onet = ONet().to(self.device)
         self.pnet.eval()
         self.rnet.eval()
         self.onet.eval()
@@ -82,7 +87,7 @@ class MTCNN():
         with torch.no_grad():
             # run P-Net on different scales
             for s in scales:
-                boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
+                boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0], device=self.device)
                 bounding_boxes.append(boxes)
 
             # collect boxes (and offsets, and scores) from different scales
@@ -102,7 +107,7 @@ class MTCNN():
             # STAGE 2
 
             img_boxes = get_image_boxes(bounding_boxes, image, size=24)
-            img_boxes = torch.FloatTensor(img_boxes).to(device)
+            img_boxes = torch.FloatTensor(img_boxes).to(self.device)
 
             output = self.rnet(img_boxes)
             offsets = output[0].cpu().data.numpy()  # shape [n_boxes, 4]
@@ -124,7 +129,7 @@ class MTCNN():
             img_boxes = get_image_boxes(bounding_boxes, image, size=48)
             if len(img_boxes) == 0: 
                 return [], []
-            img_boxes = torch.FloatTensor(img_boxes).to(device)
+            img_boxes = torch.FloatTensor(img_boxes).to(self.device)
             output = self.onet(img_boxes)
             landmarks = output[0].cpu().data.numpy()  # shape [n_boxes, 10]
             offsets = output[1].cpu().data.numpy()  # shape [n_boxes, 4]
diff --git a/mtcnn_pytorch/src/first_stage.py b/mtcnn_pytorch/src/first_stage.py
index 55ed04a..e396355 100644
--- a/mtcnn_pytorch/src/first_stage.py
+++ b/mtcnn_pytorch/src/first_stage.py
@@ -4,10 +4,9 @@ import math
 from PIL import Image
 import numpy as np
 from .box_utils import nms, _preprocess
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-# device = 'cpu'
 
-def run_first_stage(image, net, scale, threshold):
+
+def run_first_stage(image, net, scale, threshold, device=None):
     """Run P-Net, generate bounding boxes, and do NMS.
 
     Arguments:
@@ -23,7 +22,8 @@ def run_first_stage(image, net, scale, threshold):
         a float numpy array of shape [n_boxes, 9],
             bounding boxes with scores and offsets (4 + 1 + 4).
     """
-
+    if not device:
+        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     # scale the image and convert it to a float array
     width, height = image.size
     sw, sh = math.ceil(width*scale), math.ceil(height*scale)
diff --git a/mtcnn_pytorch/src/get_nets.py b/mtcnn_pytorch/src/get_nets.py
index ebf7748..0dc0e14 100644
--- a/mtcnn_pytorch/src/get_nets.py
+++ b/mtcnn_pytorch/src/get_nets.py
@@ -52,7 +52,7 @@ class PNet(nn.Module):
         self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
         self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
 
-        weights = np.load('mtcnn_pytorch/src/weights/pnet.npy')[()]
+        weights = np.load('mtcnn_pytorch/src/weights/pnet.npy', allow_pickle=True)[()]
         for n, p in self.named_parameters():
             p.data = torch.FloatTensor(weights[n])
 
@@ -97,7 +97,7 @@ class RNet(nn.Module):
         self.conv5_1 = nn.Linear(128, 2)
         self.conv5_2 = nn.Linear(128, 4)
 
-        weights = np.load('mtcnn_pytorch/src/weights/rnet.npy')[()]
+        weights = np.load('mtcnn_pytorch/src/weights/rnet.npy', allow_pickle=True)[()]
         for n, p in self.named_parameters():
             p.data = torch.FloatTensor(weights[n])
 
@@ -148,7 +148,7 @@ class ONet(nn.Module):
         self.conv6_2 = nn.Linear(256, 4)
         self.conv6_3 = nn.Linear(256, 10)
 
-        weights = np.load('mtcnn_pytorch/src/weights/onet.npy')[()]
+        weights = np.load('mtcnn_pytorch/src/weights/onet.npy', allow_pickle=True)[()]
         for n, p in self.named_parameters():
             p.data = torch.FloatTensor(weights[n])
 
diff --git a/requirements.txt b/requirements.txt
index a12d700..48b48b6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,17 @@
-torch==0.4.0
-numpy==1.14.5
-matplotlib==2.1.2
-tqdm==4.23.4
-mxnet_cu90==1.2.1
-scipy==1.0.0
-bcolz==1.2.1
-easydict==1.7
-opencv_python==3.4.0.12
-Pillow==5.2.0
-mxnet==1.2.1.post1
-scikit_learn==0.19.2
-tensorboardX==1.2
-torchvision==0.2.1
+pytorch==1.5.0
+torchvision==0.6.0
+opencv-python==4.5.3.56
+easydict==1.9
+cython==0.29.24
+packaging>=21.0
+setuptools>=52.0.0
+bcolz
+mxnet==1.4.1
+tqdm==4.62.2
+scikit-learn==0.24.2
+tensorboardX==2.4
+matplotlib==3.4.3
+numpy==1.21.2
+onnx==1.10.1
+onnx-simplifier==0.3.6
+apex
\ No newline at end of file
diff --git a/train.py b/train.py
index 4a4aba2..b91e62b 100644
--- a/train.py
+++ b/train.py
@@ -1,32 +1,122 @@
+import os
+import random
+import argparse
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
 from config import get_config
 from Learner import face_learner
-import argparse
 
-# python train.py -net mobilefacenet -b 200 -w 4
+
+def set_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    os.environ['PYTHONHASHSEED'] = str(seed)
+
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='for face verification')
-    parser.add_argument("-e", "--epochs", help="training epochs", default=20, type=int)
-    parser.add_argument("-net", "--net_mode", help="which network, [ir, ir_se, mobilefacenet]",default='ir_se', type=str)
-    parser.add_argument("-depth", "--net_depth", help="how many layers [50,100,152]", default=50, type=int)
-    parser.add_argument('-lr','--lr',help='learning rate',default=1e-3, type=float)
-    parser.add_argument("-b", "--batch_size", help="batch_size", default=96, type=int)
-    parser.add_argument("-w", "--num_workers", help="workers number", default=3, type=int)
-    parser.add_argument("-d", "--data_mode", help="use which database, [vgg, ms1m, emore, concat]",default='emore', type=str)
-    args = parser.parse_args()
+    parser.add_argument("-net_mode", help="which network, [ir, ir_se, mobilefacenet]",default='ir_se', type=str)
+    parser.add_argument("-net_depth", help="how many layers [50,100,152]", default=100, type=int)
+    parser.add_argument("-data_mode", help="use which database, [vgg, ms1m, emore, concat]",default='emore', type=str)
+    parser.add_argument("-data_path", help="data dir", default='./data/faces_emore', type=str)
+    parser.add_argument("-max_iter", help="max_iter", default=1000, type=int)
+    parser.add_argument("-start_epoch", help="train epoch", default=0, type=int)
+    parser.add_argument("-epochs", help="training epochs", default=20, type=int)
+    parser.add_argument("-batch_size", help="batch_size", default=96, type=int)
+    parser.add_argument('-lr', help='learning rate', default=1e-3, type=float)
+    parser.add_argument("-num_workers", help="workers number", default=3, type=int)
+    parser.add_argument("-seed", help="train seed", default=2021, type=int)
+
+    # finetune
+    parser.add_argument("-is_finetune", help="if finetune", default=0, type=int)
+    parser.add_argument("-resume", help="reload pretrained weights path", default="", type=str)
+    parser.add_argument("-num_classes", help="if you want finetune, set num_classes", default=0, type=int)
+
+    # distributed
+    parser.add_argument("-device_id", help="device_id", default=0)
+    parser.add_argument("-distributed", help="if distributed", default=0, type=int)
+    parser.add_argument("-backend", help="", default='nccl', type=str)
+    parser.add_argument("-dist_url", help="", default='tcp://127.0.0.1:41111', type=str)
+    parser.add_argument("-nodes", help="all of nodes", default=1, type=int, metavar='N')
+    parser.add_argument("-gpus", help="number of gpus per node", default=1, type=int)
+    parser.add_argument("-nr", help="ranking within the nodes", default=0, type=int)
+
+    # apex amp
+    parser.add_argument("-use_amp", help="if use amp", default=0, type=int)
+    parser.add_argument("-opt_level", help="apex amp level, [O1, O2]", default='O2', type=str)
+    parser.add_argument("-loss_scale", help="apex amp loss scale, [128.0, None]", default=128.0, type=float)
 
+    # device type gpu or npu
+    parser.add_argument("-device_type", help="device type, [gpu, npu, cpu]", default="gpu", type=str)
+
+    # init config
+    args = parser.parse_args()
     conf = get_config()
-    
+    conf.is_master_node = not args.distributed or args.device_id == 0
+    # set seed
+    set_seed(args.seed)
+
+    # model config
     if args.net_mode == 'mobilefacenet':
         conf.use_mobilfacenet = True
     else:
         conf.net_mode = args.net_mode
-        conf.net_depth = args.net_depth    
-    
+        conf.net_depth = args.net_depth
+
+    # train config
     conf.lr = args.lr
     conf.batch_size = args.batch_size
     conf.num_workers = args.num_workers
     conf.data_mode = args.data_mode
+    conf.emore_folder = Path(args.data_path)
+    conf.max_iter = args.max_iter
+    conf.start_epoch = args.start_epoch
+
+    # distributed or only one device
+    conf.distributed = args.distributed
+    conf.device_type = args.device_type
+    if args.device_type == 'gpu':
+        conf.device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu")
+    elif args.device_type == 'npu':
+        conf.device = torch.device(f"npu:{args.device_id}" if torch.npu.is_available() else "cpu")
+        torch.npu.set_device(conf.device)
+    elif args.device_type == 'cpu':
+        # Untrustworthy
+        conf.device = torch.device('cpu')
+    else:
+        raise ValueError('device type error,please choice in ["gpu","npu","cpu"]')
+
+    # distributed config
+    if conf.distributed:
+        args.world_size = args.gpus * args.nodes  #
+        rank = args.nr * args.gpus + int(args.device_id)
+        dist.init_process_group(
+            backend=args.backend,
+            init_method=args.dist_url,
+            world_size=args.world_size,
+            rank=rank
+        )
+        conf.world_size = args.world_size
+        conf.rank = rank
+
+    # apex amp config
+    conf.use_amp = args.use_amp
+    conf.opt_level = args.opt_level
+    conf.loss_scale = args.loss_scale
+
     learner = face_learner(conf)
+    # 加载预训练模型
+    if args.resume:
+        learner.load_state_dict(args.resume, is_finetune=args.is_finetune)
+
+    # TODO if you want finetune, you just need to create train data loader
+    # if args.is_finetune:
+    #     learner.loader, learner.class_num = get_train_loader()
+    #     assert learner.class_num == args.num_classes
 
-    learner.train(conf, args.epochs)
\ No newline at end of file
+    learner.train(conf, args.epochs)
diff --git a/utils.py b/utils.py
index bb3f100..f34c929 100644
--- a/utils.py
+++ b/utils.py
@@ -67,7 +67,7 @@ def prepare_facebank(conf, model, mtcnn, tta = True):
     return embeddings, names
 
 def load_facebank(conf):
-    embeddings = torch.load(conf.facebank_path/'facebank.pth')
+    embeddings = torch.load(conf.facebank_path / 'facebank.pth', map_location='cpu')
     names = np.load(conf.facebank_path/'names.npy')
     return embeddings, names
 
diff --git a/verifacation.py b/verifacation.py
index ff7d5ed..556592a 100644
--- a/verifacation.py
+++ b/verifacation.py
@@ -29,7 +29,6 @@ from sklearn.decomposition import PCA
 import sklearn
 from scipy import interpolate
 import datetime
-import mxnet as mx
 
 def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca=0):
     assert (embeddings1.shape[0] == embeddings2.shape[0])