import argparse
import torch
from ais_bench.infer.interface import InferSession
from maskrcnn_benchmark.modeling.detector.onnx_model import SWIN_BACKBONE, FUSE_MODEL, LANG
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
def om_infer(lang, mask, image):
device_id = 0
model_backbone = "./glip_backbone.om"
session_backbone = InferSession(device_id, model_backbone)
model_rpn = "./glip_rpn.om"
session_rpn = InferSession(device_id, model_rpn)
model_lang = "./glip_language.om"
session_lang = InferSession(device_id, model_lang)
feed_backbone = [image]
feed_lang = [lang, mask]
f1, f2, f3, f4, f5 = session_backbone.infer(feed_backbone)
l1, l2 = session_lang.infer(feed_lang)
feeds_rpn = [f1, f2, f3, f4, f5, l1, l2]
out = session_rpn.infer(feeds_rpn)
return out
def cpu_infer(cfg, weight, lang, mask, image):
model_backbone = SWIN_BACKBONE(cfg)
checkpointer = DetectronCheckpointer(cfg, model_backbone)
_ = checkpointer.load(weight, force=True)
model_backbone.eval()
model_rpn = FUSE_MODEL(cfg)
checkpointer = DetectronCheckpointer(cfg, model_rpn)
_ = checkpointer.load(weight, force=True)
model_rpn.eval()
model_lang = LANG(cfg)
checkpointer = DetectronCheckpointer(cfg, model_lang)
_ = checkpointer.load(weight, force=True)
model_lang.eval()
with torch.no_grad():
f1, f2, f3, f4, f5 = model_backbone(image)
l1, l2 = model_lang(lang,mask)
out = model_rpn(f1, f2, f3, f4, f5, l1, l2)
return out
def main():
parser = argparse.ArgumentParser(description="PyTorch Detection to Grounding Inference")
parser.add_argument(
"--config_file",
default="configs/grounding/e2e_dyhead_SwinT_S_FPN_1x_od_grounding_eval.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--weight",
help="pth to model",
default="glip_tiny_model_o365_goldg.pth"
)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
cfg.freeze()
lang = torch.randint(1,500,(1,256))
mask = torch.ones(1,256)
image = torch.rand(1,3,784,1344)
x = cpu_infer(cfg, args.weight, lang, mask, image)
y = om_infer(lang, mask, image)
for i in range(15):
similarity = torch.cosine_similarity(x[i].reshape(-1), torch.from_numpy(y[i]).reshape(-1), dim=0)
print(f"第{i}个输出的余弦相似度为:")
print(similarity)
if __name__ == '__main__':
main()