import torch
from torch import nn
from ..backbone import build_backbone
from ..rpn.vldyhead import VLDyHeadModule
class SWIN_BACKBONE(nn.Module):
def __init__(self, cfg):
super(SWIN_BACKBONE, self).__init__()
self.cfg = cfg
self.backbone = build_backbone(cfg)
def forward(self, images):
with torch.no_grad():
visual_features = self.backbone(images)
return visual_features[0], visual_features[1], visual_features[2], visual_features[3], visual_features[4]
class FUSE_MODEL(nn.Module):
def __init__(self, cfg, language_dict_features, captions, positive_map):
super(FUSE_MODEL, self).__init__()
self.cfg = cfg
self.rpn = VLDyHeadModule(cfg, onnx_export='rpn_head')
self.language_dict_features = language_dict_features
self.captions = captions
self.positive_map = positive_map
def forward(self, input1, input2, input3, input4, input5):
visual_features = []
visual_features.append(input1)
visual_features.append(input2)
visual_features.append(input3)
visual_features.append(input4)
visual_features.append(input5)
with torch.no_grad():
o1, o2, o3, o4, o5, o6 = self.rpn(None, visual_features, None, self.language_dict_features, self.positive_map, self.captions, None)
return o1, o2, o3, o4, o5, o6
class SELECT_BBOX(nn.Module):
def __init__(self, cfg):
super(SELECT_BBOX, self).__init__()
self.cfg = cfg
self.rpn = VLDyHeadModule(cfg, onnx_export='select')
def forward(self, input1, input2, input3, input4, input5, input6):
visual_features = [input1, input2, input3, input4, input5, input6]
with torch.no_grad():
box_cls, box_regression, centerness, dot_product_logits = self.rpn(None, visual_features, None, None, None, None, None)
return box_cls[0], box_cls[1], box_cls[2], box_cls[3], box_cls[4], box_regression[0], box_regression[1], box_regression[2], \
box_regression[3], box_regression[4], centerness[0], centerness[1], centerness[2], centerness[3], centerness[4], \
dot_product_logits[0], dot_product_logits[1], dot_product_logits[2], dot_product_logits[3], dot_product_logits[4]