05360171创建于 2022年3月18日历史提交
diff --git a/lib/config/default.py b/lib/config/default.py
index de7527b..de72c6b 100644
--- a/lib/config/default.py
+++ b/lib/config/default.py
@@ -12,6 +12,10 @@ from __future__ import print_function
 import os
 
 from yacs.config import CfgNode as CN
+import os
+NPU_CALCULATE_DEVICE = 0
+if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')):
+    NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE'))
 
 
 _C = CN()
@@ -143,27 +147,6 @@ def update_config(cfg, args):
     cfg.merge_from_file(args.cfg)
     cfg.merge_from_list(args.opts)
 
-    if args.modelDir:
-        cfg.OUTPUT_DIR = args.modelDir
-
-    if args.logDir:
-        cfg.LOG_DIR = args.logDir
-
-    if args.dataDir:
-        cfg.DATA_DIR = args.dataDir
-
-    cfg.DATASET.ROOT = os.path.join(
-        cfg.DATA_DIR, cfg.DATASET.ROOT
-    )
-
-    cfg.MODEL.PRETRAINED = os.path.join(
-        cfg.DATA_DIR, cfg.MODEL.PRETRAINED
-    )
-
-    if cfg.TEST.MODEL_FILE:
-        cfg.TEST.MODEL_FILE = os.path.join(
-            cfg.DATA_DIR, cfg.TEST.MODEL_FILE
-        )
 
     cfg.freeze()
 
diff --git a/lib/config/models.py b/lib/config/models.py
index 8e04c4f..1fcc06e 100644
--- a/lib/config/models.py
+++ b/lib/config/models.py
@@ -9,6 +9,11 @@ from __future__ import division
 from __future__ import print_function
 
 from yacs.config import CfgNode as CN
+import os
+
+NPU_CALCULATE_DEVICE = 0
+if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')):
+    NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE'))
 
 
 # pose_resnet related params
diff --git a/lib/core/evaluate.py b/lib/core/evaluate.py
index cf72285..3dceef3 100644
--- a/lib/core/evaluate.py
+++ b/lib/core/evaluate.py
@@ -12,7 +12,6 @@ import numpy as np
 
 from core.inference import get_max_preds
 
-
 def calc_dists(preds, target, normalize):
     preds = preds.astype(np.float32)
     target = target.astype(np.float32)
diff --git a/lib/core/function.py b/lib/core/function.py
index 46f9da0..33ad6c3 100644
--- a/lib/core/function.py
+++ b/lib/core/function.py
@@ -15,6 +15,7 @@ import os
 
 import numpy as np
 import torch
+from apex import amp
 
 from core.evaluate import accuracy
 from core.inference import get_final_preds
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
 
 
 def train(config, train_loader, model, criterion, optimizer, epoch,
-          output_dir, tb_log_dir, writer_dict):
+          output_dir, tb_log_dir, writer_dict, is_amp=False, device=None):
     batch_time = AverageMeter()
     data_time = AverageMeter()
     losses = AverageMeter()
@@ -35,17 +36,22 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
     # switch to train mode
     model.train()
 
+
     end = time.time()
     for i, (input, target, target_weight, meta) in enumerate(train_loader):
         # measure data loading time
+        # if i == 1:
+        #     import sys
+        #     sys.exit()
         data_time.update(time.time() - end)
-
+        # input.npu()
         # compute output
-        outputs = model(input)
-
-        target = target.cuda(non_blocking=True)
-        target_weight = target_weight.cuda(non_blocking=True)
+        input = input.to(device, non_blocking=True)
+        target = target.to(device, non_blocking=True)
+        target_weight = target_weight.to(device, non_blocking=True)
 
+        # with torch.autograd.profiler.profile(use_npu=True) as prof:
+        outputs = model(input)
         if isinstance(outputs, list):
             loss = criterion(outputs[0], target, target_weight)
             for output in outputs[1:]:
@@ -55,12 +61,17 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
             loss = criterion(output, target, target_weight)
 
         # loss = criterion(output, target, target_weight)
-
         # compute gradient and do update step
         optimizer.zero_grad()
-        loss.backward()
+        if is_amp:
+            with amp.scale_loss(loss, optimizer) as scaled_loss:
+                scaled_loss.backward()
+        else:
+            loss.backward()
         optimizer.step()
 
+        # exit()
+
         # measure accuracy and record loss
         losses.update(loss.item(), input.size(0))
 
@@ -72,7 +83,7 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
         batch_time.update(time.time() - end)
         end = time.time()
 
-        if i % config.PRINT_FREQ == 0:
+        if i % config.PRINT_FREQ == 0 and device == 'npu:0':
             msg = 'Epoch: [{0}][{1}/{2}]\t' \
                   'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                   'Speed {speed:.1f} samples/s\t' \
@@ -96,7 +107,7 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
 
 
 def validate(config, val_loader, val_dataset, model, criterion, output_dir,
-             tb_log_dir, writer_dict=None):
+             tb_log_dir, writer_dict=None,device=None):
     batch_time = AverageMeter()
     losses = AverageMeter()
     acc = AverageMeter()
@@ -118,6 +129,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
         end = time.time()
         for i, (input, target, target_weight, meta) in enumerate(val_loader):
             # compute output
+            input = input.to(device, non_blocking=True)
             outputs = model(input)
             if isinstance(outputs, list):
                 output = outputs[-1]
@@ -128,7 +140,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
                 # this part is ugly, because pytorch has not supported negative index
                 # input_flipped = model(input[:, :, :, ::-1])
                 input_flipped = np.flip(input.cpu().numpy(), 3).copy()
-                input_flipped = torch.from_numpy(input_flipped).cuda()
+                input_flipped = torch.from_numpy(input_flipped).npu()
                 outputs_flipped = model(input_flipped)
 
                 if isinstance(outputs_flipped, list):
@@ -142,8 +154,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
 
                 output = (output + output_flipped) * 0.5
 
-            target = target.cuda(non_blocking=True)
-            target_weight = target_weight.cuda(non_blocking=True)
+            target = target.npu(device)
+            target_weight = target_weight.npu(device)
 
             loss = criterion(output, target, target_weight)
 
diff --git a/lib/core/inference.py b/lib/core/inference.py
index e5cbd25..e63f186 100644
--- a/lib/core/inference.py
+++ b/lib/core/inference.py
@@ -14,7 +14,7 @@ import math
 import numpy as np
 import cv2
 
-from utils.transforms import transform_preds
+from lib.utils.transforms import transform_preds
 
 
 def get_max_preds(batch_heatmaps):
diff --git a/lib/core/loss.py b/lib/core/loss.py
index b879495..cad0d9e 100644
--- a/lib/core/loss.py
+++ b/lib/core/loss.py
@@ -12,6 +12,7 @@ import torch
 import torch.nn as nn
 
 
+
 class JointsMSELoss(nn.Module):
     def __init__(self, use_target_weight):
         super(JointsMSELoss, self).__init__()
diff --git a/lib/dataset/JointsDataset.py b/lib/dataset/JointsDataset.py
index 5367463..7179c6a 100644
--- a/lib/dataset/JointsDataset.py
+++ b/lib/dataset/JointsDataset.py
@@ -1,10 +1,3 @@
-# ------------------------------------------------------------------------------
-# Copyright (c) Microsoft
-# Licensed under the MIT License.
-# Written by Bin Xiao (Bin.Xiao@microsoft.com)
-# Modified by Hanbin Dai (daihanbin.ac@gmail.com)
-# ------------------------------------------------------------------------------
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -18,10 +11,9 @@ import numpy as np
 import torch
 from torch.utils.data import Dataset
 
-from utils.transforms import get_affine_transform
-from utils.transforms import affine_transform
-from utils.transforms import fliplr_joints
-
+from lib.utils.transforms import get_affine_transform
+from lib.utils.transforms import affine_transform
+from lib.utils.transforms import fliplr_joints
 
 logger = logging.getLogger(__name__)
 
@@ -108,7 +100,7 @@ class JointsDataset(Dataset):
 
         return center, scale
 
-    def __len__(self,):
+    def __len__(self, ):
         return len(self.db)
 
     def __getitem__(self, idx):
@@ -119,7 +111,7 @@ class JointsDataset(Dataset):
         imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
 
         if self.data_format == 'zip':
-            from utils import zipreader
+            from lib.utils import zipreader
             data_numpy = zipreader.imread(
                 image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
             )
@@ -145,7 +137,7 @@ class JointsDataset(Dataset):
 
         if self.is_train:
             if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body
-                and np.random.rand() < self.prob_half_body):
+                    and np.random.rand() < self.prob_half_body):
                 c_half_body, s_half_body = self.half_body_transform(
                     joints, joints_vis
                 )
@@ -155,8 +147,8 @@ class JointsDataset(Dataset):
 
             sf = self.scale_factor
             rf = self.rotation_factor
-            s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
-            r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
+            s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
+            r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \
                 if random.random() <= 0.6 else 0
 
             if self.flip and random.random() <= 0.5:
@@ -164,7 +156,7 @@ class JointsDataset(Dataset):
                 joints, joints_vis = fliplr_joints(
                     joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
                 c[0] = data_numpy.shape[1] - c[0] - 1
-                
+
         joints_heatmap = joints.copy()
         trans = get_affine_transform(c, s, r, self.image_size)
         trans_heatmap = get_affine_transform(c, s, r, self.heatmap_size)
@@ -221,11 +213,11 @@ class JointsDataset(Dataset):
 
             joints_x, joints_y = joints_x / num_vis, joints_y / num_vis
 
-            area = rec['scale'][0] * rec['scale'][1] * (self.pixel_std**2)
+            area = rec['scale'][0] * rec['scale'][1] * (self.pixel_std ** 2)
             joints_center = np.array([joints_x, joints_y])
             bbox_center = np.array(rec['center'])
-            diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2)
-            ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area))
+            diff_norm2 = np.linalg.norm((joints_center - bbox_center), 2)
+            ks = np.exp(-1.0 * (diff_norm2 ** 2) / ((0.2) ** 2 * 2.0 * area))
 
             metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
             if ks > metric:
@@ -235,7 +227,6 @@ class JointsDataset(Dataset):
         logger.info('=> num selected db: {}'.format(len(db_selected)))
         return db_selected
 
-
     def generate_target(self, joints, joints_vis):
         '''
         :param joints:  [num_joints, 3]
@@ -259,13 +250,13 @@ class JointsDataset(Dataset):
             for joint_id in range(self.num_joints):
                 target_weight[joint_id] = \
                     self.adjust_target_weight(joints[joint_id], target_weight[joint_id], tmp_size)
-                
+
                 if target_weight[joint_id] == 0:
                     continue
 
                 mu_x = joints[joint_id][0]
                 mu_y = joints[joint_id][1]
-                
+
                 x = np.arange(0, self.heatmap_size[0], 1, np.float32)
                 y = np.arange(0, self.heatmap_size[1], 1, np.float32)
                 y = y[:, np.newaxis]
@@ -279,7 +270,6 @@ class JointsDataset(Dataset):
 
         return target, target_weight
 
-
     def adjust_target_weight(self, joint, target_weight, tmp_size):
         # feat_stride = self.image_size / self.heatmap_size
         mu_x = joint[0]
@@ -292,4 +282,4 @@ class JointsDataset(Dataset):
             # If not, just return the image as is
             target_weight = 0
 
-        return target_weight
+        return target_weight
\ No newline at end of file
diff --git a/lib/dataset/__init__.py b/lib/dataset/__init__.py
index 16e2413..a28ede6 100644
--- a/lib/dataset/__init__.py
+++ b/lib/dataset/__init__.py
@@ -8,5 +8,5 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from .mpii import MPIIDataset as mpii
+# from .mpii import MPIIDataset as mpii
 from .coco import COCODataset as coco
diff --git a/lib/dataset/coco.py b/lib/dataset/coco.py
index 667d283..04c0d48 100644
--- a/lib/dataset/coco.py
+++ b/lib/dataset/coco.py
@@ -1,10 +1,3 @@
-# ------------------------------------------------------------------------------
-# Copyright (c) Microsoft
-# Licensed under the MIT License.
-# Written by Bin Xiao (Bin.Xiao@microsoft.com)
-# Modified by Hanbin Dai (daihanbin.ac@gmail.com) and Feng Zhang (zhangfengwcy@gmail.com)
-# ------------------------------------------------------------------------------
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -19,9 +12,9 @@ from pycocotools.cocoeval import COCOeval
 import json_tricks as json
 import numpy as np
 
-from dataset.JointsDataset import JointsDataset
-from nms.nms import oks_nms
-from nms.nms import soft_oks_nms
+from lib.dataset.JointsDataset import JointsDataset
+from lib.nms.nms import oks_nms
+from lib.nms.nms import soft_oks_nms
 
 
 logger = logging.getLogger(__name__)
@@ -443,4 +436,4 @@ class COCODataset(JointsDataset):
         for ind, name in enumerate(stats_names):
             info_str.append((name, coco_eval.stats[ind]))
 
-        return info_str
+        return info_str
\ No newline at end of file
diff --git a/lib/models/transpose_h.py b/lib/models/transpose_h.py
index f52b0b7..878f2af 100644
--- a/lib/models/transpose_h.py
+++ b/lib/models/transpose_h.py
@@ -20,6 +20,13 @@ from collections import OrderedDict
 
 import copy
 from typing import Optional, List
+import torch.npu
+import os
+NPU_CALCULATE_DEVICE = 0
+if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')):
+    NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE'))
+if torch.npu.current_device() != NPU_CALCULATE_DEVICE:
+    torch.npu.set_device(f'npu:{NPU_CALCULATE_DEVICE}')
 
 BN_MOMENTUM = 0.1
 logger = logging.getLogger(__name__)
@@ -685,6 +692,9 @@ class TransPoseH(nn.Module):
 
 def get_pose_net(cfg, is_train, **kwargs):
     model = TransPoseH(cfg, **kwargs)
+    model = model.to(f'npu:{NPU_CALCULATE_DEVICE}')
+    if not isinstance(model, torch.nn.parallel.DistributedDataParallel):
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[NPU_CALCULATE_DEVICE], broadcast_buffers=False)
 
     if is_train and cfg['MODEL']['INIT_WEIGHTS']:
         model.init_weights(cfg['MODEL']['PRETRAINED'])
diff --git a/lib/models/transpose_r.py b/lib/models/transpose_r.py
index b7b6274..c77f4c3 100644
--- a/lib/models/transpose_r.py
+++ b/lib/models/transpose_r.py
@@ -18,6 +18,7 @@ from collections import OrderedDict
 
 import copy
 from typing import Optional, List
+import os
 
 BN_MOMENTUM = 0.1
 logger = logging.getLogger(__name__)
@@ -392,13 +393,27 @@ class TransPoseR(nn.Module):
                     stride=2,
                     padding=padding,
                     output_padding=output_padding,
-                    bias=self.deconv_with_bias))
+                    bias=self.deconv_with_bias
+                    ))
             layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
             layers.append(nn.ReLU(inplace=True))
+            # layers.append(nn.Sequential(
+            #     nn.ConvTranspose2d(
+            #                 in_channels=self.inplanes,
+            #                 out_channels=planes,
+            #                 kernel_size=kernel,
+            #                 stride=2,
+            #                 padding=padding,
+            #                 output_padding=output_padding,
+            #                 bias=self.deconv_with_bias
+            #     ),
+            #     nn.BatchNorm2d(planes, momentum=BN_MOMENTUM),
+            #     nn.ReLU(inplace=True)
+            # ))
             self.inplanes = planes
 
         return nn.Sequential(*layers)
-
+        # return nn.ModuleList(deconv_layers)
     def forward(self, x):
         x = self.conv1(x)
         x = self.bn1(x)
diff --git a/lib/nms/nms.py b/lib/nms/nms.py
index 7f83e05..29daf57 100644
--- a/lib/nms/nms.py
+++ b/lib/nms/nms.py
@@ -10,8 +10,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from .cpu_nms import cpu_nms
-from .gpu_nms import gpu_nms
+# from .cpu_nms import cpu_nms
+# from .gpu_nms import gpu_nms
 
 
 def py_nms_wrapper(thresh):
@@ -20,16 +20,16 @@ def py_nms_wrapper(thresh):
     return _nms
 
 
-def cpu_nms_wrapper(thresh):
-    def _nms(dets):
-        return cpu_nms(dets, thresh)
-    return _nms
-
-
-def gpu_nms_wrapper(thresh, device_id):
-    def _nms(dets):
-        return gpu_nms(dets, thresh, device_id)
-    return _nms
+# def cpu_nms_wrapper(thresh):
+#     def _nms(dets):
+#         return cpu_nms(dets, thresh)
+#     return _nms
+#
+#
+# def gpu_nms_wrapper(thresh, device_id):
+#     def _nms(dets):
+#         return gpu_nms(dets, thresh, device_id)
+#     return _nms
 
 
 def nms(dets, thresh):
diff --git a/lib/nms/setup_linux.py b/lib/nms/setup_linux.py
index 9120a93..7c14f72 100644
--- a/lib/nms/setup_linux.py
+++ b/lib/nms/setup_linux.py
@@ -11,6 +11,11 @@ from setuptools import setup
 from distutils.extension import Extension
 from Cython.Distutils import build_ext
 import numpy as np
+import os
+import ascend_function
+NPU_CALCULATE_DEVICE = 0
+if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')):
+    NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE'))
 
 
 def find_in_path(name, path):
@@ -38,7 +43,7 @@ def locate_cuda():
         nvcc = pjoin(home, 'bin', 'nvcc')
     else:
         # otherwise, search the PATH for NVCC
-        default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin')
+        default_path = pjoin(os.sep, 'usr', 'local', 'npu', 'bin')
         nvcc = find_in_path('nvcc', os.environ['PATH'] + os.pathsep + default_path)
         if nvcc is None:
             raise EnvironmentError('The nvcc binary could not be '
diff --git a/lib/utils/transforms.py b/lib/utils/transforms.py
index 19581df..65ab99c 100644
--- a/lib/utils/transforms.py
+++ b/lib/utils/transforms.py
@@ -11,6 +11,7 @@ from __future__ import print_function
 
 import numpy as np
 import cv2
+import os
 
 
 def flip_back(output_flipped, matched_parts):
diff --git a/lib/utils/utils.py b/lib/utils/utils.py
index 3ed88e9..0317a51 100644
--- a/lib/utils/utils.py
+++ b/lib/utils/utils.py
@@ -17,6 +17,8 @@ from pathlib import Path
 import torch
 import torch.optim as optim
 import torch.nn as nn
+import os
+
 
 
 def create_logger(cfg, cfg_name, phase='train'):
diff --git a/lib/utils/zipreader.py b/lib/utils/zipreader.py
index dab919f..13a05bd 100644
--- a/lib/utils/zipreader.py
+++ b/lib/utils/zipreader.py
@@ -14,6 +14,11 @@ import xml.etree.ElementTree as ET
 
 import cv2
 import numpy as np
+import os
+import ascend_function
+NPU_CALCULATE_DEVICE = 0
+if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')):
+    NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE'))
 
 _im_zfile = []
 _xml_path_zip = []