add BiggerGait
This commit is contained in:
@@ -188,3 +188,93 @@ class Baseline(nn.Module):
|
||||
_, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||
# return embed_1, logits, heat_mapt
|
||||
return embed_1, logits
|
||||
|
||||
|
||||
class Baseline_Single(nn.Module):
|
||||
def __init__(self, model_cfg):
|
||||
super(Baseline_Single, self).__init__()
|
||||
self.pre_rgb = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg']))
|
||||
self.post_backbone = SetBlockWrapper(Post_ResNet9(**model_cfg['backbone_cfg']))
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def get_backbone(self, backbone_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||
return Backbone(**valid_args)
|
||||
if is_list(backbone_cfg):
|
||||
Backbone = nn.ModuleList([self.get_backbone(cfg)
|
||||
for cfg in backbone_cfg])
|
||||
return Backbone
|
||||
raise ValueError(
|
||||
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||
|
||||
def pre_forward(self, appearance, *args, **kwargs):
|
||||
outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w]
|
||||
outs = self.post_backbone(outs, *args, **kwargs)
|
||||
return outs
|
||||
|
||||
def forward(self, appearance, seqL, *args, **kwargs):
|
||||
outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w]
|
||||
outs = self.post_backbone(outs, *args, **kwargs)
|
||||
# Temporal Pooling, TP
|
||||
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||
# Horizontal Pooling Matching, HPM
|
||||
outs = self.HPP(outs) # [n, c, p]
|
||||
embed_1 = self.FCs(outs) # [n, c, p]
|
||||
_, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||
return embed_1, logits
|
||||
|
||||
def test_1(self, appearance, *args, **kwargs):
|
||||
outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w]
|
||||
outs = self.post_backbone(outs, *args, **kwargs)
|
||||
return outs
|
||||
|
||||
def test_2(self, outs, seqL):
|
||||
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||
outs = self.HPP(outs) # [n, c, p]
|
||||
embed_1 = self.FCs(outs) # [n, c, p]
|
||||
_, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||
return embed_1, logits
|
||||
|
||||
class Baseline_Share(nn.Module):
|
||||
def __init__(self, model_cfg):
|
||||
super(Baseline_Share, self).__init__()
|
||||
self.head_num = model_cfg['head_num']
|
||||
self.num_FPN = model_cfg['total_layer_num'] // model_cfg['group_layer_num']
|
||||
self.real_gait = nn.ModuleList([
|
||||
Baseline_Single(model_cfg) for _ in range(self.head_num)
|
||||
])
|
||||
self.Gait_List = nn.ModuleList([
|
||||
self.real_gait[_ // (self.num_FPN // self.head_num)] for _ in range(self.num_FPN)
|
||||
])
|
||||
|
||||
def forward(self, x, seqL):
|
||||
x = self.test_1(x)
|
||||
embed_list, log_list = self.test_2(x, seqL)
|
||||
return embed_list, log_list
|
||||
|
||||
def test_1(self, x, *args, **kwargs):
|
||||
# x: [n, c, s, h, w]
|
||||
n,c,s,h,w = x.shape
|
||||
x_list = list(torch.chunk(x, self.num_FPN, dim=1))
|
||||
for i in range(self.num_FPN):
|
||||
x_list[i] = self.Gait_List[i].test_1(x_list[i], *args, **kwargs)
|
||||
x = torch.concat(x_list, dim=1)
|
||||
return x
|
||||
|
||||
def test_2(self, x, seqL):
|
||||
# x: [n, c, s, h, w]
|
||||
# embed_1: [n, c, p]
|
||||
x_list = torch.chunk(x, self.num_FPN, dim=1)
|
||||
embed_list = []
|
||||
log_list = []
|
||||
for i in range(self.num_FPN):
|
||||
embed_1, logits = self.Gait_List[i].test_2(x_list[i], seqL)
|
||||
embed_list.append(embed_1)
|
||||
log_list.append(logits)
|
||||
return embed_list, log_list
|
||||
|
||||
Reference in New Issue
Block a user