import sys
import torch
import torch.nn as nn
import torch.onnx
import torch.nn.functional as F
response_scale = 1e-3
class SiameseAlexNet(nn.Module):
def __init__(self):
super(SiameseAlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 96, 11, 2),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2),
nn.Conv2d(96, 256, 5, 1, groups=2),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2),
nn.Conv2d(256, 384, 3, 1),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, 3, 1, groups=2),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, 3, 1, groups=2)
)
self.corr_bias = nn.Parameter(torch.zeros(1))
self.exemplar = None
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
exemplar, instance = x
if exemplar is not None and instance is not None:
batch_size = exemplar.shape[0]
exemplar = self.features(exemplar)
instance = self.features(instance)
N, C, H, W = instance.shape
instance = instance.view(1, -1, H, W)
score = F.conv2d(instance, exemplar, groups=N) * response_scale + self.corr_bias
return score.transpose(0, 1)
elif exemplar is not None and instance is None:
self.exemplar = self.features(exemplar)
self.exemplar = torch.cat([self.exemplar for _ in range(3)], dim=0)
return self.exemplar
else:
_, _, H, W = instance.shape
instance = instance.reshape(3, 3, H, W)
instance = self.features(instance)
N, C, H, W = instance.shape
instance = instance.view(1, N*C, H, W)
return instance
def exemplar_convert(input_file, output_file):
model = SiameseAlexNet()
model.load_state_dict(torch.load(input_file, map_location='cpu'))
model.eval()
input_names = ["actual_input_1"]
output_names = ["output1"]
input1 = torch.randn(1, 3, 127, 127)
input2 = None
dummy_input = [input1, input2]
torch.onnx.export(model, dummy_input, output_file, input_names=input_names, output_names=output_names,
opset_version=11)
def search_convert(input_file, output_file):
model = SiameseAlexNet()
model.load_state_dict(torch.load(input_file, map_location='cpu'))
model.eval()
input_names = ["actual_input_1"]
output_names = ["output1"]
input1 = None
input2 = torch.randn(1, 9, 255, 255)
dummy_input = [input1, input2]
torch.onnx.export(model, dummy_input, output_file, input_names=input_names, output_names=output_names,
opset_version=11)
if __name__ == "__main__":
input_file = sys.argv[1]
output_file_exemplar = sys.argv[2]
output_file_search = sys.argv[3]
exemplar_convert(input_file, output_file_exemplar)
search_convert(input_file, output_file_search)