05360171创建于 2022年3月18日历史提交
diff --git a/experiments/siammask_sharp/custom.py b/experiments/siammask_sharp/custom.py
index 380c610..60820ee 100644
--- a/experiments/siammask_sharp/custom.py
+++ b/experiments/siammask_sharp/custom.py
@@ -1,38 +1,51 @@
 from models.siammask_sharp import SiamMask
 from models.features import MultiStageFeature
-from models.rpn import RPN, DepthCorr
+from models.rpn import RPN, DepthCorr_10, DepthCorr_20, DepthCorr_3969_feature, DepthCorr_3969_head
 from models.mask import Mask
-import torch
 import torch.nn as nn
-import torch.nn.functional as F
 from utils.load_helper import load_pretrain
 from resnet import resnet50
+import torch
+import torch.nn.functional as F
+import numpy as np
+from utils.anchors import Anchors
+
+
+class ResDownS_t(nn.Module):
+    def __init__(self, inplane, outplane):
+        super(ResDownS_t, self).__init__()
+        self.downsample = nn.Sequential(
+            nn.Conv2d(1024, 256, kernel_size=1, bias=False),
+            nn.BatchNorm2d(256))
+
+    def forward(self, x):
+        x = self.downsample(x)
+        l = 4
+        r = -4
+        x = x[:, :, l:r, l:r]
+        return x
 
 
-class ResDownS(nn.Module):
+class ResDownS_s(nn.Module):
     def __init__(self, inplane, outplane):
-        super(ResDownS, self).__init__()
+        super(ResDownS_s, self).__init__()
         self.downsample = nn.Sequential(
-                nn.Conv2d(inplane, outplane, kernel_size=1, bias=False),
-                nn.BatchNorm2d(outplane))
+            nn.Conv2d(1024, 256, kernel_size=1, bias=False),
+            nn.BatchNorm2d(256))
 
     def forward(self, x):
         x = self.downsample(x)
-        if x.size(3) < 20:
-            l = 4
-            r = -4
-            x = x[:, :, l:r, l:r]
         return x
 
 
-class ResDown(MultiStageFeature):
+class ResDown_t(MultiStageFeature):
     def __init__(self, pretrain=False):
-        super(ResDown, self).__init__()
+        super(ResDown_t, self).__init__()
         self.features = resnet50(layer3=True, layer4=False)
         if pretrain:
-            load_pretrain(self.features, 'resnet.model')
+            load_pretrain(self.features, '../addByMyself/resnet.model')
 
-        self.downsample = ResDownS(1024, 256)
+        self.downsample = ResDownS_t(1024, 256)
 
         self.layers = [self.downsample, self.features.layer2, self.features.layer3]
         self.train_nums = [1, 3]
@@ -40,27 +53,28 @@ class ResDown(MultiStageFeature):
 
         self.unfix(0.0)
 
-    def param_groups(self, start_lr, feature_mult=1):
-        lr = start_lr * feature_mult
-
-        def _params(module, mult=1):
-            params = list(filter(lambda x:x.requires_grad, module.parameters()))
-            if len(params):
-                return [{'params': params, 'lr': lr * mult}]
-            else:
-                return []
-
-        groups = []
-        groups += _params(self.downsample)
-        groups += _params(self.features, 0.1)
-        return groups
-
     def forward(self, x):
         output = self.features(x)
         p3 = self.downsample(output[-1])
         return p3
 
-    def forward_all(self, x):
+
+class ResDown_s(MultiStageFeature):
+    def __init__(self, pretrain=False):
+        super(ResDown_s, self).__init__()
+        self.features = resnet50(layer3=True, layer4=False)
+        if pretrain:
+            load_pretrain(self.features, '../addByMyself/resnet.model')
+
+        self.downsample = ResDownS_s(1024, 256)
+
+        self.layers = [self.downsample, self.features.layer2, self.features.layer3]
+        self.train_nums = [1, 3]
+        self.change_point = [0, 0.5]
+
+        self.unfix(0.0)
+
+    def forward(self, x):
         output = self.features(x)
         p3 = self.downsample(output[-1])
         return output, p3
@@ -70,15 +84,8 @@ class UP(RPN):
     def __init__(self, anchor_num=5, feature_in=256, feature_out=256):
         super(UP, self).__init__()
 
-        self.anchor_num = anchor_num
-        self.feature_in = feature_in
-        self.feature_out = feature_out
-
-        self.cls_output = 2 * self.anchor_num
-        self.loc_output = 4 * self.anchor_num
-
-        self.cls = DepthCorr(feature_in, feature_out, self.cls_output)
-        self.loc = DepthCorr(feature_in, feature_out, self.loc_output)
+        self.cls = DepthCorr_10(256, 256, 10)
+        self.loc = DepthCorr_20(256, 256, 20)
 
     def forward(self, z_f, x_f):
         cls = self.cls(z_f, x_f)
@@ -86,106 +93,122 @@ class UP(RPN):
         return cls, loc
 
 
-class MaskCorr(Mask):
+class MaskCorr_feature(Mask):
     def __init__(self, oSz=63):
-        super(MaskCorr, self).__init__()
+        super(MaskCorr_feature, self).__init__()
         self.oSz = oSz
-        self.mask = DepthCorr(256, 256, self.oSz**2)
+        self.mask = DepthCorr_3969_feature(256, 256, 3969)
 
     def forward(self, z, x):
-        return self.mask(z, x)
+        return self.mask.forward_corr(z, x)
+
+
+class MaskCorr_head(Mask):
+    def __init__(self, oSz=63):
+        super(MaskCorr_head, self).__init__()
+        self.oSz = oSz
+        self.mask = DepthCorr_3969_head(256, 256, 3969)
+
+    def forward(self, x):
+        return self.mask.head(x)
 
 
 class Refine(nn.Module):
     def __init__(self):
         super(Refine, self).__init__()
         self.v0 = nn.Sequential(nn.Conv2d(64, 16, 3, padding=1), nn.ReLU(),
-                           nn.Conv2d(16, 4, 3, padding=1),nn.ReLU())
+                                nn.Conv2d(16, 4, 3, padding=1), nn.ReLU())
 
         self.v1 = nn.Sequential(nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(),
-                           nn.Conv2d(64, 16, 3, padding=1), nn.ReLU())
+                                nn.Conv2d(64, 16, 3, padding=1), nn.ReLU())
 
         self.v2 = nn.Sequential(nn.Conv2d(512, 128, 3, padding=1), nn.ReLU(),
-                           nn.Conv2d(128, 32, 3, padding=1), nn.ReLU())
+                                nn.Conv2d(128, 32, 3, padding=1), nn.ReLU())
 
         self.h2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
-                           nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
+                                nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
 
         self.h1 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),
-                           nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
+                                nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
 
         self.h0 = nn.Sequential(nn.Conv2d(4, 4, 3, padding=1), nn.ReLU(),
-                           nn.Conv2d(4, 4, 3, padding=1), nn.ReLU())
+                                nn.Conv2d(4, 4, 3, padding=1), nn.ReLU())
 
         self.deconv = nn.ConvTranspose2d(256, 32, 15, 15)
 
         self.post0 = nn.Conv2d(32, 16, 3, padding=1)
         self.post1 = nn.Conv2d(16, 4, 3, padding=1)
         self.post2 = nn.Conv2d(4, 1, 3, padding=1)
-        
-        for modules in [self.v0, self.v1, self.v2, self.h2, self.h1, self.h0, self.deconv, self.post0, self.post1, self.post2,]:
+
+        for modules in [self.v0, self.v1, self.v2, self.h2, self.h1, self.h0, self.deconv, self.post0, self.post1,
+                        self.post2, ]:
             for l in modules.modules():
                 if isinstance(l, nn.Conv2d):
                     nn.init.kaiming_uniform_(l.weight, a=1)
 
-    def forward(self, f, corr_feature, pos=None, test=False):
-        if test:
-            p0 = torch.nn.functional.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61]
-            p1 = torch.nn.functional.pad(f[1], [8, 8, 8, 8])[:, :, 2 * pos[0]:2 * pos[0] + 31, 2 * pos[1]:2 * pos[1] + 31]
-            p2 = torch.nn.functional.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0] + 15, pos[1]:pos[1] + 15]
-        else:
-            p0 = F.unfold(f[0], (61, 61), padding=0, stride=4).permute(0, 2, 1).contiguous().view(-1, 64, 61, 61)
-            if not (pos is None): p0 = torch.index_select(p0, 0, pos)
-            p1 = F.unfold(f[1], (31, 31), padding=0, stride=2).permute(0, 2, 1).contiguous().view(-1, 256, 31, 31)
-            if not (pos is None): p1 = torch.index_select(p1, 0, pos)
-            p2 = F.unfold(f[2], (15, 15), padding=0, stride=1).permute(0, 2, 1).contiguous().view(-1, 512, 15, 15)
-            if not (pos is None): p2 = torch.index_select(p2, 0, pos)
-
-        if not(pos is None):
-            p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1)
-        else:
-            p3 = corr_feature.permute(0, 2, 3, 1).contiguous().view(-1, 256, 1, 1)
+    def forward(self, p0, p1, p2, p3):
 
         out = self.deconv(p3)
         out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31)))
         out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61)))
         out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127)))
-        out = out.view(-1, 127*127)
-        return out
+        out = out.view(-1, 127 * 127)
 
-    def param_groups(self, start_lr, feature_mult=1):
-        params = filter(lambda x:x.requires_grad, self.parameters())
-        params = [{'params': params, 'lr': start_lr * feature_mult}]
-        return params
+        return out
 
 
 class Custom(SiamMask):
     def __init__(self, pretrain=False, **kwargs):
         super(Custom, self).__init__(**kwargs)
-        self.features = ResDown(pretrain=pretrain)
-        self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256)
-        self.mask_model = MaskCorr()
+        self.features_t = ResDown_t(pretrain=pretrain)
+        self.features_s = ResDown_s(pretrain=pretrain)
+        self.rpn_model = UP(anchor_num=5, feature_in=256, feature_out=256)
+        self.mask_model_feature = MaskCorr_feature()
+        self.mask_model_head = MaskCorr_head()
         self.refine_model = Refine()
+        window = np.outer(np.hanning(25), np.hanning(25))
+        self.window = np.tile(window.flatten(), 5)
+        self.anchor = generate_anchor({'stride': 8, 'ratios': [0.33, 0.5, 1, 2, 3], 'scales': [8], 'round_dight': 0},
+                                      25)
 
-    def refine(self, f, pos=None):
-        return self.refine_model(f, pos)
-
-    def template(self, template):
-        self.zf = self.features(template)
-
-    def track(self, search):
-        search = self.features(search)
+    def track_mask(self, search):
+        feature, search = self.features_s(search)
         rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search)
-        return rpn_pred_cls, rpn_pred_loc
+        corr_feature = self.mask_model_feature(self.zf, search)
+        pred_mask = self.mask_model_head(corr_feature)
+        return rpn_pred_cls, rpn_pred_loc, pred_mask, feature, corr_feature
 
-    def track_mask(self, search):
-        self.feature, self.search = self.features.forward_all(search)
-        rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, self.search)
-        self.corr_feature = self.mask_model.mask.forward_corr(self.zf, self.search)
-        pred_mask = self.mask_model.mask.head(self.corr_feature)
-        return rpn_pred_cls, rpn_pred_loc, pred_mask
-
-    def track_refine(self, pos):
-        pred_mask = self.refine_model(self.feature, self.corr_feature, pos=pos, test=True)
+    def track_refine(self, p0, p1, p2, p3):
+        pred_mask = self.refine_model(p0, p1, p2, p3)
         return pred_mask
 
+    def forward(self, template, search):
+        self.zf = self.features_t(template)
+        score, delta, mask, feature, corr_feature = self.track_mask(search)
+        '''
+        model.anchors = {'stride': 8, 'ratios': [0.33, 0.5, 1, 2, 3], 'scales': [8], 'round_dight': 0}
+        print(p.anchor.shape) :(3125, 4)
+        print(window.shape):(3125,)
+        '''
+        f0,f1,f2= feature[0],feature[1],feature[2]
+        return score, delta, mask, f0,f1,f2, corr_feature
+
+
+
+def generate_anchor(cfg, score_size):
+    anchors = Anchors(cfg)
+    anchor = anchors.anchors
+    x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3]
+    anchor = np.stack([(x1 + x2) * 0.5, (y1 + y2) * 0.5, x2 - x1, y2 - y1], 1)
+
+    total_stride = anchors.stride
+    anchor_num = anchor.shape[0]
+
+    anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
+    ori = - (score_size // 2) * total_stride
+    xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
+                         [ori + total_stride * dy for dy in range(score_size)])
+    xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
+             np.tile(yy.flatten(), (anchor_num, 1)).flatten()
+    anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
+    return anchor
diff --git a/models/mask.py b/models/mask.py
index b32f9ea..15890fd 100644
--- a/models/mask.py
+++ b/models/mask.py
@@ -19,7 +19,7 @@ class Mask(nn.Module):
     def track(self, search):
         raise NotImplementedError
 
-    def param_groups(self, start_lr, feature_mult=1):
-        params = filter(lambda x:x.requires_grad, self.parameters())
-        params = [{'params': params, 'lr': start_lr * feature_mult}]
-        return params
+    # def param_groups(self, start_lr, feature_mult=1):
+    #     params = filter(lambda x:x.requires_grad, self.parameters())
+    #     params = [{'params': params, 'lr': start_lr * feature_mult}]
+    #     return params
diff --git a/models/rpn.py b/models/rpn.py
index e6158ac..c1ef6a2 100644
--- a/models/rpn.py
+++ b/models/rpn.py
@@ -3,6 +3,7 @@
 # Licensed under The MIT License
 # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
 # --------------------------------------------------------
+import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
@@ -20,49 +21,51 @@ class RPN(nn.Module):
     def track(self, search):
         raise NotImplementedError
 
-    def param_groups(self, start_lr, feature_mult=1, key=None):
-        if key is None:
-            params = filter(lambda x:x.requires_grad, self.parameters())
-        else:
-            params = [v for k, v in self.named_parameters() if (key in k) and v.requires_grad]
-        params = [{'params': params, 'lr': start_lr * feature_mult}]
-        return params
+    # def param_groups(self, start_lr, feature_mult=1, key=None):
+    #     if key is None:
+    #         params = filter(lambda x:x.requires_grad, self.parameters())
+    #     else:
+    #         params = [v for k, v in self.named_parameters() if (key in k) and v.requires_grad]
+    #     params = [{'params': params, 'lr': start_lr * feature_mult}]
+    #     return params
+
 
 
 def conv2d_dw_group(x, kernel):
-    batch, channel = kernel.shape[:2]
-    x = x.view(1, batch*channel, x.size(2), x.size(3))  # 1 * (b*c) * k * k
-    kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))  # (b*c) * 1 * H * W
-    out = F.conv2d(x, kernel, groups=batch*channel)
-    out = out.view(batch, channel, out.size(2), out.size(3))
+    batch, channel = 1, 256
+    x = x.view(1, 256, 29, 29)  # 1 * (b*c) * k * k
+    kernel = kernel.view(256, 1, 5, 5)  # (b*c) * 1 * H * W
+    out = F.conv2d(x, kernel, groups=256)
+    out = out.view(1, 256, 25, 25)
     return out
 
 
-class DepthCorr(nn.Module):
-    def __init__(self, in_channels, hidden, out_channels, kernel_size=3):
-        super(DepthCorr, self).__init__()
+class DepthCorr_20(nn.Module):
+    def __init__(self, in_channels=256, hidden=256, out_channels=10, kernel_size=3):
+        super(DepthCorr_20, self).__init__()
         # adjust layer for asymmetrical features
         self.conv_kernel = nn.Sequential(
-                nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
-                nn.BatchNorm2d(hidden),
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
                 nn.ReLU(inplace=True),
                 )
         self.conv_search = nn.Sequential(
-                nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
-                nn.BatchNorm2d(hidden),
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
                 nn.ReLU(inplace=True),
                 )
 
         self.head = nn.Sequential(
-                nn.Conv2d(hidden, hidden, kernel_size=1, bias=False),
-                nn.BatchNorm2d(hidden),
+                nn.Conv2d(256, 256, kernel_size=1, bias=False),
+                nn.BatchNorm2d(256),
                 nn.ReLU(inplace=True),
-                nn.Conv2d(hidden, out_channels, kernel_size=1)
+                nn.Conv2d(256, 20, kernel_size=1)
                 )
 
     def forward_corr(self, kernel, input):
         kernel = self.conv_kernel(kernel)
         input = self.conv_search(input)
+
         feature = conv2d_dw_group(input, kernel)
         return feature
 
@@ -70,3 +73,112 @@ class DepthCorr(nn.Module):
         feature = self.forward_corr(kernel, search)
         out = self.head(feature)
         return out
+
+class DepthCorr_10(nn.Module):
+    def __init__(self, in_channels=256, hidden=256, out_channels=10, kernel_size=3):
+        super(DepthCorr_10, self).__init__()
+        # adjust layer for asymmetrical features
+        self.conv_kernel = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                )
+        self.conv_search = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                )
+
+        self.head = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=1, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(256, 10, kernel_size=1)
+                )
+
+    def forward_corr(self, kernel, input):
+        kernel = self.conv_kernel(kernel)
+        input = self.conv_search(input)
+
+        feature = conv2d_dw_group(input, kernel)
+        return feature
+
+    def forward(self, kernel, search):
+        feature = self.forward_corr(kernel, search)
+        out = self.head(feature)
+        return out
+
+class DepthCorr_3969_feature(nn.Module):
+    def __init__(self, in_channels=256, hidden=256, out_channels=10, kernel_size=3):
+        super(DepthCorr_3969_feature, self).__init__()
+        # adjust layer for asymmetrical features
+        self.conv_kernel = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                )
+        self.conv_search = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                )
+
+        self.head = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=1, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(256, 3969, kernel_size=1)
+                )
+
+    def forward_corr(self, kernel, input):
+        '''
+        k 1,256,7,7
+        i 1,256,31,31
+        '''
+        kernel = self.conv_kernel(kernel)
+        input = self.conv_search(input)
+        feature = conv2d_dw_group(input, kernel)
+        return feature
+
+    def forward(self, kernel, search):
+        feature = self.forward_corr(kernel, search)
+        out = self.head(feature)
+        return out
+
+class DepthCorr_3969_head(nn.Module):
+    def __init__(self, in_channels=256, hidden=256, out_channels=10, kernel_size=3):
+        super(DepthCorr_3969_head, self).__init__()
+        # adjust layer for asymmetrical features
+        self.conv_kernel = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                )
+        self.conv_search = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=3, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                )
+
+        self.head = nn.Sequential(
+                nn.Conv2d(256, 256, kernel_size=1, bias=False),
+                nn.BatchNorm2d(256),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(256, 3969, kernel_size=1)
+                )
+
+    def forward_corr(self, kernel, input):
+        '''
+        k 1,256,7,7
+        i 1,256,31,31
+        '''
+        kernel = self.conv_kernel(kernel)
+        input = self.conv_search(input)
+        feature = conv2d_dw_group(input, kernel)
+        return feature
+
+    def forward(self, kernel, search):
+        feature = self.forward_corr(kernel, search)
+        out = self.head(feature)
+        return out
+
diff --git a/models/siammask_sharp.py b/models/siammask_sharp.py
index bccdf80..b9b0b38 100644
--- a/models/siammask_sharp.py
+++ b/models/siammask_sharp.py
@@ -8,6 +8,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.autograd import Variable
 from utils.anchors import Anchors
+import numpy as np
 
 
 class SiamMask(nn.Module):
@@ -54,18 +55,12 @@ class SiamMask(nn.Module):
 
         return rpn_loss_cls, rpn_loss_loc, rpn_loss_mask, iou_m, iou_5, iou_7
 
-    def run(self, template, search, softmax=False):
-        """
-        run network
-        """
+    def run(self, template, search):
         template_feature = self.feature_extractor(template)
         feature, search_feature = self.features.forward_all(search)
         rpn_pred_cls, rpn_pred_loc = self.rpn(template_feature, search_feature)
         corr_feature = self.mask_model.mask.forward_corr(template_feature, search_feature)  # (b, 256, w, h)
         rpn_pred_mask = self.refine_model(feature, corr_feature)
-
-        if softmax:
-            rpn_pred_cls = self.softmax(rpn_pred_cls)
         return rpn_pred_cls, rpn_pred_loc, rpn_pred_mask, template_feature, search_feature
 
     def softmax(self, cls):
@@ -75,39 +70,13 @@ class SiamMask(nn.Module):
         cls = F.log_softmax(cls, dim=4)
         return cls
 
-    def forward(self, input):
-        """
-        :param input: dict of input with keys of:
-                'template': [b, 3, h1, w1], input template image.
-                'search': [b, 3, h2, w2], input search image.
-                'label_cls':[b, max_num_gts, 5] or None(self.training==False),
-                                     each gt contains x1,y1,x2,y2,class.
-        :return: dict of loss, predict, accuracy
-        """
-        template = input['template']
-        search = input['search']
-        if self.training:
-            label_cls = input['label_cls']
-            label_loc = input['label_loc']
-            lable_loc_weight = input['label_loc_weight']
-            label_mask = input['label_mask']
-            label_mask_weight = input['label_mask_weight']
-
-        rpn_pred_cls, rpn_pred_loc, rpn_pred_mask, template_feature, search_feature = \
-            self.run(template, search, softmax=self.training)
-
-        outputs = dict()
-
-        outputs['predict'] = [rpn_pred_loc, rpn_pred_cls, rpn_pred_mask, template_feature, search_feature]
-
-        if self.training:
-            rpn_loss_cls, rpn_loss_loc, rpn_loss_mask, iou_acc_mean, iou_acc_5, iou_acc_7 = \
-                self._add_rpn_loss(label_cls, label_loc, lable_loc_weight, label_mask, label_mask_weight,
-                                   rpn_pred_cls, rpn_pred_loc, rpn_pred_mask)
-            outputs['losses'] = [rpn_loss_cls, rpn_loss_loc, rpn_loss_mask]
-            outputs['accuracy'] = [iou_acc_mean, iou_acc_5, iou_acc_7]
-
-        return outputs
+    def forward(self, template,search):
+
+        rpn_pred_cls, rpn_pred_loc, rpn_pred_mask, template_feature, search_feature = self.run(template, search)
+
+
+        return rpn_pred_cls, rpn_pred_loc, rpn_pred_mask,template_feature,search_feature
+
 
     def template(self, z):
         self.zf = self.feature_extractor(z)
@@ -122,6 +91,7 @@ class SiamMask(nn.Module):
         return rpn_pred_cls, rpn_pred_loc
 
 
+
 def get_cls_loss(pred, label, select):
     if select.nelement() == 0: return pred.sum()*0.
     pred = torch.index_select(pred, 0, select)
@@ -185,7 +155,7 @@ def iou_measure(pred, label):
     union = torch.sum(mask_sum > 0, dim=1).float()
     iou = intxn/union
     return torch.mean(iou), (torch.sum(iou > 0.5).float()/iou.shape[0]), (torch.sum(iou > 0.7).float()/iou.shape[0])
-    
+
 
 if __name__ == "__main__":
     p_m = torch.randn(4, 63*63, 25, 25)