add Gait3D support
This commit is contained in:
@@ -3,7 +3,7 @@ from time import strftime, localtime
|
||||
import numpy as np
|
||||
from utils import get_msg_mgr, mkdir
|
||||
|
||||
from .metric import mean_iou, cuda_dist, compute_ACC_mAP
|
||||
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank
|
||||
from .re_rank import re_ranking
|
||||
|
||||
|
||||
@@ -225,3 +225,43 @@ def evaluate_segmentation(data, dataset):
|
||||
miou = mean_iou(pred, labels)
|
||||
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
|
||||
return {"scalar/test_accuracy/mIOU": miou}
|
||||
|
||||
def evaluate_Gait3D(data, conf, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
features, labels, cams, time_seqs = data['embeddings'], data['labels'], data['types'], data['views']
|
||||
import json
|
||||
probe_sets = json.load(
|
||||
open('./datasets/Gait3D/Gait3D.json', 'rb'))['PROBE_SET']
|
||||
probe_mask = []
|
||||
for id, ty, sq in zip(labels, cams, time_seqs):
|
||||
if '-'.join([id, ty, sq]) in probe_sets:
|
||||
probe_mask.append(True)
|
||||
else:
|
||||
probe_mask.append(False)
|
||||
probe_mask = np.array(probe_mask)
|
||||
|
||||
# probe_features = features[:probe_num]
|
||||
probe_features = features[probe_mask]
|
||||
# gallery_features = features[probe_num:]
|
||||
gallery_features = features[~probe_mask]
|
||||
# probe_lbls = np.asarray(labels[:probe_num])
|
||||
# gallery_lbls = np.asarray(labels[probe_num:])
|
||||
probe_lbls = np.asarray(labels)[probe_mask]
|
||||
gallery_lbls = np.asarray(labels)[~probe_mask]
|
||||
|
||||
results = {}
|
||||
msg_mgr.log_info(f"The test metric you choose is {metric}.")
|
||||
dist = cuda_dist(probe_features, gallery_features, metric).cpu().numpy()
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, probe_lbls, gallery_lbls)
|
||||
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
for r in [1, 5, 10]:
|
||||
results['scalar/test_accuracy/Rank-{}'.format(r)] = cmc[r - 1] * 100
|
||||
results['scalar/test_accuracy/mAP'] = mAP * 100
|
||||
results['scalar/test_accuracy/mINP'] = mINP * 100
|
||||
|
||||
# print_csv_format(dataset_name, results)
|
||||
msg_mgr.log_info(results)
|
||||
return results
|
||||
@@ -86,3 +86,73 @@ def compute_ACC_mAP(distmat, q_pids, g_pids, q_views=None, g_views=None, rank=1)
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return ACC, mAP
|
||||
|
||||
|
||||
def evaluate_rank(distmat, p_lbls, g_lbls, max_rank=50):
|
||||
'''
|
||||
Copy from https://github.com/Gait3D/Gait3D-Benchmark/blob/72beab994c137b902d826f4b9f9e95b107bebd78/lib/utils/rank.py#L12-L63
|
||||
'''
|
||||
num_p, num_g = distmat.shape
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
|
||||
matches = (g_lbls[indices] == p_lbls[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each probe
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
all_INP = []
|
||||
num_valid_p = 0. # number of valid probe
|
||||
|
||||
for p_idx in range(num_p):
|
||||
# compute cmc curve
|
||||
# binary vector, positions with value 1 are correct matches
|
||||
raw_cmc = matches[p_idx]
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when probe identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = raw_cmc.cumsum()
|
||||
|
||||
pos_idx = np.where(raw_cmc == 1) # 返回坐标,此处raw_cmc为一维矩阵,所以返回相当于index
|
||||
max_pos_idx = np.max(pos_idx)
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP.append(inp)
|
||||
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_p += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = raw_cmc.sum()
|
||||
pos_idx = np.where(raw_cmc == 1) # 返回坐标,此处raw_cmc为一维矩阵,所以返回相当于index
|
||||
max_pos_idx = np.max(pos_idx)
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP.append(inp)
|
||||
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_p += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
assert num_valid_p > 0, 'Error: all probe identities do not appear in gallery'
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_p
|
||||
|
||||
return all_cmc, all_AP, all_INP
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
'''
|
||||
Modifed from https://github.com/Gait3D/Gait3D-Benchmark/blob/72beab994c137b902d826f4b9f9e95b107bebd78/lib/modeling/models/smplgait.py
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
||||
|
||||
|
||||
class SMPLGait(BaseModel):
|
||||
def __init__(self, cfgs, is_training):
|
||||
super().__init__(cfgs, is_training)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
# Baseline
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
# for SMPL
|
||||
self.fc1 = nn.Linear(85, 128)
|
||||
self.fc2 = nn.Linear(128, 256)
|
||||
self.fc3 = nn.Linear(256, 256)
|
||||
self.bn1 = nn.BatchNorm1d(128)
|
||||
self.bn2 = nn.BatchNorm1d(256)
|
||||
self.bn3 = nn.BatchNorm1d(256)
|
||||
self.dropout2 = nn.Dropout(p=0.2)
|
||||
self.dropout3 = nn.Dropout(p=0.2)
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
smpls = ipts[1] # [n, s, d]
|
||||
|
||||
# extract SMPL features
|
||||
n, s, d = smpls.size()
|
||||
sps = smpls.view(-1, d)
|
||||
del smpls
|
||||
|
||||
sps = F.relu(self.bn1(self.fc1(sps)))
|
||||
sps = F.relu(self.bn2(self.dropout2(self.fc2(sps)))) # (B, 256)
|
||||
sps = F.relu(self.bn3(self.dropout3(self.fc3(sps)))) # (B, 256)
|
||||
sps = sps.reshape(n, 1, s, 16, 16)
|
||||
iden = Variable(torch.eye(16)).unsqueeze(
|
||||
0).repeat(n, 1, s, 1, 1) # [n, 1, s, 16, 16]
|
||||
if sps.is_cuda:
|
||||
iden = iden.cuda()
|
||||
sps_trans = sps + iden # [n, 1, s, 16, 16]
|
||||
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(1)
|
||||
|
||||
del ipts
|
||||
outs = self.Backbone(sils) # [n, c, s, h, w]
|
||||
outs_n, outs_c, outs_s, outs_h, outs_w = outs.size()
|
||||
|
||||
zero_tensor = Variable(torch.zeros(
|
||||
(outs_n, outs_c, outs_s, outs_h, outs_h-outs_w)))
|
||||
if outs.is_cuda:
|
||||
zero_tensor = zero_tensor.cuda()
|
||||
# [n, s, c, h, h] [n, s, c, 16, 16]
|
||||
outs = torch.cat([outs, zero_tensor], -1)
|
||||
outs = outs.reshape(outs_n*outs_c*outs_s, outs_h,
|
||||
outs_h) # [n*c*s, 16, 16]
|
||||
|
||||
sps = sps_trans.repeat(1, outs_c, 1, 1, 1).reshape(
|
||||
outs_n * outs_c * outs_s, 16, 16)
|
||||
|
||||
outs_trans = torch.bmm(outs, sps)
|
||||
outs_trans = outs_trans.reshape(outs_n, outs_c, outs_s, outs_h, outs_h)
|
||||
|
||||
# Temporal Pooling, TP
|
||||
outs_trans = self.TP(outs_trans, seqL, options={"dim": 2})[
|
||||
0] # [n, c, h, w]
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feat = self.HPP(outs_trans) # [n, c, p]
|
||||
embed_1 = self.FCs(feat) # [n, c, p]
|
||||
|
||||
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||
|
||||
n, _, s, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embed_1, 'labels': labs},
|
||||
'softmax': {'logits': logits, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed_1
|
||||
}
|
||||
}
|
||||
return retval
|
||||
Reference in New Issue
Block a user