Support skeleton (#155)
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from ..base_model import BaseModel
|
||||
from ..backbones.resgcn import ResGCN
|
||||
from ..modules import Graph
|
||||
import torch.nn.functional as F
|
||||
|
||||
class GaitGraph1(BaseModel):
|
||||
"""
|
||||
GaitGraph1: Gaitgraph: Graph Convolutional Network for Skeleton-Based Gait Recognition
|
||||
Paper: https://ieeexplore.ieee.org/document/9506717
|
||||
Github: https://github.com/tteepe/GaitGraph
|
||||
"""
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.joint_format = model_cfg['joint_format']
|
||||
self.input_num = model_cfg['input_num']
|
||||
self.block = model_cfg['block']
|
||||
self.input_branch = model_cfg['input_branch']
|
||||
self.main_stream = model_cfg['main_stream']
|
||||
self.num_class = model_cfg['num_class']
|
||||
self.reduction = model_cfg['reduction']
|
||||
self.tta = model_cfg['tta']
|
||||
|
||||
## Graph Init ##
|
||||
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
||||
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
|
||||
## Network ##
|
||||
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
|
||||
main_stream=self.main_stream, num_class=self.num_class,
|
||||
reduction=self.reduction, block=self.block,graph=self.A)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
ipts, labs, type_, view_, seqL = inputs
|
||||
x_input = ipts[0] # N T C V I
|
||||
# x = N, T, C, V, M -> N, C, T, V, M
|
||||
x_input = x_input.permute(0, 2, 3, 4, 1).contiguous()
|
||||
N, T, V, I, C = x_input.size()
|
||||
|
||||
pose = x_input
|
||||
if self.training:
|
||||
x_input = torch.cat([x_input[:,:int(T/2),...],x_input[:,int(T/2):,...]],dim=0) #[8, 60, 17, 1, 3]
|
||||
elif self.tta:
|
||||
data_flipped = torch.flip(x_input,dims=[1])
|
||||
x_input = torch.cat([x_input,data_flipped], dim=0)
|
||||
|
||||
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
|
||||
|
||||
# resgcn
|
||||
x = self.ResGCN(x)
|
||||
x = F.normalize(x, dim=1, p=2) # norm #only for GaitGraph1 # Remove from GaitGraph2
|
||||
|
||||
if self.training:
|
||||
f1, f2 = torch.split(x, [N, N], dim=0)
|
||||
embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #[4, 2, 128]
|
||||
|
||||
elif self.tta:
|
||||
f1, f2 = torch.split(x, [N, N], dim=0)
|
||||
embed = torch.mean(torch.stack([f1, f2]), dim=0)
|
||||
embed = embed.unsqueeze(-1)
|
||||
else:
|
||||
embed = embed.unsqueeze(-1)
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'SupConLoss': {'features': embed , 'labels': labs}, # loss
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed # for metric
|
||||
}
|
||||
}
|
||||
return retval
|
||||
Reference in New Issue
Block a user