2c29afadf3
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import torch
|
|
from ..base_model import BaseModel
|
|
from ..backbones.resgcn import ResGCN
|
|
from ..modules import Graph
|
|
import torch.nn.functional as F
|
|
|
|
class GaitGraph1(BaseModel):
|
|
"""
|
|
GaitGraph1: Gaitgraph: Graph Convolutional Network for Skeleton-Based Gait Recognition
|
|
Paper: https://ieeexplore.ieee.org/document/9506717
|
|
Github: https://github.com/tteepe/GaitGraph
|
|
"""
|
|
def build_network(self, model_cfg):
|
|
|
|
self.joint_format = model_cfg['joint_format']
|
|
self.input_num = model_cfg['input_num']
|
|
self.block = model_cfg['block']
|
|
self.input_branch = model_cfg['input_branch']
|
|
self.main_stream = model_cfg['main_stream']
|
|
self.num_class = model_cfg['num_class']
|
|
self.reduction = model_cfg['reduction']
|
|
self.tta = model_cfg['tta']
|
|
|
|
## Graph Init ##
|
|
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
|
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
|
|
## Network ##
|
|
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
|
|
main_stream=self.main_stream, num_class=self.num_class,
|
|
reduction=self.reduction, block=self.block,graph=self.A)
|
|
|
|
def forward(self, inputs):
|
|
|
|
ipts, labs, type_, view_, seqL = inputs
|
|
x_input = ipts[0] # N T C V I
|
|
# x = N, T, C, V, M -> N, C, T, V, M
|
|
x_input = x_input.permute(0, 2, 3, 4, 1).contiguous()
|
|
N, T, V, I, C = x_input.size()
|
|
|
|
pose = x_input
|
|
if self.training:
|
|
x_input = torch.cat([x_input[:,:int(T/2),...],x_input[:,int(T/2):,...]],dim=0) #[8, 60, 17, 1, 3]
|
|
elif self.tta:
|
|
data_flipped = torch.flip(x_input,dims=[1])
|
|
x_input = torch.cat([x_input,data_flipped], dim=0)
|
|
|
|
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
|
|
|
|
# resgcn
|
|
x = self.ResGCN(x)
|
|
x = F.normalize(x, dim=1, p=2) # norm #only for GaitGraph1 # Remove from GaitGraph2
|
|
|
|
if self.training:
|
|
f1, f2 = torch.split(x, [N, N], dim=0)
|
|
embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #[4, 2, 128]
|
|
|
|
elif self.tta:
|
|
f1, f2 = torch.split(x, [N, N], dim=0)
|
|
embed = torch.mean(torch.stack([f1, f2]), dim=0)
|
|
embed = embed.unsqueeze(-1)
|
|
else:
|
|
embed = embed.unsqueeze(-1)
|
|
|
|
retval = {
|
|
'training_feat': {
|
|
'SupConLoss': {'features': embed , 'labels': labs}, # loss
|
|
},
|
|
'visual_summary': {
|
|
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': embed # for metric
|
|
}
|
|
}
|
|
return retval
|