2c29afadf3
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from ..base_model import BaseModel
|
|
from ..backbones.resgcn import ResGCN
|
|
from ..modules import Graph
|
|
import numpy as np
|
|
|
|
|
|
class GaitGraph2(BaseModel):
|
|
"""
|
|
GaitGraph2: Towards a Deeper Understanding of Skeleton-based Gait Recognition
|
|
Paper: https://openaccess.thecvf.com/content/CVPR2022W/Biometrics/papers/Teepe_Towards_a_Deeper_Understanding_of_Skeleton-Based_Gait_Recognition_CVPRW_2022_paper
|
|
Github: https://github.com/tteepe/GaitGraph2
|
|
"""
|
|
def build_network(self, model_cfg):
|
|
|
|
self.joint_format = model_cfg['joint_format']
|
|
self.input_num = model_cfg['input_num']
|
|
self.block = model_cfg['block']
|
|
self.input_branch = model_cfg['input_branch']
|
|
self.main_stream = model_cfg['main_stream']
|
|
self.num_class = model_cfg['num_class']
|
|
self.reduction = model_cfg['reduction']
|
|
self.tta = model_cfg['tta']
|
|
## Graph Init ##
|
|
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
|
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
|
|
## Network ##
|
|
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
|
|
main_stream=self.main_stream, num_class=self.num_class,
|
|
reduction=self.reduction, block=self.block,graph=self.A)
|
|
|
|
def forward(self, inputs):
|
|
|
|
ipts, labs, type_, view_, seqL = inputs
|
|
x_input = ipts[0]
|
|
N, T, V, I, C = x_input.size()
|
|
pose = x_input
|
|
flip_idx = self.graph.flip_idx
|
|
|
|
if not self.training and self.tta:
|
|
multi_input = MultiInput(self.graph.connect_joint, self.graph.center)
|
|
x1 = []
|
|
x2 = []
|
|
for i in range(N):
|
|
x1.append(multi_input(x_input[i,:,:,0,:3].flip(0)))
|
|
x2.append(multi_input(x_input[i,:,flip_idx,0,:3]))
|
|
x_input = torch.cat([x_input, torch.stack(x1,0), torch.stack(x2,0)], dim=0)
|
|
|
|
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
|
|
|
|
# resgcn
|
|
x = self.ResGCN(x)
|
|
|
|
if not self.training and self.tta:
|
|
f1, f2, f3 = torch.split(x, [N, N, N], dim=0)
|
|
x = torch.cat((f1, f2, f3), dim=1)
|
|
|
|
embed = torch.unsqueeze(x,-1)
|
|
|
|
retval = {
|
|
'training_feat': {
|
|
'SupConLoss': {'features': x , 'labels': labs}, # loss
|
|
},
|
|
'visual_summary': {
|
|
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': embed # for metric
|
|
}
|
|
}
|
|
return retval
|
|
|
|
class MultiInput:
|
|
def __init__(self, connect_joint, center):
|
|
self.connect_joint = connect_joint
|
|
self.center = center
|
|
|
|
def __call__(self, data):
|
|
|
|
# T, V, C -> T, V, I=3, C + 2
|
|
T, V, C = data.shape
|
|
x_new = torch.zeros((T, V, 3, C + 2), device=data.device)
|
|
|
|
# Joints
|
|
x = data
|
|
x_new[:, :, 0, :C] = x
|
|
for i in range(V):
|
|
x_new[:, i, 0, C:] = x[:, i, :2] - x[:, self.center, :2]
|
|
|
|
# Velocity
|
|
for i in range(T - 2):
|
|
x_new[i, :, 1, :2] = x[i + 1, :, :2] - x[i, :, :2]
|
|
x_new[i, :, 1, 3:] = x[i + 2, :, :2] - x[i, :, :2]
|
|
x_new[:, :, 1, 3] = x[:, :, 2]
|
|
|
|
# Bones
|
|
for i in range(V):
|
|
x_new[:, i, 2, :2] = x[:, i, :2] - x[:, self.connect_joint[i], :2]
|
|
bone_length = 0
|
|
for i in range(C - 1):
|
|
bone_length += torch.pow(x_new[:, :, 2, i], 2)
|
|
bone_length = torch.sqrt(bone_length) + 0.0001
|
|
for i in range(C - 1):
|
|
x_new[:, :, 2, C+i] = torch.acos(x_new[:, :, 2, i] / bone_length)
|
|
x_new[:, :, 2, 3] = x[:, :, 2]
|
|
|
|
data = x_new
|
|
return data
|
|
|