Update modules.py

This commit is contained in:
Dongyang Jin
2025-08-14 23:01:51 +08:00
committed by GitHub
parent 347d157f9c
commit 285bd71f92
+81
View File
@@ -861,3 +861,84 @@ class BasicBlock3D(nn.Module):
out = self.relu(out) out = self.relu(out)
return out return out
# Modified from https://github.com/autonomousvision/unimatch
class FlowFunc(nn.Module):
def __init__(self, radius=3, padding_mode='zeros'):
super(FlowFunc, self).__init__()
self.radius = radius
self.padding_mode = padding_mode
def coords_grid(self, n, h, w, device=None):
assert device is not None
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [h, w]
stacks = [x, y]
grid = torch.stack(stacks, dim=0).float() # [2, h, w]
grid = grid[None].repeat(n, 1, 1, 1) # [n, 2, h, w]
return grid.to(device)
def generate_window_grid(self, h_min, h_max, w_min, w_max, len_h, len_w, device=None):
assert device is not None
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w),
torch.linspace(h_min, h_max, len_h)],
)
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [h, w, 2]
return grid.to(device)
def normalize_coords(self, coords, h, w):
# coords: [n*s, h, w, 2]
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
return (coords - c) / c # [-1, 1]
def forward(self, feature0, feature1):
'''
features: [n, c, s, h, w]
'''
n = feature0.size(0)
s = feature1.size(2)
feature0 = rearrange(feature0, 'n c s h w -> (n s) c h w')
feature1 = rearrange(feature1, 'n c s h w -> (n s) c h w')
n_s, c, h, w = feature1.size()
coords_init = self.coords_grid(n_s, h, w, feature1.device) # [n*s, 2, h, w]
coords = coords_init.view(n_s, 2, -1).permute(0, 2, 1) # [n*s, h*w, 2]
local_h = 2 * self.radius + 1
local_w = 2 * self.radius + 1
window_grid = self.generate_window_grid(-self.radius, self.radius, -self.radius, self.radius,
local_h, local_w, device=feature0.device) # [2r+1, 2r+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(n_s, 1, 1, 1) # [n*s, 1, (2r+1)**2, 2]
sample_coords = coords.unsqueeze(-2) + window_grid # [n*s, h*w, (2r+1)**2, 2]
sample_coords_softmax = sample_coords
# exclude coords that are out of image space
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [n*s, h*w, (2r+1)**2]
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [n*s, h*w, (2r+1)**2]
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
# normalize coordinates to [-1, 1]
sample_coords_norm = self.normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(feature1.contiguous(), sample_coords_norm.contiguous(),
padding_mode=self.padding_mode, align_corners=True
).permute(0, 2, 1, 3).contiguous() # [n*s, h*w, c, (2r+1)**2]
feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(n_s, h * w, 1, c) # [n*s, h*w, 1, c]
corr = torch.matmul(feature0_view, window_feature).view(n_s, h * w, -1) / (c ** 0.5) # [n*s, h*w, (2r+1)**2]
# mask invalid locations
corr[~valid] = float("-inf")
# corr[~valid] = -1e9
prob = F.softmax(corr, -1) # [n*s, h*w, (2r+1)**2]
correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
n_s, h, w, 2).permute(0, 3, 1, 2) # [n*s, 2, h, w]
flow = correspondence - coords_init # [n*s, 2, h, w]
flow = rearrange(flow, '(n s) c h w -> n c s h w', n=n)
correspondence = rearrange(correspondence, '(n s) c h w -> n c s h w', n=n)
return flow