update tutorial for hid 2023
This commit is contained in:
@@ -189,7 +189,7 @@ def GREW_submission(data, dataset, metric='euc'):
|
||||
return
|
||||
|
||||
|
||||
def HID_submission(data, dataset, metric='euc'):
|
||||
def HID_submission(data, dataset, rerank=True, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
msg_mgr.log_info("Evaluating HID")
|
||||
feature, label, seq_type = data['embeddings'], data['labels'], data['views']
|
||||
@@ -201,12 +201,16 @@ def HID_submission(data, dataset, metric='euc'):
|
||||
gallery_y = label[gallery_mask]
|
||||
probe_x = feature[probe_mask, :]
|
||||
probe_y = seq_type[probe_mask]
|
||||
|
||||
feat = np.concatenate([probe_x, gallery_x])
|
||||
dist = cuda_dist(feat, feat, metric).cpu().numpy()
|
||||
msg_mgr.log_info("Starting Re-ranking")
|
||||
re_rank = re_ranking(dist, probe_x.shape[0], k1=6, k2=6, lambda_value=0.3)
|
||||
idx = np.argsort(re_rank, axis=1)
|
||||
if rerank:
|
||||
feat = np.concatenate([probe_x, gallery_x])
|
||||
dist = cuda_dist(feat, feat, metric).cpu().numpy()
|
||||
msg_mgr.log_info("Starting Re-ranking")
|
||||
re_rank = re_ranking(
|
||||
dist, probe_x.shape[0], k1=6, k2=6, lambda_value=0.3)
|
||||
idx = np.argsort(re_rank, axis=1)
|
||||
else:
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.cpu().sort(1)[1].numpy()
|
||||
|
||||
save_path = os.path.join(
|
||||
"HID_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv")
|
||||
@@ -226,6 +230,7 @@ def evaluate_segmentation(data, dataset):
|
||||
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
|
||||
return {"scalar/test_accuracy/mIOU": miou}
|
||||
|
||||
|
||||
def evaluate_Gait3D(data, conf, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user