support CCPG
This commit is contained in:
@@ -156,3 +156,66 @@ def evaluate_rank(distmat, p_lbls, g_lbls, max_rank=50):
|
||||
all_cmc = all_cmc.sum(0) / num_valid_p
|
||||
|
||||
return all_cmc, all_AP, all_INP
|
||||
|
||||
|
||||
def evaluate_many(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||
num_q, num_g = distmat.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
indices = np.argsort(distmat, axis=1) # 对应位置变成从小到大的序号
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(
|
||||
np.int32) # 根据indices调整顺序 g_pids[indices]
|
||||
# print(matches)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
all_INP = []
|
||||
num_valid_q = 0.
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
# binary vector, positions with value 1 are correct matches
|
||||
orig_cmc = matches[q_idx][keep]
|
||||
if not np.any(orig_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
|
||||
pos_idx = np.where(orig_cmc == 1)
|
||||
max_pos_idx = np.max(pos_idx)
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP.append(inp)
|
||||
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
|
||||
return all_cmc, mAP, mINP
|
||||
|
||||
Reference in New Issue
Block a user