fix sampler's bug
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data as tordata
|
||||
@@ -49,7 +50,11 @@ class TripletSampler(tordata.sampler.Sampler):
|
||||
|
||||
|
||||
def sync_random_sample_list(obj_list, k):
|
||||
idx = torch.randperm(len(obj_list))[:k]
|
||||
if len(obj_list) < k:
|
||||
idx = random.choices(range(len(obj_list)), k=k)
|
||||
idx = torch.tensor(idx)
|
||||
else:
|
||||
idx = torch.randperm(len(obj_list))[:k]
|
||||
if torch.cuda.is_available():
|
||||
idx = idx.cuda()
|
||||
torch.distributed.broadcast(idx, src=0)
|
||||
|
||||
Reference in New Issue
Block a user