add gaitedge training code
This commit is contained in:
@@ -3,6 +3,7 @@ import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils import clones, is_list_or_tuple
|
||||
from torchvision.ops import RoIAlign
|
||||
|
||||
|
||||
class HorizontalPoolingPyramid():
|
||||
@@ -182,6 +183,61 @@ class BasicConv3d(nn.Module):
|
||||
return outs
|
||||
|
||||
|
||||
class SilhouetteCropAndResize(nn.Module):
|
||||
def __init__(self, H=64, W=44, eps=1, **kwargs):
|
||||
super(SilhouetteCropAndResize, self).__init__()
|
||||
self.H, self.W, self.eps = H, W, eps
|
||||
self.Pad = nn.ZeroPad2d((int(self.W / 2), int(self.W / 2), 0, 0))
|
||||
self.RoiPool = RoIAlign((self.H, self.W), 1, sampling_ratio=-1)
|
||||
|
||||
def forward(self, feature_map, binary_mask, w_h_ratio):
|
||||
"""
|
||||
In sils: [n, c, h, w]
|
||||
w_h_ratio: [n, 1]
|
||||
Out aligned_sils: [n, c, H, W]
|
||||
"""
|
||||
n, c, h, w = feature_map.size()
|
||||
# w_h_ratio = w_h_ratio.repeat(1, 1) # [n, 1]
|
||||
w_h_ratio = w_h_ratio.view(-1, 1) # [n, 1]
|
||||
|
||||
h_sum = binary_mask.sum(-1) # [n, c, h]
|
||||
_ = (h_sum >= self.eps).float().cumsum(axis=-1) # [n, c, h]
|
||||
h_top = (_ == 0).float().sum(-1) # [n, c]
|
||||
h_bot = (_ != torch.max(_, dim=-1, keepdim=True)
|
||||
[0]).float().sum(-1) + 1. # [n, c]
|
||||
|
||||
w_sum = binary_mask.sum(-2) # [n, c, w]
|
||||
w_cumsum = w_sum.cumsum(axis=-1) # [n, c, w]
|
||||
w_h_sum = w_sum.sum(-1).unsqueeze(-1) # [n, c, 1]
|
||||
w_center = (w_cumsum < w_h_sum / 2.).float().sum(-1) # [n, c]
|
||||
|
||||
p1 = self.W - self.H * w_h_ratio
|
||||
p1 = p1 / 2.
|
||||
p1 = torch.clamp(p1, min=0) # [n, c]
|
||||
t_w = w_h_ratio * self.H / w
|
||||
p2 = p1 / t_w # [n, c]
|
||||
|
||||
height = h_bot - h_top # [n, c]
|
||||
width = height * w / h # [n, c]
|
||||
width_p = int(self.W / 2)
|
||||
|
||||
feature_map = self.Pad(feature_map)
|
||||
w_center = w_center + width_p # [n, c]
|
||||
|
||||
w_left = w_center - width / 2 - p2 # [n, c]
|
||||
w_right = w_center + width / 2 + p2 # [n, c]
|
||||
|
||||
w_left = torch.clamp(w_left, min=0., max=w+2*width_p)
|
||||
w_right = torch.clamp(w_right, min=0., max=w+2*width_p)
|
||||
|
||||
boxes = torch.cat([w_left, h_top, w_right, h_bot], dim=-1)
|
||||
# index of bbox in batch
|
||||
box_index = torch.arange(n, device=feature_map.device)
|
||||
rois = torch.cat([box_index.view(-1, 1), boxes], -1)
|
||||
crops = self.RoiPool(feature_map, rois) # [n, c, H, W]
|
||||
return crops
|
||||
|
||||
|
||||
def RmBN2dAffine(model):
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
|
||||
Reference in New Issue
Block a user