Files
Dongyang Jin 2c29afadf3 Support skeleton (#155)
* pose

* pose

* pose

* pose

* 你的提交消息

* pose

* pose

* Delete train1.sh

* pretreatment

* configs

* pose

* reference

* Update gaittr.py

* naming

* naming

* Update transform.py

* update for datasets

* update README

* update name and README

* update

* Update transform.py
2023-09-27 16:20:00 +08:00

111 lines
3.8 KiB
Python

import torch
import torch.nn as nn
from ..base_model import BaseModel
from ..backbones.resgcn import ResGCN
from ..modules import Graph
import numpy as np
class GaitGraph2(BaseModel):
"""
GaitGraph2: Towards a Deeper Understanding of Skeleton-based Gait Recognition
Paper: https://openaccess.thecvf.com/content/CVPR2022W/Biometrics/papers/Teepe_Towards_a_Deeper_Understanding_of_Skeleton-Based_Gait_Recognition_CVPRW_2022_paper
Github: https://github.com/tteepe/GaitGraph2
"""
def build_network(self, model_cfg):
self.joint_format = model_cfg['joint_format']
self.input_num = model_cfg['input_num']
self.block = model_cfg['block']
self.input_branch = model_cfg['input_branch']
self.main_stream = model_cfg['main_stream']
self.num_class = model_cfg['num_class']
self.reduction = model_cfg['reduction']
self.tta = model_cfg['tta']
## Graph Init ##
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
## Network ##
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
main_stream=self.main_stream, num_class=self.num_class,
reduction=self.reduction, block=self.block,graph=self.A)
def forward(self, inputs):
ipts, labs, type_, view_, seqL = inputs
x_input = ipts[0]
N, T, V, I, C = x_input.size()
pose = x_input
flip_idx = self.graph.flip_idx
if not self.training and self.tta:
multi_input = MultiInput(self.graph.connect_joint, self.graph.center)
x1 = []
x2 = []
for i in range(N):
x1.append(multi_input(x_input[i,:,:,0,:3].flip(0)))
x2.append(multi_input(x_input[i,:,flip_idx,0,:3]))
x_input = torch.cat([x_input, torch.stack(x1,0), torch.stack(x2,0)], dim=0)
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
# resgcn
x = self.ResGCN(x)
if not self.training and self.tta:
f1, f2, f3 = torch.split(x, [N, N, N], dim=0)
x = torch.cat((f1, f2, f3), dim=1)
embed = torch.unsqueeze(x,-1)
retval = {
'training_feat': {
'SupConLoss': {'features': x , 'labels': labs}, # loss
},
'visual_summary': {
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
},
'inference_feat': {
'embeddings': embed # for metric
}
}
return retval
class MultiInput:
def __init__(self, connect_joint, center):
self.connect_joint = connect_joint
self.center = center
def __call__(self, data):
# T, V, C -> T, V, I=3, C + 2
T, V, C = data.shape
x_new = torch.zeros((T, V, 3, C + 2), device=data.device)
# Joints
x = data
x_new[:, :, 0, :C] = x
for i in range(V):
x_new[:, i, 0, C:] = x[:, i, :2] - x[:, self.center, :2]
# Velocity
for i in range(T - 2):
x_new[i, :, 1, :2] = x[i + 1, :, :2] - x[i, :, :2]
x_new[i, :, 1, 3:] = x[i + 2, :, :2] - x[i, :, :2]
x_new[:, :, 1, 3] = x[:, :, 2]
# Bones
for i in range(V):
x_new[:, i, 2, :2] = x[:, i, :2] - x[:, self.connect_joint[i], :2]
bone_length = 0
for i in range(C - 1):
bone_length += torch.pow(x_new[:, :, 2, i], 2)
bone_length = torch.sqrt(bone_length) + 0.0001
for i in range(C - 1):
x_new[:, :, 2, C+i] = torch.acos(x_new[:, :, 2, i] / bone_length)
x_new[:, :, 2, 3] = x[:, :, 2]
data = x_new
return data