import argparse
import os
import numpy as np
import torch
import onnx
from torch import nn
from numpy import random
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.modeling.detector.onnx_model import SWIN_BACKBONE, FUSE_MODEL, SELECT_BBOX
from maskrcnn_benchmark.engine.inference import create_positive_dict, create_queries_and_maps, create_queries_and_maps_from_dataset
from maskrcnn_benchmark.modeling.language_backbone import build_language_backbone
from transformers import AutoTokenizer
if not hasattr(np, 'float'):
np.float = np.float64
class lang(nn.Module):
def __init__(self, cfg):
super(lang, self).__init__()
self.language_backbone = build_language_backbone(cfg)
self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
def forward(self, captions):
language_dict_features = {}
tokenized = self.tokenizer.batch_encode_plus(captions,
max_length=256,
padding='max_length',
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to('cpu')
input_ids = tokenized.input_ids
mlm_labels = None
tokenizer_input = {"input_ids": input_ids,
"attention_mask": tokenized.attention_mask}
language_dict_features = self.language_backbone(tokenizer_input)
return language_dict_features
def create_language_dict_features(cfg, weight):
data_loaders = make_data_loader(cfg, is_train=False)
data_loader = data_loaders[0]
dataset = data_loaders[0].dataset
captions, all_positive_map_label_to_token = create_queries_and_maps_from_dataset(dataset, cfg)
positive_map=all_positive_map_label_to_token[0]
lang_model = lang(cfg)
checkpointer = DetectronCheckpointer(cfg, lang_model)
_ = checkpointer.load(weight, force=True)
with torch.no_grad():
language_dict_features = lang_model(captions)
return language_dict_features, captions, positive_map
def main():
parser = argparse.ArgumentParser(description="PyTorch Detection to Grounding Inference")
parser.add_argument(
"--config-file",
default="configs/grounding/e2e_dyhead_SwinT_S_FPN_1x_od_grounding_eval.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--weight",
help="pth to model",
default="glip_tiny_model_o365_goldg.pth"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER
)
parser.add_argument(
"--model_type",
help="convert model type",
)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
if args.model_type == "backbone":
model = SWIN_BACKBONE(cfg)
elif args.model_type == "rpn_head":
language_dict_features, captions, positive_map = create_language_dict_features(cfg, args.weight)
np.save('./rpn_head/positive_map.npy', positive_map)
model = FUSE_MODEL(cfg, language_dict_features, captions, positive_map)
elif args.model_type == "select":
model = SELECT_BBOX(cfg)
checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
_ = checkpointer.load(args.weight, force=True)
iou_types = ("bbox",)
model.eval()
if args.model_type == "backbone":
image = torch.rand(1,3,800,1216)
dynamic_axes={'images': {2: '-1', 3: '-1'}}
dummy_input = (image)
torch.onnx.export(model, dummy_input, './backbone/model/glip_backbone.onnx', input_names=['images'],\
output_names=['feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5'], opset_version=11, dynamic_axes=dynamic_axes)
elif args.model_type == "rpn_head":
feature_1 = torch.rand(1,256,100,152)
feature_2 = torch.rand(1,256,50,76)
feature_3 = torch.rand(1,256,25,38)
feature_4 = torch.rand(1,256,13,19)
feature_5 = torch.rand(1,256,7,10)
dynamic_axes={'feature_1': {2: '-1', 3: '-1'}, 'feature_2': {2: '-1', 3: '-1'}, 'feature_3': {2: '-1', 3: '-1'}, \
'feature_4': {2: '-1', 3: '-1'}, 'feature_5': {2: '-1', 3: '-1'}}
dummy_input = (feature_1, feature_2, feature_3, feature_4, feature_5)
torch.onnx.export(model, dummy_input, './rpn_head/model/glip_rpn_head.onnx', input_names=['feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5'],\
output_names=['o1', 'o2', 'o3', 'o4', 'o5', 'o6'], opset_version=11, dynamic_axes=dynamic_axes)
elif args.model_type == "select":
input_1 = torch.rand(1,256,100,152)
input_2 = torch.rand(1,256,50,76)
input_3 = torch.rand(1,256,25,38)
input_4 = torch.rand(1,256,13,19)
input_5 = torch.rand(1,256,7,10)
input_6 = torch.rand(1,256,768)
dynamic_axes={'input_1': {2: '-1', 3: '-1'}, 'input_2': {2: '-1', 3: '-1'}, 'input_3': {2: '-1', 3: '-1'}, \
'input_4': {2: '-1', 3: '-1'}, 'input_5': {2: '-1', 3: '-1'}}
dummy_input = (input_1, input_2, input_3, input_4, input_5, input_6)
torch.onnx.export(model, dummy_input, './select/model/glip_select.onnx', input_names=['input_1', 'input_2', 'input_3', 'input_4', 'input_5', 'input_6'],\
output_names=['o1', 'o2', 'o3', 'o4', 'o5', 'o6', 'o7', 'o8', 'o9', 'o10', 'o11', 'o12', 'o13', \
'o14', 'o15', 'o16', 'o17', 'o18', 'o19', 'o20',], \
opset_version=11, dynamic_axes=dynamic_axes)
if __name__ == '__main__':
main()