import torch
import numpy as np
import os.path as osp
import os
import argparse
from network import MGN
from data import Data
from opt import opt
def build_model():
model = MGN()
model.load_state_dict(torch.load(opt.weight))
if opt.npu:
model = model.to("npu:0")
model.eval()
return model
def get_raw_data(data):
inputs, targets = next(iter(data.query_loader))
return inputs, targets
def extract_one_batch_feature(model, inputs):
if opt.npu:
inputs = inputs.npu()
outputs = model(inputs)
f1 = outputs[0].data.cpu()
inputs = inputs.index_select(3, torch.arange(inputs.size(3) - 1, -1, -1))
if opt.npu:
inputs = inputs.npu()
outputs = model(inputs)
f2 = outputs[0].data.cpu()
ff = f1 + f2
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
return ff
if __name__ == '__main__':
data = Data()
inputs, targets = get_raw_data(data)
model = build_model()
qf = extract_one_batch_feature(model, inputs).numpy()
print(qf)