Update modules.py
This commit is contained in:
@@ -861,3 +861,84 @@ class BasicBlock3D(nn.Module):
|
|||||||
out = self.relu(out)
|
out = self.relu(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from https://github.com/autonomousvision/unimatch
|
||||||
|
class FlowFunc(nn.Module):
|
||||||
|
def __init__(self, radius=3, padding_mode='zeros'):
|
||||||
|
super(FlowFunc, self).__init__()
|
||||||
|
self.radius = radius
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
def coords_grid(self, n, h, w, device=None):
|
||||||
|
assert device is not None
|
||||||
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [h, w]
|
||||||
|
stacks = [x, y]
|
||||||
|
grid = torch.stack(stacks, dim=0).float() # [2, h, w]
|
||||||
|
grid = grid[None].repeat(n, 1, 1, 1) # [n, 2, h, w]
|
||||||
|
return grid.to(device)
|
||||||
|
|
||||||
|
def generate_window_grid(self, h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
||||||
|
assert device is not None
|
||||||
|
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w),
|
||||||
|
torch.linspace(h_min, h_max, len_h)],
|
||||||
|
)
|
||||||
|
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [h, w, 2]
|
||||||
|
return grid.to(device)
|
||||||
|
|
||||||
|
def normalize_coords(self, coords, h, w):
|
||||||
|
# coords: [n*s, h, w, 2]
|
||||||
|
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
||||||
|
return (coords - c) / c # [-1, 1]
|
||||||
|
|
||||||
|
def forward(self, feature0, feature1):
|
||||||
|
'''
|
||||||
|
features: [n, c, s, h, w]
|
||||||
|
'''
|
||||||
|
n = feature0.size(0)
|
||||||
|
s = feature1.size(2)
|
||||||
|
feature0 = rearrange(feature0, 'n c s h w -> (n s) c h w')
|
||||||
|
feature1 = rearrange(feature1, 'n c s h w -> (n s) c h w')
|
||||||
|
|
||||||
|
n_s, c, h, w = feature1.size()
|
||||||
|
coords_init = self.coords_grid(n_s, h, w, feature1.device) # [n*s, 2, h, w]
|
||||||
|
coords = coords_init.view(n_s, 2, -1).permute(0, 2, 1) # [n*s, h*w, 2]
|
||||||
|
|
||||||
|
local_h = 2 * self.radius + 1
|
||||||
|
local_w = 2 * self.radius + 1
|
||||||
|
|
||||||
|
window_grid = self.generate_window_grid(-self.radius, self.radius, -self.radius, self.radius,
|
||||||
|
local_h, local_w, device=feature0.device) # [2r+1, 2r+1, 2]
|
||||||
|
window_grid = window_grid.reshape(-1, 2).repeat(n_s, 1, 1, 1) # [n*s, 1, (2r+1)**2, 2]
|
||||||
|
sample_coords = coords.unsqueeze(-2) + window_grid # [n*s, h*w, (2r+1)**2, 2]
|
||||||
|
|
||||||
|
sample_coords_softmax = sample_coords
|
||||||
|
# exclude coords that are out of image space
|
||||||
|
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [n*s, h*w, (2r+1)**2]
|
||||||
|
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [n*s, h*w, (2r+1)**2]
|
||||||
|
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
||||||
|
|
||||||
|
# normalize coordinates to [-1, 1]
|
||||||
|
sample_coords_norm = self.normalize_coords(sample_coords, h, w) # [-1, 1]
|
||||||
|
window_feature = F.grid_sample(feature1.contiguous(), sample_coords_norm.contiguous(),
|
||||||
|
padding_mode=self.padding_mode, align_corners=True
|
||||||
|
).permute(0, 2, 1, 3).contiguous() # [n*s, h*w, c, (2r+1)**2]
|
||||||
|
feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(n_s, h * w, 1, c) # [n*s, h*w, 1, c]
|
||||||
|
|
||||||
|
corr = torch.matmul(feature0_view, window_feature).view(n_s, h * w, -1) / (c ** 0.5) # [n*s, h*w, (2r+1)**2]
|
||||||
|
|
||||||
|
# mask invalid locations
|
||||||
|
corr[~valid] = float("-inf")
|
||||||
|
# corr[~valid] = -1e9
|
||||||
|
|
||||||
|
prob = F.softmax(corr, -1) # [n*s, h*w, (2r+1)**2]
|
||||||
|
|
||||||
|
correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
|
||||||
|
n_s, h, w, 2).permute(0, 3, 1, 2) # [n*s, 2, h, w]
|
||||||
|
|
||||||
|
flow = correspondence - coords_init # [n*s, 2, h, w]
|
||||||
|
flow = rearrange(flow, '(n s) c h w -> n c s h w', n=n)
|
||||||
|
correspondence = rearrange(correspondence, '(n s) c h w -> n c s h w', n=n)
|
||||||
|
|
||||||
|
return flow
|
||||||
|
|||||||
Reference in New Issue
Block a user