import numpy as np
import cv2
import torch
import torch.nn.functional as F
import time
import warnings
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from tqdm import tqdm
from .alexnet import SiameseAlexNet
from .config import config
from .custom_transforms import ToTensor
from .utils import get_exemplar_image, get_pyramid_instance_image, get_instance_image, show_image
torch.set_num_threads(1)
class SiamFCTracker:
def __init__(self, model_path, npu_id=0, is_deterministic=False):
self.npu_id = npu_id
self.name = 'SiamFC'
self.is_deterministic = is_deterministic
self.model = SiameseAlexNet()
self.model.load_state_dict(torch.load(model_path))
torch.npu.set_device(npu_id)
self.model = self.model.to(torch.device("npu:{}".format(self.npu_id)))
self.model.eval()
self.transforms = transforms.Compose([ToTensor()])
def _cosine_window(self, size):
"""
get the cosine window
"""
cos_window = np.hanning(int(size[0]))[:, np.newaxis].dot(np.hanning(int(size[1]))[np.newaxis, :])
cos_window = cos_window.astype(np.float32)
cos_window /= np.sum(cos_window)
return cos_window
def init(self, frame, box):
""" initialize siamfc tracker
Args:
frame: an RGB image
bbox: one-based bounding box [x, y, width, height]
"""
self.bbox = (box[0]-1, box[1]-1, box[0]-1+box[2], box[1]-1+box[3])
self.pos = np.array([box[0]-1+(box[2])/2, box[1]-1+(box[3])/2])
self.target_sz = np.array([box[2], box[3]])
self.img_mean = tuple(map(int, frame.mean(axis=(0, 1))))
exemplar_img, scale_z, s_z = get_exemplar_image(frame, self.bbox,
config.exemplar_size, config.context_amount, self.img_mean)
exemplar_img = self.transforms(exemplar_img)[None, :, :, :]
exemplar_img_var = Variable(exemplar_img.to(torch.device("npu:{}".format(self.npu_id))))
self.model((exemplar_img_var, None))
self.penalty = np.ones(config.num_scale) * config.scale_penalty
self.penalty[config.num_scale//2] = 1
self.interp_response_sz = config.response_up_stride * config.response_sz
self.cosine_window = self._cosine_window((self.interp_response_sz, self.interp_response_sz))
self.scales = config.scale_step ** np.arange(np.ceil(config.num_scale/2)-config.num_scale,
np.floor(config.num_scale/2)+1)
self.s_x = s_z + (config.instance_size-config.exemplar_size) / scale_z
self.min_s_x = 0.2 * self.s_x
self.max_s_x = 5 * self.s_x
def update(self, frame):
"""track object based on the previous frame
Args:
frame: an RGB image
Returns:
bbox: tuple of 1-based bounding box(xmin, ymin, xmax, ymax)
"""
size_x_scales = self.s_x * self.scales
pyramid = get_pyramid_instance_image(frame, self.pos, config.instance_size, size_x_scales, self.img_mean)
instance_imgs = torch.cat([self.transforms(x)[None, :, :, :] for x in pyramid], dim=0)
instance_imgs_var = Variable(instance_imgs.to(torch.device("npu:{}".format(self.npu_id))))
response_maps = self.model((None, instance_imgs_var))
response_maps = response_maps.data.cpu().numpy().squeeze()
response_maps_up = [cv2.resize(x, (self.interp_response_sz, self.interp_response_sz), cv2.INTER_CUBIC)
for x in response_maps]
max_score = np.array([x.max() for x in response_maps_up]) * self.penalty
scale_idx = max_score.argmax()
response_map = response_maps_up[scale_idx]
response_map -= response_map.min()
response_map /= response_map.sum()
response_map = (1 - config.window_influence) * response_map + config.window_influence * self.cosine_window
max_r, max_c = np.unravel_index(response_map.argmax(), response_map.shape)
disp_response_interp = np.array([max_c, max_r]) - (self.interp_response_sz-1) / 2.
disp_response_input = disp_response_interp * config.total_stride / config.response_up_stride
scale = self.scales[scale_idx]
disp_response_frame = disp_response_input * (self.s_x * scale) / config.instance_size
self.pos += disp_response_frame
self.s_x *= ((1 - config.scale_lr) + config.scale_lr * scale)
self.s_x = max(self.min_s_x, min(self.max_s_x, self.s_x))
self.target_sz = ((1 - config.scale_lr) + config.scale_lr * scale) * self.target_sz
box = np.array([
self.pos[0] + 1 - (self.target_sz[0]) / 2,
self.pos[1] + 1 - (self.target_sz[1]) / 2,
self.target_sz[0], self.target_sz[1]])
return box
def track(self, img_files, box, visualize=False):
frame_num = len(img_files)
boxes = np.zeros((frame_num, 4))
boxes[0] = box
times = np.zeros(frame_num)
for f, img_file in enumerate(img_files):
img = cv2.imread(img_file, cv2.IMREAD_COLOR)
begin = time.time()
if f == 0:
self.init(img, box)
else:
boxes[f, :] = self.update(img)
times[f] = time.time() - begin
if visualize:
show_image(img, boxes[f, :])
return boxes, times