diff --git a/config.py b/config.py
index a13f61d..c688274 100644
--- a/config.py
+++ b/config.py
@@ -2,7 +2,7 @@ conf = {
     "WORK_PATH": "./work",
     "CUDA_VISIBLE_DEVICES": "0,1,2,3",
     "data": {
-        'dataset_path': "your_dataset_path",
+        'dataset_path': "./predata",
         'resolution': '64',
         'dataset': 'CASIA-B',
         # In CASIA-B, data of subject #5 is incomplete.
diff --git a/model/model.py b/model/model.py
index 3f0737c..3056a1b 100644
--- a/model/model.py
+++ b/model/model.py
@@ -1,20 +1,77 @@
-import math
+import math
 import os
 import os.path as osp
 import random
 import sys
 from datetime import datetime
-
+# from apex import amp
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.autograd as autograd
 import torch.optim as optim
 import torch.utils.data as tordata
+from tqdm import tqdm
+
+import argparse
+import shutil
+import warnings
+import time
+
+import torch.distributed
 
 from .network import TripletLoss, SetNet
 from .utils import TripletSampler
 
+class wrapperNet(nn.Module):
+    def __init__(self, module):
+        super(wrapperNet, self).__init__()
+        self.module = module
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self, name, fmt=':f', start_count_index=10):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+        self.start_count_index = start_count_index
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        if self.count == 0:
+            self.N = n
+
+        self.val = val
+        self.count += n
+        if self.count > (self.start_count_index * self.N):
+            self.sum += val * n
+            self.avg = self.sum / (self.count - self.start_count_index * self.N)
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+class ProgressMeter(object):
+    def __init__(self, n_batches, meters, prefix=""):
+        self.batch_fmtstr = self._get_batch_fmtstr(n_batches)
+        self.meters = meters
+        self.prefix = prefix
+
+    def display(self, batch):
+        entries = [self.prefix + self.batch_fmtstr.format(batch)]
+        entries += [str(meter) for meter in self.meters]
+        print(entries)
+
+    def _get_batch_fmtstr(self, n_batches):
+        n_digits = len(str(n_batches // 1))
+        fmt = '{:' + str(n_digits) + 'd}'
+        return '[' + fmt + '/' + fmt.format(n_batches) + ']'
 
 class Model:
     def __init__(self,
@@ -53,18 +110,48 @@ class Model:
         self.total_iter = total_iter
 
         self.img_size = img_size
-
-        self.encoder = SetNet(self.hidden_dim).float()
-        self.encoder = nn.DataParallel(self.encoder)
-        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
-        self.triplet_loss = nn.DataParallel(self.triplet_loss)
-        self.encoder.cuda()
-        self.triplet_loss.cuda()
-
-        self.optimizer = optim.Adam([
-            {'params': self.encoder.parameters()},
-        ], lr=self.lr)
-
+        
+        use_dist = False
+        '''
+        try:
+            local_rank = torch.distributed.get_rank()
+        except AssertionError:  # Default process group is not initialized
+            use_dist = False
+        '''
+        
+        if use_dist:
+            self.encoder = SetNet(self.hidden_dim).float()
+            self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
+            self.optimizer = optim.Adam([
+                {'params': self.encoder.parameters()},
+            ], lr=self.lr)
+            
+            self.local_device = f'cpu:{local_rank}'
+            self.encoder.to(self.local_device)
+            self.triplet_loss.to(self.local_device)
+            
+            # self.encoder,self.optimizer = amp.initialize(self.encoder,self.optimizer,opt_level="O2", loss_scale=32.0)
+            
+            local_rank = torch.distributed.get_rank()
+            if torch.cpu.device_count() > 1:
+                print("Let's use",torch.cpu.device_count(),"CPUs!")
+                print('-----RANK=', local_rank)
+                self.encoder = nn.parallel.DistributedDataParallel(self.encoder, broadcast_buffers=False, device_ids=[local_rank])
+        else:
+            self.local_device = 'cpu'
+            self.encoder = SetNet(self.hidden_dim).float()
+            self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
+            self.triplet_loss = nn.DataParallel(self.triplet_loss)
+            self.encoder = self.encoder.cpu()
+            self.triplet_loss = self.triplet_loss.cpu()
+            
+            self.optimizer = optim.Adam([
+                {'params': self.encoder.parameters()},
+            ], lr=self.lr)
+            
+            # self.encoder,self.optimizer = amp.initialize(self.encoder,self.optimizer,opt_level="O2", loss_scale=64.0)
+            self.encoder = nn.DataParallel(self.encoder)
+        
         self.hard_loss_metric = []
         self.full_loss_metric = []
         self.full_loss_num = []
@@ -81,47 +168,52 @@ class Model:
         view = [batch[i][2] for i in range(batch_size)]
         seq_type = [batch[i][3] for i in range(batch_size)]
         label = [batch[i][4] for i in range(batch_size)]
+        
         batch = [seqs, view, seq_type, label, None]
-
+        
         def select_frame(index):
             sample = seqs[index]
             frame_set = frame_sets[index]
+            
             if self.sample_type == 'random':
-                frame_id_list = random.choices(frame_set, k=self.frame_num)
+                frame_list = sorted(list(frame_set))
+                
+                frame_id_list = random.choices(frame_list, k=self.frame_num)
+                
                 _ = [feature.loc[frame_id_list].values for feature in sample]
             else:
                 _ = [feature.values for feature in sample]
             return _
-
+        
         seqs = list(map(select_frame, range(len(seqs))))
 
         if self.sample_type == 'random':
             seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
         else:
-            gpu_num = min(torch.cuda.device_count(), batch_size)
-            batch_per_gpu = math.ceil(batch_size / gpu_num)
+            npu_num =  batch_size
+            batch_per_npu = math.ceil(batch_size / npu_num)
             batch_frames = [[
                                 len(frame_sets[i])
-                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
+                                for i in range(batch_per_npu * _, batch_per_npu * (_ + 1))
                                 if i < batch_size
-                                ] for _ in range(gpu_num)]
-            if len(batch_frames[-1]) != batch_per_gpu:
-                for _ in range(batch_per_gpu - len(batch_frames[-1])):
+                                ] for _ in range(npu_num)]
+            if len(batch_frames[-1]) != batch_per_npu:
+                for _ in range(batch_per_npu - len(batch_frames[-1])):
                     batch_frames[-1].append(0)
-            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
+            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(npu_num)])
             seqs = [[
                         np.concatenate([
                                            seqs[i][j]
-                                           for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
+                                           for i in range(batch_per_npu * _, batch_per_npu * (_ + 1))
                                            if i < batch_size
-                                           ], 0) for _ in range(gpu_num)]
+                                           ], 0) for _ in range(npu_num)]
                     for j in range(feature_num)]
             seqs = [np.asarray([
                                    np.pad(seqs[j][_],
                                           ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
                                           'constant',
                                           constant_values=0)
-                                   for _ in range(gpu_num)])
+                                   for _ in range(npu_num)])
                     for j in range(feature_num)]
             batch[4] = np.asarray(batch_frames)
 
@@ -129,6 +221,14 @@ class Model:
         return batch
 
     def fit(self):
+        is_8p = torch.cpu.device_count() > 1
+        
+        batch_time = AverageMeter('Time', ':6.3f')
+        data_time = AverageMeter('Data', ':6.3f')
+        hard_loss_mean = AverageMeter('Hard_Loss', ':.6e', start_count_index=0)
+        full_loss_mean = AverageMeter('Full_Loss', ':.6e', start_count_index=0)
+        p_full_loss_num = AverageMeter('Full_Loss_Num', ':6.3e', start_count_index=0)
+        
         if self.restore_iter != 0:
             self.load(self.restore_iter)
 
@@ -136,89 +236,113 @@ class Model:
         self.sample_type = 'random'
         for param_group in self.optimizer.param_groups:
             param_group['lr'] = self.lr
+        
+        local_rank = 'npu'
+        try:
+            local_rank = torch.distributed.get_rank()
+        except AssertionError:
+            pass
+        
         triplet_sampler = TripletSampler(self.train_source, self.batch_size)
+        
         train_loader = tordata.DataLoader(
             dataset=self.train_source,
+            # shuffle=False,
+            # batch_size=self.batch_size,
+            # pin_memory=False,
             batch_sampler=triplet_sampler,
             collate_fn=self.collate_fn,
             num_workers=self.num_workers)
-
         train_label_set = list(self.train_source.label_set)
         train_label_set.sort()
-
         _time1 = datetime.now()
-        for seq, view, seq_type, label, batch_frame in train_loader:
+        
+        progress = ProgressMeter(
+            len(train_loader),
+            [batch_time, data_time, hard_loss_mean, full_loss_mean, p_full_loss_num],
+            prefix="Iter[{}]".format(self.restore_iter))
+        start_time = time.time()
+        
+        for iter_i, _t_data in enumerate(train_loader):
+            data_time.update(time.time() - start_time)
+            
+            seq, view, seq_type, label, batch_frame = _t_data
+            # triplet_sampler.set_epoch(self.restore_iter)
+            
             self.restore_iter += 1
             self.optimizer.zero_grad()
-
+            
             for i in range(len(seq)):
                 seq[i] = self.np2var(seq[i]).float()
             if batch_frame is not None:
                 batch_frame = self.np2var(batch_frame).int()
-
-            feature, label_prob = self.encoder(*seq, batch_frame)
-
+            
+            feature = self.encoder(*seq, batch_frame)
+            
             target_label = [train_label_set.index(l) for l in label]
             target_label = self.np2var(np.array(target_label)).long()
 
             triplet_feature = feature.permute(1, 0, 2).contiguous()
-            triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
+            triplet_label = target_label.unsqueeze(0).cpu().repeat(triplet_feature.size(0), 1)
+            
+            triplet_feature = triplet_feature.cpu()
+            triplet_label = triplet_label.cpu()
+            
             (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
              ) = self.triplet_loss(triplet_feature, triplet_label)
             if self.hard_or_full_trip == 'hard':
                 loss = hard_loss_metric.mean()
             elif self.hard_or_full_trip == 'full':
                 loss = full_loss_metric.mean()
-
+            
             self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
             self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
             self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
             self.dist_list.append(mean_dist.mean().data.cpu().numpy())
-
+            
             if loss > 1e-9:
+                # with amp.scale_loss(loss,self.optimizer) as scaled_loss:
+                #     scaled_loss.backward()
                 loss.backward()
                 self.optimizer.step()
 
             if self.restore_iter % 1000 == 0:
-                print(datetime.now() - _time1)
+                print(f"[{local_rank}]:", datetime.now() - _time1)
                 _time1 = datetime.now()
 
-            if self.restore_iter % 100 == 0:
-                self.save()
-                print('iter {}:'.format(self.restore_iter), end='')
-                print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
-                print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
-                print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
+            if self.restore_iter % 10 == 0:
+                print(f"[{local_rank}]: ", 'iter {}:'.format(self.restore_iter), end='')
+                
                 self.mean_dist = np.mean(self.dist_list)
-                print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
-                print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
-                print(', hard or full=%r' % self.hard_or_full_trip)
+                print('mean_dist={0:.8f}'.format(self.mean_dist))
+                
+                hard_loss_mean.update(np.mean(self.hard_loss_metric), self.P * self.M)
+                full_loss_mean.update(np.mean(self.full_loss_metric), self.P * self.M)
+                p_full_loss_num.update(np.mean(self.full_loss_num), self.P * self.M)
+                progress.display(self.restore_iter)
+                
                 sys.stdout.flush()
                 self.hard_loss_metric = []
                 self.full_loss_metric = []
                 self.full_loss_num = []
                 self.dist_list = []
-
-            # Visualization using t-SNE
-            # if self.restore_iter % 500 == 0:
-            #     pca = TSNE(2)
-            #     pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
-            #     for i in range(self.P):
-            #         plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
-            #                     pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
-            #
-            #     plt.show()
+            
+            batch_time.update(time.time() - start_time)
+            start_time = time.time()
+            
+            if self.restore_iter % 200 == 0:
+                self.save()
 
             if self.restore_iter == self.total_iter:
                 break
-
+    
     def ts2var(self, x):
-        return autograd.Variable(x).cuda()
-
+        return autograd.Variable(x).to(self.local_device, non_blocking=False)
+    
     def np2var(self, x):
         return self.ts2var(torch.from_numpy(x))
-
-    def transform(self, flag, batch_size=1):
+    
+    def transform(self, flag, bin_file_path=None, batch_size=1, output_path=None, pre_process=False, post_process=False):
         self.encoder.eval()
         source = self.test_source if flag == 'test' else self.train_source
         self.sample_type = 'all'
@@ -233,26 +357,65 @@ class Model:
         view_list = list()
         seq_type_list = list()
         label_list = list()
-
-        for i, x in enumerate(data_loader):
+        
+        test_len = len(data_loader)
+        
+        for i, x in tqdm(enumerate(data_loader), total = len(data_loader)):
+            import time
+            cvt_time = time.time()
+            
             seq, view, seq_type, label, batch_frame = x
             for j in range(len(seq)):
                 seq[j] = self.np2var(seq[j]).float()
             if batch_frame is not None:
                 batch_frame = self.np2var(batch_frame).int()
-            # print(batch_frame, np.sum(batch_frame))
-
-            feature, _ = self.encoder(*seq, batch_frame)
+            
+            if pre_process:
+                bin_img_path = os.path.abspath(bin_file_path + '/'+ f'{i:0>4d}.bin')
+                
+                align_size = 100
+                
+                # new pre-process align by repeat itself
+                cat_seq = None
+                seq[0] = seq[0].detach().cpu().float()
+                org_size = seq[0].shape[1]
+                if org_size < align_size:
+                    pad_shape = list(seq[0].shape)
+                    pad_shape[1] = align_size - org_size
+                    pad_zeros = torch.zeros(pad_shape).float()
+                    cat_seq = torch.cat([pad_zeros.float(), seq[0].float()], dim=1)
+                else:
+                    cat_seq = seq[0].float()
+                    while cat_seq.shape[1] < align_size:
+                        cat_seq = torch.cat([cat_seq, seq[0].float()], dim=1)
+                    cat_seq = cat_seq[:, :align_size, :, :]
+                
+                cat_seq.numpy().tofile(bin_img_path)
+                
+                continue  # pre-processing, skip model calculation
+            
+            # add post_process
+            feature = None
+            if post_process == False:
+                feature = self.encoder(*seq, batch_frame)
+            else:
+                feat = np.fromfile(output_path+ '/'+ f'{i:0>4d}_0.bin', dtype=np.float32)
+                feature = torch.Tensor(feat).float().cpu().view(1, -1, 256)
+            
             n, num_bin, _ = feature.size()
             feature_list.append(feature.view(n, -1).data.cpu().numpy())
             view_list += view
             seq_type_list += seq_type
             label_list += label
-
+        if pre_process:
+            return None
         return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list
 
     def save(self):
         os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
+        local_rank = torch.distributed.get_rank()
+        if local_rank != 0:
+            return
         torch.save(self.encoder.state_dict(),
                    osp.join('checkpoint', self.model_name,
                             '{}-{:0>5}-encoder.ptm'.format(
@@ -262,11 +425,18 @@ class Model:
                             '{}-{:0>5}-optimizer.ptm'.format(
                                 self.save_name, self.restore_iter)))
 
-    # restore_iter: iteration index of the checkpoint to load
+    # restore_iter, iteration index of the checkpoint to load
     def load(self, restore_iter):
-        self.encoder.load_state_dict(torch.load(osp.join(
-            'checkpoint', self.model_name,
-            '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
         self.optimizer.load_state_dict(torch.load(osp.join(
             'checkpoint', self.model_name,
-            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
+            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
+        try:
+            self.encoder.load_state_dict(torch.load(osp.join(
+                'checkpoint', self.model_name,
+                '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
+        except RuntimeError:
+            wrapped = wrapperNet(self.encoder)
+            wrapped.load_state_dict(torch.load(osp.join(
+                'checkpoint', self.model_name,
+                '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)), map_location=torch.device('cpu')))
+            self.encoder = wrapped.module
diff --git a/model/network/gaitset.py b/model/network/gaitset.py
index 45c2e34..e1b7375 100644
--- a/model/network/gaitset.py
+++ b/model/network/gaitset.py
@@ -117,4 +117,4 @@ class SetNet(nn.Module):
         feature = feature.matmul(self.fc_bin[0])
         feature = feature.permute(1, 0, 2).contiguous()
 
-        return feature, None
+        return feature
diff --git a/model/network/triplet.py b/model/network/triplet.py
index ebb96d5..35a1780 100644
--- a/model/network/triplet.py
+++ b/model/network/triplet.py
@@ -14,20 +14,21 @@ class TripletLoss(nn.Module):
         n, m, d = feature.size()
         hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1)
         hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1)
-
+        
         dist = self.batch_dist(feature)
         mean_dist = dist.mean(1).mean(1)
         dist = dist.view(-1)
+        
         # hard
-        hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0]
-        hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0]
+        hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask.bool()).view(n, m, -1), 2)[0]
+        hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask.bool()).view(n, m, -1), 2)[0]
         hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1)
 
         hard_loss_metric_mean = torch.mean(hard_loss_metric, 1)
 
         # non-zero full
-        full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1)
-        full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1)
+        full_hp_dist = torch.masked_select(dist, hp_mask.bool()).view(n, m, -1, 1)
+        full_hn_dist = torch.masked_select(dist, hn_mask.bool()).view(n, m, 1, -1)
         full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1)
 
         full_loss_metric_sum = full_loss_metric.sum(1)
diff --git a/model/utils/data_loader.py b/model/utils/data_loader.py
index aa02ccc..53af26d 100644
--- a/model/utils/data_loader.py
+++ b/model/utils/data_loader.py
@@ -39,7 +39,7 @@ def load_data(dataset_path, resolution, dataset, pid_num, pid_shuffle, cache=Tru
         os.makedirs('partition', exist_ok=True)
         np.save(pid_fname, pid_list)
 
-    pid_list = np.load(pid_fname)
+    pid_list = np.load(pid_fname, allow_pickle=True)
     train_list = pid_list[0]
     test_list = pid_list[1]
     train_source = DataSet(
diff --git a/model/utils/data_set.py b/model/utils/data_set.py
index 2048fbc..8d7415a 100644
--- a/model/utils/data_set.py
+++ b/model/utils/data_set.py
@@ -15,11 +15,11 @@ class DataSet(tordata.Dataset):
         self.label = label
         self.cache = cache
         self.resolution = int(resolution)
-        self.cut_padding = int(float(resolution)/64*10)
+        self.cut_padding = int(float(resolution) / 64 * 10)
         self.data_size = len(self.label)
         self.data = [None] * self.data_size
         self.frame_set = [None] * self.data_size
-
+        
         self.label_set = set(self.label)
         self.seq_type_set = set(self.seq_type)
         self.view_set = set(self.view)
@@ -57,16 +57,16 @@ class DataSet(tordata.Dataset):
         if not self.cache:
             data = [self.__loader__(_path) for _path in self.seq_dir[index]]
             frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
-            frame_set = list(set.intersection(*frame_set))
+            frame_set = sorted(list(set.intersection(*frame_set)))
         elif self.data[index] is None:
             data = [self.__loader__(_path) for _path in self.seq_dir[index]]
             frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
-            frame_set = list(set.intersection(*frame_set))
+            frame_set = sorted(list(set.intersection(*frame_set)))
             self.data[index] = data
             self.frame_set[index] = frame_set
         else:
             data = self.data[index]
-            frame_set = self.frame_set[index]
+            frame_set = sorted(list(self.frame_set[index]))
 
         return data, frame_set, self.view[
             index], self.seq_type[index], self.label[index],
@@ -79,6 +79,7 @@ class DataSet(tordata.Dataset):
                       for _img_path in imgs
                       if osp.isfile(osp.join(flie_path, _img_path))]
         num_list = list(range(len(frame_list)))
+        
         data_dict = xr.DataArray(
             frame_list,
             coords={'frame': num_list},
diff --git a/model/utils/evaluator.py b/model/utils/evaluator.py
index 9fd0b3a..626fbfb 100644
--- a/model/utils/evaluator.py
+++ b/model/utils/evaluator.py
@@ -4,8 +4,8 @@ import numpy as np
 
 
 def cuda_dist(x, y):
-    x = torch.from_numpy(x).cuda()
-    y = torch.from_numpy(y).cuda()
+    x = torch.from_numpy(x)
+    y = torch.from_numpy(y)
     dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze(
         1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1))
     dist = torch.sqrt(F.relu(dist))