Support skeleton (#155)
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
This commit is contained in:
@@ -253,3 +253,443 @@ def RmBN2dAffine(model):
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.requires_grad = False
|
||||
m.bias.requires_grad = False
|
||||
|
||||
|
||||
'''
|
||||
Modifed from https://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/components/units
|
||||
'''
|
||||
|
||||
class Graph():
|
||||
"""
|
||||
# Thanks to YAN Sijie for the released code on Github (https://github.com/yysijie/st-gcn)
|
||||
"""
|
||||
def __init__(self, joint_format='coco', max_hop=2, dilation=1):
|
||||
self.joint_format = joint_format
|
||||
self.max_hop = max_hop
|
||||
self.dilation = dilation
|
||||
|
||||
# get edges
|
||||
self.num_node, self.edge, self.connect_joint, self.parts = self._get_edge()
|
||||
|
||||
# get adjacency matrix
|
||||
self.A = self._get_adjacency()
|
||||
|
||||
def __str__(self):
|
||||
return self.A
|
||||
|
||||
def _get_edge(self):
|
||||
if self.joint_format == 'coco':
|
||||
# keypoints = {
|
||||
# 0: "nose",
|
||||
# 1: "left_eye",
|
||||
# 2: "right_eye",
|
||||
# 3: "left_ear",
|
||||
# 4: "right_ear",
|
||||
# 5: "left_shoulder",
|
||||
# 6: "right_shoulder",
|
||||
# 7: "left_elbow",
|
||||
# 8: "right_elbow",
|
||||
# 9: "left_wrist",
|
||||
# 10: "right_wrist",
|
||||
# 11: "left_hip",
|
||||
# 12: "right_hip",
|
||||
# 13: "left_knee",
|
||||
# 14: "right_knee",
|
||||
# 15: "left_ankle",
|
||||
# 16: "right_ankle"
|
||||
# }
|
||||
num_node = 17
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
neighbor_link = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 6),
|
||||
(5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12), (11, 12),
|
||||
(11, 13), (13, 15), (12, 14), (14, 16)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 0
|
||||
self.flip_idx = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
||||
connect_joint = np.array([5,0,0,1,2,0,0,5,6,7,8,5,6,11,12,13,14])
|
||||
parts = [
|
||||
np.array([5, 7, 9]), # left_arm
|
||||
np.array([6, 8, 10]), # right_arm
|
||||
np.array([11, 13, 15]), # left_leg
|
||||
np.array([12, 14, 16]), # right_leg
|
||||
np.array([0, 1, 2, 3, 4]), # head
|
||||
]
|
||||
|
||||
elif self.joint_format == 'coco-no-head':
|
||||
num_node = 12
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
neighbor_link = [(0, 1),
|
||||
(0, 2), (2, 4), (1, 3), (3, 5), (0, 6), (1, 7), (6, 7),
|
||||
(6, 8), (8, 10), (7, 9), (9, 11)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 0
|
||||
connect_joint = np.array([3,1,0,2,4,0,6,8,10,7,9,11])
|
||||
parts =[
|
||||
np.array([0, 2, 4]), # left_arm
|
||||
np.array([1, 3, 5]), # right_arm
|
||||
np.array([6, 8, 10]), # left_leg
|
||||
np.array([7, 9, 11]) # right_leg
|
||||
]
|
||||
|
||||
elif self.joint_format =='alphapose' or self.joint_format =='openpose':
|
||||
num_node = 18
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
neighbor_link = [(0, 1), (0, 14), (0, 15), (14, 16), (15, 17),
|
||||
(1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7),
|
||||
(1, 8), (8, 9), (9, 10), (1, 11), (11, 12), (12, 13)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 1
|
||||
self.flip_idx = [0, 1, 5, 6, 7, 2, 3, 4, 11, 12, 13, 8, 9, 10, 15, 14, 17, 16]
|
||||
connect_joint = np.array([1,1,1,2,3,1,5,6,2,8,9,5,11,12,0,0,14,15])
|
||||
parts = [
|
||||
np.array([5, 6, 7]), # left_arm
|
||||
np.array([2, 3, 4]), # right_arm
|
||||
np.array([11, 12, 13]), # left_leg
|
||||
np.array([8, 9, 10]), # right_leg
|
||||
np.array([0, 1, 14, 15, 16, 17]), # head
|
||||
]
|
||||
|
||||
else:
|
||||
num_node, neighbor_link, connect_joint, parts = 0, [], [], []
|
||||
logging.info('')
|
||||
logging.error('Error: Do NOT exist this dataset: {}!'.format(self.dataset))
|
||||
raise ValueError()
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
edge = self_link + neighbor_link
|
||||
return num_node, edge, connect_joint, parts
|
||||
|
||||
def _get_hop_distance(self):
|
||||
A = np.zeros((self.num_node, self.num_node))
|
||||
for i, j in self.edge:
|
||||
A[j, i] = 1
|
||||
A[i, j] = 1
|
||||
hop_dis = np.zeros((self.num_node, self.num_node)) + np.inf
|
||||
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(self.max_hop + 1)]
|
||||
arrive_mat = (np.stack(transfer_mat) > 0)
|
||||
for d in range(self.max_hop, -1, -1):
|
||||
hop_dis[arrive_mat[d]] = d
|
||||
return hop_dis
|
||||
|
||||
def _get_adjacency(self):
|
||||
hop_dis = self._get_hop_distance()
|
||||
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
||||
adjacency = np.zeros((self.num_node, self.num_node))
|
||||
for hop in valid_hop:
|
||||
adjacency[hop_dis == hop] = 1
|
||||
normalize_adjacency = self._normalize_digraph(adjacency)
|
||||
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
||||
for i, hop in enumerate(valid_hop):
|
||||
A[i][hop_dis == hop] = normalize_adjacency[hop_dis == hop]
|
||||
return A
|
||||
|
||||
def _normalize_digraph(self, A):
|
||||
Dl = np.sum(A, 0)
|
||||
num_node = A.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-1)
|
||||
AD = np.dot(A, Dn)
|
||||
return AD
|
||||
|
||||
|
||||
class TemporalBasicBlock(nn.Module):
|
||||
"""
|
||||
TemporalConv_Res_Block
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, channels, temporal_window_size, stride=1, residual=False,reduction=0,get_res=False,tcn_stride=False):
|
||||
super(TemporalBasicBlock, self).__init__()
|
||||
|
||||
padding = ((temporal_window_size - 1) // 2, 0)
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif stride == 1:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, (stride,1)),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
|
||||
self.conv = nn.Conv2d(channels, channels, (temporal_window_size,1), (stride,1), padding)
|
||||
self.bn = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, res_module):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x + res_block + res_module)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TemporalBottleneckBlock(nn.Module):
|
||||
"""
|
||||
TemporalConv_Res_Bottleneck
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, channels, temporal_window_size, stride=1, residual=False, reduction=4,get_res=False, tcn_stride=False):
|
||||
super(TemporalBottleneckBlock, self).__init__()
|
||||
tcn_stride =False
|
||||
padding = ((temporal_window_size - 1) // 2, 0)
|
||||
inter_channels = channels // reduction
|
||||
if get_res:
|
||||
if tcn_stride:
|
||||
stride =2
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, (2,1)),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
tcn_stride= True
|
||||
else:
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif stride == 1:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, (2,1)),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
tcn_stride= True
|
||||
|
||||
self.conv_down = nn.Conv2d(channels, inter_channels, 1)
|
||||
self.bn_down = nn.BatchNorm2d(inter_channels)
|
||||
if tcn_stride:
|
||||
stride=2
|
||||
self.conv = nn.Conv2d(inter_channels, inter_channels, (temporal_window_size,1), (stride,1), padding)
|
||||
self.bn = nn.BatchNorm2d(inter_channels)
|
||||
self.conv_up = nn.Conv2d(inter_channels, channels, 1)
|
||||
self.bn_up = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, res_module):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv_down(x)
|
||||
x = self.bn_down(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv_up(x)
|
||||
x = self.bn_up(x)
|
||||
x = self.relu(x + res_block + res_module)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SpatialGraphConv(nn.Module):
|
||||
"""
|
||||
SpatialGraphConv_Basic_Block
|
||||
Arxiv: https://arxiv.org/abs/1801.07455
|
||||
Github: https://github.com/yysijie/st-gcn
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, max_graph_distance):
|
||||
super(SpatialGraphConv, self).__init__()
|
||||
|
||||
# spatial class number (distance = 0 for class 0, distance = 1 for class 1, ...)
|
||||
self.s_kernel_size = max_graph_distance + 1
|
||||
|
||||
# weights of different spatial classes
|
||||
self.gcn = nn.Conv2d(in_channels, out_channels*self.s_kernel_size, 1)
|
||||
|
||||
def forward(self, x, A):
|
||||
|
||||
# numbers in same class have same weight
|
||||
x = self.gcn(x)
|
||||
|
||||
# divide nodes into different classes
|
||||
n, kc, t, v = x.size()
|
||||
x = x.view(n, self.s_kernel_size, kc//self.s_kernel_size, t, v).contiguous()
|
||||
|
||||
# spatial graph convolution
|
||||
x = torch.einsum('nkctv,kvw->nctw', (x, A[:self.s_kernel_size])).contiguous()
|
||||
|
||||
return x
|
||||
|
||||
class SpatialBasicBlock(nn.Module):
|
||||
"""
|
||||
SpatialGraphConv_Res_Block
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, max_graph_distance, residual=False,reduction=0):
|
||||
super(SpatialBasicBlock, self).__init__()
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif in_channels == out_channels:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
self.conv = SpatialGraphConv(in_channels, out_channels, max_graph_distance)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, A):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv(x, A)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x + res_block)
|
||||
|
||||
return x
|
||||
|
||||
class SpatialBottleneckBlock(nn.Module):
|
||||
"""
|
||||
SpatialGraphConv_Res_Bottleneck
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, max_graph_distance, residual=False, reduction=4):
|
||||
super(SpatialBottleneckBlock, self).__init__()
|
||||
|
||||
inter_channels = out_channels // reduction
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif in_channels == out_channels:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
self.conv_down = nn.Conv2d(in_channels, inter_channels, 1)
|
||||
self.bn_down = nn.BatchNorm2d(inter_channels)
|
||||
self.conv = SpatialGraphConv(inter_channels, inter_channels, max_graph_distance)
|
||||
self.bn = nn.BatchNorm2d(inter_channels)
|
||||
self.conv_up = nn.Conv2d(inter_channels, out_channels, 1)
|
||||
self.bn_up = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, A):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv_down(x)
|
||||
x = self.bn_down(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv(x, A)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv_up(x)
|
||||
x = self.bn_up(x)
|
||||
x = self.relu(x + res_block)
|
||||
|
||||
return x
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""
|
||||
This class implements Spatial Transformer.
|
||||
Function adapted from: https://github.com/leaderj1001/Attention-Augmented-Conv2d
|
||||
"""
|
||||
def __init__(self, in_channels, out_channel, A, num_point, dk_factor=0.25, kernel_size=1, Nh=8, num=4, stride=1):
|
||||
super(SpatialAttention, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dk = int(dk_factor * out_channel)
|
||||
self.dv = int(out_channel)
|
||||
self.num = num
|
||||
self.Nh = Nh
|
||||
self.num_point=num_point
|
||||
self.A = A[0] + A[1] + A[2]
|
||||
self.stride = stride
|
||||
self.padding = (self.kernel_size - 1) // 2
|
||||
|
||||
assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
|
||||
assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
|
||||
assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
|
||||
assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."
|
||||
|
||||
self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding)
|
||||
|
||||
self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
# Input x
|
||||
# (batch_size, channels, 1, joints)
|
||||
B, _, T, V = x.size()
|
||||
|
||||
# flat_q, flat_k, flat_v
|
||||
# (batch_size, Nh, dvh or dkh, joints)
|
||||
# dvh = dv / Nh, dkh = dk / Nh
|
||||
# q, k, v obtained by doing 2D convolution on the input (q=XWq, k=XWk, v=XWv)
|
||||
flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
|
||||
|
||||
# Calculate the scores, obtained by doing q*k
|
||||
# (batch_size, Nh, joints, dkh)*(batch_size, Nh, dkh, joints) = (batch_size, Nh, joints,joints)
|
||||
# The multiplication can also be divided (multi_matmul) in case of space problems
|
||||
|
||||
logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
|
||||
|
||||
weights = F.softmax(logits, dim=-1)
|
||||
|
||||
# attn_out
|
||||
# (batch, Nh, joints, dvh)
|
||||
# weights*V
|
||||
# (batch, Nh, joints, joints)*(batch, Nh, joints, dvh)=(batch, Nh, joints, dvh)
|
||||
attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
|
||||
|
||||
attn_out = torch.reshape(attn_out, (B, self.Nh, T, V, self.dv // self.Nh))
|
||||
|
||||
attn_out = attn_out.permute(0, 1, 4, 2, 3)
|
||||
|
||||
# combine_heads_2d, combine heads only after having calculated each Z separately
|
||||
# (batch, Nh*dv, 1, joints)
|
||||
attn_out = self.combine_heads_2d(attn_out)
|
||||
|
||||
# Multiply for W0 (batch, out_channels, 1, joints) with out_channels=dv
|
||||
attn_out = self.attn_out(attn_out)
|
||||
return attn_out
|
||||
|
||||
def compute_flat_qkv(self, x, dk, dv, Nh):
|
||||
qkv = self.qkv_conv(x)
|
||||
# T=1 in this case, because we are considering each frame separately
|
||||
N, _, T, V = qkv.size()
|
||||
|
||||
q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
|
||||
q = self.split_heads_2d(q, Nh)
|
||||
k = self.split_heads_2d(k, Nh)
|
||||
v = self.split_heads_2d(v, Nh)
|
||||
|
||||
dkh = dk // Nh
|
||||
q = q*(dkh ** -0.5)
|
||||
flat_q = torch.reshape(q, (N, Nh, dkh, T * V))
|
||||
flat_k = torch.reshape(k, (N, Nh, dkh, T * V))
|
||||
flat_v = torch.reshape(v, (N, Nh, dv // self.Nh, T * V))
|
||||
return flat_q, flat_k, flat_v, q, k, v
|
||||
|
||||
def split_heads_2d(self, x, Nh):
|
||||
B, channels, T, V = x.size()
|
||||
ret_shape = (B, Nh, channels // Nh, T, V)
|
||||
split = torch.reshape(x, ret_shape)
|
||||
return split
|
||||
|
||||
def combine_heads_2d(self, x):
|
||||
batch, Nh, dv, T, V = x.size()
|
||||
ret_shape = (batch, Nh * dv, T, V)
|
||||
return torch.reshape(x, ret_shape)
|
||||
|
||||
Reference in New Issue
Block a user