import torch
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
"""
Left-multiplies MxM @ NxM. Returns NxM.
"""
res = torch.matmul(vectors4, matrix.T)
return res
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
"""
Normalize vector lengths.
"""
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
def torch_dot(x: torch.Tensor, y: torch.Tensor):
"""
Dot product of two tensors.
"""
return (x * y).sum(-1)
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
"""
Author: Petr Kellnhofer
Intersects rays with the [-1, 1] NDC volume.
Returns min and max distance of entry.
Returns -1 for no intersection.
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
"""
o_shape = rays_o.shape
rays_o = rays_o.detach().reshape(-1, 3)
rays_d = rays_d.detach().reshape(-1, 3)
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
invdir = 1 / rays_d
sign = (invdir < 0).long()
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
tmin = torch.max(tmin, tymin)
tmax = torch.min(tmax, tymax)
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
tmin = torch.max(tmin, tzmin)
tmax = torch.min(tmax, tzmax)
tmin[torch.logical_not(is_valid)] = -1
tmax[torch.logical_not(is_valid)] = -2
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
"""
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
"""
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
out = start[None] + steps * (stop - start)[None]
return out