add gaitedge training code

This commit is contained in:
darkliang
2022-07-17 13:47:50 +08:00
parent 4205c5f283
commit b183455eb8
17 changed files with 814 additions and 11 deletions
+56
View File
@@ -3,6 +3,7 @@ import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import clones, is_list_or_tuple
from torchvision.ops import RoIAlign
class HorizontalPoolingPyramid():
@@ -182,6 +183,61 @@ class BasicConv3d(nn.Module):
return outs
class SilhouetteCropAndResize(nn.Module):
def __init__(self, H=64, W=44, eps=1, **kwargs):
super(SilhouetteCropAndResize, self).__init__()
self.H, self.W, self.eps = H, W, eps
self.Pad = nn.ZeroPad2d((int(self.W / 2), int(self.W / 2), 0, 0))
self.RoiPool = RoIAlign((self.H, self.W), 1, sampling_ratio=-1)
def forward(self, feature_map, binary_mask, w_h_ratio):
"""
In sils: [n, c, h, w]
w_h_ratio: [n, 1]
Out aligned_sils: [n, c, H, W]
"""
n, c, h, w = feature_map.size()
# w_h_ratio = w_h_ratio.repeat(1, 1) # [n, 1]
w_h_ratio = w_h_ratio.view(-1, 1) # [n, 1]
h_sum = binary_mask.sum(-1) # [n, c, h]
_ = (h_sum >= self.eps).float().cumsum(axis=-1) # [n, c, h]
h_top = (_ == 0).float().sum(-1) # [n, c]
h_bot = (_ != torch.max(_, dim=-1, keepdim=True)
[0]).float().sum(-1) + 1. # [n, c]
w_sum = binary_mask.sum(-2) # [n, c, w]
w_cumsum = w_sum.cumsum(axis=-1) # [n, c, w]
w_h_sum = w_sum.sum(-1).unsqueeze(-1) # [n, c, 1]
w_center = (w_cumsum < w_h_sum / 2.).float().sum(-1) # [n, c]
p1 = self.W - self.H * w_h_ratio
p1 = p1 / 2.
p1 = torch.clamp(p1, min=0) # [n, c]
t_w = w_h_ratio * self.H / w
p2 = p1 / t_w # [n, c]
height = h_bot - h_top # [n, c]
width = height * w / h # [n, c]
width_p = int(self.W / 2)
feature_map = self.Pad(feature_map)
w_center = w_center + width_p # [n, c]
w_left = w_center - width / 2 - p2 # [n, c]
w_right = w_center + width / 2 + p2 # [n, c]
w_left = torch.clamp(w_left, min=0., max=w+2*width_p)
w_right = torch.clamp(w_right, min=0., max=w+2*width_p)
boxes = torch.cat([w_left, h_top, w_right, h_bot], dim=-1)
# index of bbox in batch
box_index = torch.arange(n, device=feature_map.device)
rois = torch.cat([box_index.view(-1, 1), boxes], -1)
crops = self.RoiPool(feature_map, rois) # [n, c, H, W]
return crops
def RmBN2dAffine(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):