fix sampler's bug
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.utils.data as tordata
|
import torch.utils.data as tordata
|
||||||
@@ -49,6 +50,10 @@ class TripletSampler(tordata.sampler.Sampler):
|
|||||||
|
|
||||||
|
|
||||||
def sync_random_sample_list(obj_list, k):
|
def sync_random_sample_list(obj_list, k):
|
||||||
|
if len(obj_list) < k:
|
||||||
|
idx = random.choices(range(len(obj_list)), k=k)
|
||||||
|
idx = torch.tensor(idx)
|
||||||
|
else:
|
||||||
idx = torch.randperm(len(obj_list))[:k]
|
idx = torch.randperm(len(obj_list))[:k]
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
idx = idx.cuda()
|
idx = idx.cuda()
|
||||||
|
|||||||
Reference in New Issue
Block a user