444 lines
14 KiB
Python
Executable File
444 lines
14 KiB
Python
Executable File
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from easydict import EasyDict as edict
|
|
from sklearn.decomposition import PCA
|
|
import matplotlib.pyplot as plt
|
|
|
|
EPS = 1e-6
|
|
|
|
def nearest_sample2d(im, x, y, return_inbounds=False):
|
|
# x and y are each B, N
|
|
# output is B, C, N
|
|
if len(im.shape) == 5:
|
|
B, N, C, H, W = list(im.shape)
|
|
else:
|
|
B, C, H, W = list(im.shape)
|
|
N = list(x.shape)[1]
|
|
|
|
x = x.float()
|
|
y = y.float()
|
|
H_f = torch.tensor(H, dtype=torch.float32)
|
|
W_f = torch.tensor(W, dtype=torch.float32)
|
|
|
|
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
|
|
|
max_y = (H_f - 1).int()
|
|
max_x = (W_f - 1).int()
|
|
|
|
x0 = torch.floor(x).int()
|
|
x1 = x0 + 1
|
|
y0 = torch.floor(y).int()
|
|
y1 = y0 + 1
|
|
|
|
x0_clip = torch.clamp(x0, 0, max_x)
|
|
x1_clip = torch.clamp(x1, 0, max_x)
|
|
y0_clip = torch.clamp(y0, 0, max_y)
|
|
y1_clip = torch.clamp(y1, 0, max_y)
|
|
dim2 = W
|
|
dim1 = W * H
|
|
|
|
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
|
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
|
|
|
base_y0 = base + y0_clip * dim2
|
|
base_y1 = base + y1_clip * dim2
|
|
|
|
idx_y0_x0 = base_y0 + x0_clip
|
|
idx_y0_x1 = base_y0 + x1_clip
|
|
idx_y1_x0 = base_y1 + x0_clip
|
|
idx_y1_x1 = base_y1 + x1_clip
|
|
|
|
# use the indices to lookup pixels in the flat image
|
|
# im is B x C x H x W
|
|
# move C out to last dim
|
|
if len(im.shape) == 5:
|
|
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
|
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
else:
|
|
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
|
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
|
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
|
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
|
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
|
|
|
# Finally calculate interpolated values.
|
|
x0_f = x0.float()
|
|
x1_f = x1.float()
|
|
y0_f = y0.float()
|
|
y1_f = y1.float()
|
|
|
|
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
|
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
|
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
|
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
|
|
|
# w_yi_xo is B * N * 1
|
|
max_idx = torch.cat([w_y0_x0, w_y0_x1, w_y1_x0, w_y1_x1], dim=-1).max(dim=-1)[1]
|
|
output = torch.stack([i_y0_x0, i_y0_x1, i_y1_x0, i_y1_x1], dim=-1).gather(-1, max_idx[...,None,None].repeat(1,1,C,1)).squeeze(-1)
|
|
|
|
# output is B*N x C
|
|
output = output.view(B, -1, C)
|
|
output = output.permute(0, 2, 1)
|
|
# output is B x C x N
|
|
|
|
if return_inbounds:
|
|
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
|
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
|
inbounds = (x_valid & y_valid).float()
|
|
inbounds = inbounds.reshape(
|
|
B, N
|
|
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
|
return output, inbounds
|
|
|
|
return output # B, C, N
|
|
|
|
def smart_cat(tensor1, tensor2, dim):
|
|
if tensor1 is None:
|
|
return tensor2
|
|
return torch.cat([tensor1, tensor2], dim=dim)
|
|
|
|
|
|
def normalize_single(d):
|
|
# d is a whatever shape torch tensor
|
|
dmin = torch.min(d)
|
|
dmax = torch.max(d)
|
|
d = (d - dmin) / (EPS + (dmax - dmin))
|
|
return d
|
|
|
|
|
|
def normalize(d):
|
|
# d is B x whatever. normalize within each element of the batch
|
|
out = torch.zeros(d.size())
|
|
if d.is_cuda:
|
|
out = out.cuda()
|
|
B = list(d.size())[0]
|
|
for b in list(range(B)):
|
|
out[b] = normalize_single(d[b])
|
|
return out
|
|
|
|
|
|
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
|
# returns a meshgrid sized B x Y x X
|
|
|
|
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
|
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
|
grid_y = grid_y.repeat(B, 1, X)
|
|
|
|
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
|
grid_x = torch.reshape(grid_x, [1, 1, X])
|
|
grid_x = grid_x.repeat(B, Y, 1)
|
|
|
|
if stack:
|
|
# note we stack in xy order
|
|
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
|
grid = torch.stack([grid_x, grid_y], dim=-1)
|
|
return grid
|
|
else:
|
|
return grid_y, grid_x
|
|
|
|
|
|
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
|
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
|
# returns shape-1
|
|
# axis can be a list of axes
|
|
for (a, b) in zip(x.size(), mask.size()):
|
|
assert a == b # some shape mismatch!
|
|
prod = x * mask
|
|
if dim is None:
|
|
numer = torch.sum(prod)
|
|
denom = EPS + torch.sum(mask)
|
|
else:
|
|
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
|
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
|
|
|
mean = numer / denom
|
|
return mean
|
|
|
|
|
|
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
|
# x and y are each B, N
|
|
# output is B, C, N
|
|
if len(im.shape) == 5:
|
|
B, N, C, H, W = list(im.shape)
|
|
else:
|
|
B, C, H, W = list(im.shape)
|
|
N = list(x.shape)[1]
|
|
|
|
x = x.float()
|
|
y = y.float()
|
|
H_f = torch.tensor(H, dtype=torch.float32)
|
|
W_f = torch.tensor(W, dtype=torch.float32)
|
|
|
|
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
|
|
|
max_y = (H_f - 1).int()
|
|
max_x = (W_f - 1).int()
|
|
|
|
x0 = torch.floor(x).int()
|
|
x1 = x0 + 1
|
|
y0 = torch.floor(y).int()
|
|
y1 = y0 + 1
|
|
|
|
x0_clip = torch.clamp(x0, 0, max_x)
|
|
x1_clip = torch.clamp(x1, 0, max_x)
|
|
y0_clip = torch.clamp(y0, 0, max_y)
|
|
y1_clip = torch.clamp(y1, 0, max_y)
|
|
dim2 = W
|
|
dim1 = W * H
|
|
|
|
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
|
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
|
|
|
base_y0 = base + y0_clip * dim2
|
|
base_y1 = base + y1_clip * dim2
|
|
|
|
idx_y0_x0 = base_y0 + x0_clip
|
|
idx_y0_x1 = base_y0 + x1_clip
|
|
idx_y1_x0 = base_y1 + x0_clip
|
|
idx_y1_x1 = base_y1 + x1_clip
|
|
|
|
# use the indices to lookup pixels in the flat image
|
|
# im is B x C x H x W
|
|
# move C out to last dim
|
|
if len(im.shape) == 5:
|
|
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
|
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
|
0, 2, 1
|
|
)
|
|
else:
|
|
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
|
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
|
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
|
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
|
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
|
|
|
# Finally calculate interpolated values.
|
|
x0_f = x0.float()
|
|
x1_f = x1.float()
|
|
y0_f = y0.float()
|
|
y1_f = y1.float()
|
|
|
|
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
|
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
|
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
|
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
|
|
|
output = (
|
|
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
|
)
|
|
# output is B*N x C
|
|
output = output.view(B, -1, C)
|
|
output = output.permute(0, 2, 1)
|
|
# output is B x C x N
|
|
|
|
if return_inbounds:
|
|
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
|
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
|
inbounds = (x_valid & y_valid).float()
|
|
inbounds = inbounds.reshape(
|
|
B, N
|
|
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
|
return output, inbounds
|
|
|
|
return output # B, C, N
|
|
|
|
|
|
def procrustes_analysis(X0,X1,Weight): # [B,N,3]
|
|
# translation
|
|
t0 = X0.mean(dim=1,keepdim=True)
|
|
t1 = X1.mean(dim=1,keepdim=True)
|
|
X0c = X0-t0
|
|
X1c = X1-t1
|
|
# scale
|
|
# s0 = (X0c**2).sum(dim=-1).mean().sqrt()
|
|
# s1 = (X1c**2).sum(dim=-1).mean().sqrt()
|
|
# X0cs = X0c/s0
|
|
# X1cs = X1c/s1
|
|
# rotation (use double for SVD, float loses precision)
|
|
U,_,V = (X0c.t()@X1c).double().svd(some=True)
|
|
R = (U@V.t()).float()
|
|
if R.det()<0: R[2] *= -1
|
|
# align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0
|
|
se3 = edict(t0=t0[0],t1=t1[0],R=R)
|
|
|
|
return se3
|
|
|
|
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border", interp_mode="bilinear"):
|
|
r"""Sample a tensor using bilinear interpolation
|
|
|
|
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
|
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
|
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
|
convention.
|
|
|
|
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
|
:math:`B` is the batch size, :math:`C` is the number of channels,
|
|
:math:`H` is the height of the image, and :math:`W` is the width of the
|
|
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
|
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
|
|
|
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
|
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
|
that in this case the order of the components is slightly different
|
|
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
|
|
|
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
|
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
|
left-most image pixel :math:`W-1` to the center of the right-most
|
|
pixel.
|
|
|
|
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
|
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
|
the left-most pixel :math:`W` to the right edge of the right-most
|
|
pixel.
|
|
|
|
Similar conventions apply to the :math:`y` for the range
|
|
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
|
:math:`[0,T-1]` and :math:`[0,T]`.
|
|
|
|
Args:
|
|
input (Tensor): batch of input images.
|
|
coords (Tensor): batch of coordinates.
|
|
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
|
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
|
|
|
Returns:
|
|
Tensor: sampled points.
|
|
"""
|
|
|
|
sizes = input.shape[2:]
|
|
|
|
assert len(sizes) in [2, 3]
|
|
|
|
if len(sizes) == 3:
|
|
# t x y -> x y t to match dimensions T H W in grid_sample
|
|
coords = coords[..., [1, 2, 0]]
|
|
|
|
if align_corners:
|
|
coords = coords * torch.tensor(
|
|
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
|
)
|
|
else:
|
|
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
|
|
|
coords -= 1
|
|
|
|
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode, mode=interp_mode)
|
|
|
|
|
|
def sample_features4d(input, coords, interp_mode="bilinear"):
|
|
r"""Sample spatial features
|
|
|
|
`sample_features4d(input, coords)` samples the spatial features
|
|
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
|
|
|
The field is sampled at coordinates :attr:`coords` using bilinear
|
|
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
|
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
|
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
|
|
|
The output tensor has one feature per point, and has shape :math:`(B,
|
|
R, C)`.
|
|
|
|
Args:
|
|
input (Tensor): spatial features.
|
|
coords (Tensor): points.
|
|
|
|
Returns:
|
|
Tensor: sampled features.
|
|
"""
|
|
|
|
B, _, _, _ = input.shape
|
|
|
|
# B R 2 -> B R 1 2
|
|
coords = coords.unsqueeze(2)
|
|
|
|
# B C R 1
|
|
feats = bilinear_sampler(input, coords, interp_mode=interp_mode)
|
|
|
|
return feats.permute(0, 2, 1, 3).view(
|
|
B, -1, feats.shape[1] * feats.shape[3]
|
|
) # B C R 1 -> B R C
|
|
|
|
|
|
def sample_features5d(input, coords, interp_mode="bilinear"):
|
|
r"""Sample spatio-temporal features
|
|
|
|
`sample_features5d(input, coords)` works in the same way as
|
|
:func:`sample_features4d` but for spatio-temporal features and points:
|
|
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
|
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
|
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
|
|
|
Args:
|
|
input (Tensor): spatio-temporal features.
|
|
coords (Tensor): spatio-temporal points.
|
|
|
|
Returns:
|
|
Tensor: sampled features.
|
|
"""
|
|
|
|
B, T, _, _, _ = input.shape
|
|
|
|
# B T C H W -> B C T H W
|
|
input = input.permute(0, 2, 1, 3, 4)
|
|
|
|
# B R1 R2 3 -> B R1 R2 1 3
|
|
coords = coords.unsqueeze(3)
|
|
|
|
# B C R1 R2 1
|
|
feats = bilinear_sampler(input, coords, interp_mode=interp_mode)
|
|
|
|
return feats.permute(0, 2, 3, 1, 4).view(
|
|
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
|
) # B C R1 R2 1 -> B R1 R2 C
|
|
|
|
def vis_PCA(fmaps, save_dir):
|
|
"""
|
|
visualize the PCA of the feature maps
|
|
args:
|
|
fmaps: feature maps 1 C H W
|
|
save_dir: the directory to save the PCA visualization
|
|
"""
|
|
|
|
pca = PCA(n_components=3)
|
|
fmap_vis = fmaps[0,...]
|
|
fmap_vnorm = (
|
|
(fmap_vis-fmap_vis.min())/
|
|
(fmap_vis.max()-fmap_vis.min()))
|
|
H_vis, W_vis = fmap_vis.shape[1:]
|
|
fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0],
|
|
-1).permute(1,0)
|
|
fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy())
|
|
pca = fmap_pca.reshape(H_vis,W_vis,3)
|
|
plt.imsave(save_dir,
|
|
(
|
|
(pca-pca.min())/
|
|
(pca.max()-pca.min())
|
|
)) |