GaitSSB@Pretrain release
This commit is contained in:
@@ -134,3 +134,41 @@ class CommonSampler(tordata.sampler.Sampler):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
# **************** For GaitSSB ****************
|
||||
# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
|
||||
import random
|
||||
class BilateralSampler(tordata.sampler.Sampler):
|
||||
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.batch_shuffle = batch_shuffle
|
||||
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank = dist.get_rank()
|
||||
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.total_indices = list(range(self.dataset_length))
|
||||
|
||||
def __iter__(self):
|
||||
random.shuffle(self.total_indices)
|
||||
count = 0
|
||||
batch_size = self.batch_size[0] * self.batch_size[1]
|
||||
while True:
|
||||
if (count + 1) * batch_size >= self.dataset_length:
|
||||
count = 0
|
||||
random.shuffle(self.total_indices)
|
||||
|
||||
sampled_indices = self.total_indices[count*batch_size:(count+1)*batch_size]
|
||||
sampled_indices = sync_random_sample_list(sampled_indices, len(sampled_indices))
|
||||
|
||||
total_size = int(math.ceil(batch_size / self.world_size)) * self.world_size
|
||||
sampled_indices += sampled_indices[:(batch_size - len(sampled_indices))]
|
||||
|
||||
sampled_indices = sampled_indices[self.rank:total_size:self.world_size]
|
||||
count += 1
|
||||
|
||||
yield sampled_indices * 2
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
Reference in New Issue
Block a user