Files
OpenGait/opengait/modeling/models/gaitgraph1.py
T
Dongyang Jin 2c29afadf3 Support skeleton (#155)
* pose

* pose

* pose

* pose

* 你的提交消息

* pose

* pose

* Delete train1.sh

* pretreatment

* configs

* pose

* reference

* Update gaittr.py

* naming

* naming

* Update transform.py

* update for datasets

* update README

* update name and README

* update

* Update transform.py
2023-09-27 16:20:00 +08:00

76 lines
2.8 KiB
Python

import torch
from ..base_model import BaseModel
from ..backbones.resgcn import ResGCN
from ..modules import Graph
import torch.nn.functional as F
class GaitGraph1(BaseModel):
"""
GaitGraph1: Gaitgraph: Graph Convolutional Network for Skeleton-Based Gait Recognition
Paper: https://ieeexplore.ieee.org/document/9506717
Github: https://github.com/tteepe/GaitGraph
"""
def build_network(self, model_cfg):
self.joint_format = model_cfg['joint_format']
self.input_num = model_cfg['input_num']
self.block = model_cfg['block']
self.input_branch = model_cfg['input_branch']
self.main_stream = model_cfg['main_stream']
self.num_class = model_cfg['num_class']
self.reduction = model_cfg['reduction']
self.tta = model_cfg['tta']
## Graph Init ##
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
## Network ##
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
main_stream=self.main_stream, num_class=self.num_class,
reduction=self.reduction, block=self.block,graph=self.A)
def forward(self, inputs):
ipts, labs, type_, view_, seqL = inputs
x_input = ipts[0] # N T C V I
# x = N, T, C, V, M -> N, C, T, V, M
x_input = x_input.permute(0, 2, 3, 4, 1).contiguous()
N, T, V, I, C = x_input.size()
pose = x_input
if self.training:
x_input = torch.cat([x_input[:,:int(T/2),...],x_input[:,int(T/2):,...]],dim=0) #[8, 60, 17, 1, 3]
elif self.tta:
data_flipped = torch.flip(x_input,dims=[1])
x_input = torch.cat([x_input,data_flipped], dim=0)
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
# resgcn
x = self.ResGCN(x)
x = F.normalize(x, dim=1, p=2) # norm #only for GaitGraph1 # Remove from GaitGraph2
if self.training:
f1, f2 = torch.split(x, [N, N], dim=0)
embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #[4, 2, 128]
elif self.tta:
f1, f2 = torch.split(x, [N, N], dim=0)
embed = torch.mean(torch.stack([f1, f2]), dim=0)
embed = embed.unsqueeze(-1)
else:
embed = embed.unsqueeze(-1)
retval = {
'training_feat': {
'SupConLoss': {'features': embed , 'labels': labs}, # loss
},
'visual_summary': {
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
},
'inference_feat': {
'embeddings': embed # for metric
}
}
return retval