add BiggerGait

This commit is contained in:
bugjudger
2025-10-07 18:15:04 -04:00
parent ae87f04d62
commit 67940f6561
4 changed files with 577 additions and 0 deletions
@@ -188,3 +188,93 @@ class Baseline(nn.Module):
_, logits = self.BNNecks(embed_1) # [n, c, p]
# return embed_1, logits, heat_mapt
return embed_1, logits
class Baseline_Single(nn.Module):
def __init__(self, model_cfg):
super(Baseline_Single, self).__init__()
self.pre_rgb = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg']))
self.post_backbone = SetBlockWrapper(Post_ResNet9(**model_cfg['backbone_cfg']))
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
def pre_forward(self, appearance, *args, **kwargs):
outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w]
outs = self.post_backbone(outs, *args, **kwargs)
return outs
def forward(self, appearance, seqL, *args, **kwargs):
outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w]
outs = self.post_backbone(outs, *args, **kwargs)
# Temporal Pooling, TP
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM
outs = self.HPP(outs) # [n, c, p]
embed_1 = self.FCs(outs) # [n, c, p]
_, logits = self.BNNecks(embed_1) # [n, c, p]
return embed_1, logits
def test_1(self, appearance, *args, **kwargs):
outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w]
outs = self.post_backbone(outs, *args, **kwargs)
return outs
def test_2(self, outs, seqL):
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
outs = self.HPP(outs) # [n, c, p]
embed_1 = self.FCs(outs) # [n, c, p]
_, logits = self.BNNecks(embed_1) # [n, c, p]
return embed_1, logits
class Baseline_Share(nn.Module):
def __init__(self, model_cfg):
super(Baseline_Share, self).__init__()
self.head_num = model_cfg['head_num']
self.num_FPN = model_cfg['total_layer_num'] // model_cfg['group_layer_num']
self.real_gait = nn.ModuleList([
Baseline_Single(model_cfg) for _ in range(self.head_num)
])
self.Gait_List = nn.ModuleList([
self.real_gait[_ // (self.num_FPN // self.head_num)] for _ in range(self.num_FPN)
])
def forward(self, x, seqL):
x = self.test_1(x)
embed_list, log_list = self.test_2(x, seqL)
return embed_list, log_list
def test_1(self, x, *args, **kwargs):
# x: [n, c, s, h, w]
n,c,s,h,w = x.shape
x_list = list(torch.chunk(x, self.num_FPN, dim=1))
for i in range(self.num_FPN):
x_list[i] = self.Gait_List[i].test_1(x_list[i], *args, **kwargs)
x = torch.concat(x_list, dim=1)
return x
def test_2(self, x, seqL):
# x: [n, c, s, h, w]
# embed_1: [n, c, p]
x_list = torch.chunk(x, self.num_FPN, dim=1)
embed_list = []
log_list = []
for i in range(self.num_FPN):
embed_1, logits = self.Gait_List[i].test_2(x_list[i], seqL)
embed_list.append(embed_1)
log_list.append(logits)
return embed_list, log_list