import torch
from torch import nn
from ..backbone import build_backbone
from ..rpn.vldyhead import VLDyHeadModule
from maskrcnn_benchmark.modeling.language_backbone import build_language_backbone
class LANG(nn.Module):
def __init__(self, cfg):
super(LANG, self).__init__()
self.language_backbone = build_language_backbone(cfg)
def forward(self, input_ids, attention_mask):
tokenizer_input = {"input_ids": input_ids,
"attention_mask": attention_mask}
language_dict_features = self.language_backbone(tokenizer_input)
return language_dict_features['hidden'], language_dict_features['masks']
class SWIN_BACKBONE(nn.Module):
def __init__(self, cfg):
super(SWIN_BACKBONE, self).__init__()
self.cfg = cfg
self.backbone = build_backbone(cfg)
def forward(self, images):
with torch.no_grad():
visual_features = self.backbone(images)
return visual_features[0], visual_features[1], visual_features[2], visual_features[3], visual_features[4]
class FUSE_MODEL(nn.Module):
def __init__(self, cfg):
super(FUSE_MODEL, self).__init__()
self.cfg = cfg
self.rpn = VLDyHeadModule(cfg, onnx_export='rpn_head')
def forward(self, input1, input2, input3, input4, input5, hidden, mask):
language_dict_features = {}
language_dict_features['hidden'] = hidden
language_dict_features['masks'] = mask
language_dict_features['embedded'] = None
visual_features = []
visual_features.append(input1)
visual_features.append(input2)
visual_features.append(input3)
visual_features.append(input4)
visual_features.append(input5)
with torch.no_grad():
box_regression, centerness, dot_product_logits = self.rpn(None, visual_features, None, language_dict_features, None, None, None)
return box_regression[0], box_regression[1], box_regression[2], \
box_regression[3], box_regression[4], centerness[0], centerness[1], centerness[2], centerness[3], centerness[4], \
dot_product_logits[0], dot_product_logits[1], dot_product_logits[2], dot_product_logits[3], dot_product_logits[4]