From 285bd71f92f853cc98946ac85c4e694ecac714d1 Mon Sep 17 00:00:00 2001 From: Dongyang Jin <73057174+jdyjjj@users.noreply.github.com> Date: Thu, 14 Aug 2025 23:01:51 +0800 Subject: [PATCH] Update modules.py --- opengait/modeling/modules.py | 81 ++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/opengait/modeling/modules.py b/opengait/modeling/modules.py index e1ee2b0..2d165b0 100644 --- a/opengait/modeling/modules.py +++ b/opengait/modeling/modules.py @@ -861,3 +861,84 @@ class BasicBlock3D(nn.Module): out = self.relu(out) return out + + + +# Modified from https://github.com/autonomousvision/unimatch +class FlowFunc(nn.Module): + def __init__(self, radius=3, padding_mode='zeros'): + super(FlowFunc, self).__init__() + self.radius = radius + self.padding_mode = padding_mode + + def coords_grid(self, n, h, w, device=None): + assert device is not None + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [h, w] + stacks = [x, y] + grid = torch.stack(stacks, dim=0).float() # [2, h, w] + grid = grid[None].repeat(n, 1, 1, 1) # [n, 2, h, w] + return grid.to(device) + + def generate_window_grid(self, h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w), + torch.linspace(h_min, h_max, len_h)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [h, w, 2] + return grid.to(device) + + def normalize_coords(self, coords, h, w): + # coords: [n*s, h, w, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + def forward(self, feature0, feature1): + ''' + features: [n, c, s, h, w] + ''' + n = feature0.size(0) + s = feature1.size(2) + feature0 = rearrange(feature0, 'n c s h w -> (n s) c h w') + feature1 = rearrange(feature1, 'n c s h w -> (n s) c h w') + + n_s, c, h, w = feature1.size() + coords_init = self.coords_grid(n_s, h, w, feature1.device) # [n*s, 2, h, w] + coords = coords_init.view(n_s, 2, -1).permute(0, 2, 1) # [n*s, h*w, 2] + + local_h = 2 * self.radius + 1 + local_w = 2 * self.radius + 1 + + window_grid = self.generate_window_grid(-self.radius, self.radius, -self.radius, self.radius, + local_h, local_w, device=feature0.device) # [2r+1, 2r+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(n_s, 1, 1, 1) # [n*s, 1, (2r+1)**2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [n*s, h*w, (2r+1)**2, 2] + + sample_coords_softmax = sample_coords + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [n*s, h*w, (2r+1)**2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [n*s, h*w, (2r+1)**2] + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = self.normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1.contiguous(), sample_coords_norm.contiguous(), + padding_mode=self.padding_mode, align_corners=True + ).permute(0, 2, 1, 3).contiguous() # [n*s, h*w, c, (2r+1)**2] + feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(n_s, h * w, 1, c) # [n*s, h*w, 1, c] + + corr = torch.matmul(feature0_view, window_feature).view(n_s, h * w, -1) / (c ** 0.5) # [n*s, h*w, (2r+1)**2] + + # mask invalid locations + corr[~valid] = float("-inf") + # corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [n*s, h*w, (2r+1)**2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + n_s, h, w, 2).permute(0, 3, 1, 2) # [n*s, 2, h, w] + + flow = correspondence - coords_init # [n*s, 2, h, w] + flow = rearrange(flow, '(n s) c h w -> n c s h w', n=n) + correspondence = rearrange(correspondence, '(n s) c h w -> n c s h w', n=n) + + return flow