fix sampler's bug

This commit is contained in:
wj1tr0y
2022-04-16 19:27:05 +08:00
committed by Junhao Liang
parent 0fb5c45272
commit a4ead0b40d
+6 -1
View File
@@ -1,4 +1,5 @@
import math
import random
import torch
import torch.distributed as dist
import torch.utils.data as tordata
@@ -49,7 +50,11 @@ class TripletSampler(tordata.sampler.Sampler):
def sync_random_sample_list(obj_list, k):
idx = torch.randperm(len(obj_list))[:k]
if len(obj_list) < k:
idx = random.choices(range(len(obj_list)), k=k)
idx = torch.tensor(idx)
else:
idx = torch.randperm(len(obj_list))[:k]
if torch.cuda.is_available():
idx = idx.cuda()
torch.distributed.broadcast(idx, src=0)