fix sampler's bug

This commit is contained in:
wj1tr0y
2022-04-16 19:27:05 +08:00
committed by Junhao Liang
parent 0fb5c45272
commit a4ead0b40d
+5
View File
@@ -1,4 +1,5 @@
import math
import random
import torch
import torch.distributed as dist
import torch.utils.data as tordata
@@ -49,6 +50,10 @@ class TripletSampler(tordata.sampler.Sampler):
def sync_random_sample_list(obj_list, k):
if len(obj_list) < k:
idx = random.choices(range(len(obj_list)), k=k)
idx = torch.tensor(idx)
else:
idx = torch.randperm(len(obj_list))[:k]
if torch.cuda.is_available():
idx = idx.cuda()