@@ -97,10 +97,9 @@ class SemanticSegmentor(nn.Module):
The prediction has shape KxHxW that represents the logits of
each class for each pixel.
"""
- images = [x["image"].to(self.device) for x in batched_inputs]
+ '''images = [x["image"].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
-
features = self.backbone(images.tensor)
if "sem_seg" in batched_inputs[0]:
@@ -109,8 +108,13 @@ class SemanticSegmentor(nn.Module):
targets, self.backbone.size_divisibility, self.sem_seg_head.ignore_value
).tensor
else:
- targets = None
+ targets = None'''
+ batched_inputs = batched_inputs.to(self.device)
+ raw_input = batched_inputs
+ features = self.backbone(raw_input)
+ targets = None
results, losses = self.sem_seg_head(features, targets)
+ return results
if self.training:
return losses
@@ -25,6 +25,91 @@ def _as_tensor(x):
return torch.as_tensor(x)
+def bilinear_grid_sample(im, grid, align_corners=False):
+ """Given an input and a flow-field grid, computes the output using input
+ values and pixel locations from grid. Supported only bilinear interpolation
+ method to sample the input pixels.
+
+ Args:
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
+ considered as referring to the center points of the input's
+ corner pixels. If set to False, they are instead considered as
+ referring to the corner points of the input's corner pixels,
+ making the sampling more resolution agnostic.
+ Returns:
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+ """
+ n, c, h, w = im.shape
+ gn, gh, gw, _ = grid.shape
+ assert n == gn
+
+ x = grid[:, :, :, 0]
+ y = grid[:, :, :, 1]
+
+ if align_corners:
+ x = ((x + 1) / 2) * (w - 1)
+ y = ((y + 1) / 2) * (h - 1)
+ else:
+ x = ((x + 1) * w - 1) / 2
+ y = ((y + 1) * h - 1) / 2
+
+ x = x.view(n, -1)
+ y = y.view(n, -1)
+
+ x0 = torch.floor(x).long()
+ y0 = torch.floor(y).long()
+ x1 = x0 + 1
+ y1 = y0 + 1
+
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
+
+ # Apply default for grid_sample function zero padding
+ im_padded = torch.cat([im, torch.zeros(im.size(0), im.size(1), 1, im.size(3))], dim=2)
+ im_padded = torch.cat([im_padded, torch.zeros(im.size(0), im_padded.size(1), im_padded.size(2), 1)], dim=3)
+ im_padded = torch.cat([torch.zeros(im.size(0), im_padded.size(1), im_padded.size(2), 1), im_padded], dim=3)
+ im_padded = torch.cat([torch.zeros(im.size(0), im_padded.size(1), 1, im_padded.size(3)), im_padded], dim=2)
+ padded_h = h + 2
+ padded_w = w + 2
+ # save points positions after padding
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+
+ # Clip coordinates to padded image size
+ x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+ x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+ y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+ y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+
+ im_padded = im_padded.view(n, c, -1)
+
+ x0_y0_tmp = (x0 + y0 * padded_w).unsqueeze(1).to(dtype=torch.int32)
+ s = torch.zeros(x0_y0_tmp.size(0), c, x0_y0_tmp.size(2))
+ x0_y0 = x0_y0_tmp.expand_as(s).to(dtype=torch.int64)
+ x0_y1_tmp = (x0 + y1 * padded_w).unsqueeze(1).to(dtype=torch.int32)
+ s = torch.zeros(x0_y1_tmp.size(0), c, x0_y1_tmp.size(2))
+ x0_y1 = x0_y1_tmp.expand_as(s).to(dtype=torch.int64)
+ x1_y0_tmp = (x1 + y0 * padded_w).unsqueeze(1).to(dtype=torch.int32)
+ s = torch.zeros(x1_y0_tmp.size(0), c, x1_y0_tmp.size(2))
+ x1_y0 = x1_y0_tmp.expand_as(s).to(dtype=torch.int64)
+ x1_y1_tmp = (x1 + y1 * padded_w).unsqueeze(1).to(dtype=torch.int32)
+ s = torch.zeros(x1_y1_tmp.size(0), c, x1_y1_tmp.size(2))
+ x1_y1 = x1_y1_tmp.expand_as(s).to(dtype=torch.int64)
+
+ Ia = torch.gather(im_padded, 2, x0_y0)
+ Ib = torch.gather(im_padded, 2, x0_y1)
+ Ic = torch.gather(im_padded, 2, x1_y0)
+ Id = torch.gather(im_padded, 2, x1_y1)
+
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
@@ -44,8 +129,11 @@ def point_sample(input, point_coords, **kwargs):
add_dim = False
if point_coords.dim() == 3:
add_dim = True
- point_coords = point_coords.unsqueeze(2)
- output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
+ s = torch.zeros(point_coords.size(0), point_coords.size(1), 1, point_coords.size(2))
+ point_coords = point_coords.reshape(point_coords.size(0), point_coords.size(1), 1, point_coords.size(2))
+ #output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
+ # copy from mmcv
+ output = bilinear_grid_sample(input, 2.0 * point_coords - 1.0, align_corners=False)
if add_dim:
output = output.squeeze(3)
return output
@@ -147,8 +235,11 @@ def get_uncertain_point_coords_on_grid(uncertainty_map, num_points):
num_points = min(H * W, num_points)
point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]
point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)
- point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
- point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
+ point_indices = point_indices.to(torch.int32)
+ x = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
+ y = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
+ point_coords = torch.stack([x, y], dim=-1)
+ point_indices = point_indices.to(torch.int64)
return point_indices, point_coords
@@ -126,7 +126,9 @@ class PointRendSemSegHead(nn.Module):
# put sem seg point predictions to the right places on the upsampled grid.
N, C, H, W = sem_seg_logits.shape
- point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
+ point_indices = point_indices.unsqueeze(1)
+ s = torch.zeros(point_indices.size(0), C, point_indices.size(2))
+ point_indices = point_indices.to(dtype=torch.int32).expand_as(s).to(dtype=torch.int64)
sem_seg_logits = (
sem_seg_logits.reshape(N, C, H * W)
.scatter_(2, point_indices, point_logits)