GaitSSB@Pretrain release

This commit is contained in:
jdyjjj
2023-11-20 20:28:09 +08:00
parent 476c4adbe3
commit b24e797486
6 changed files with 506 additions and 4 deletions
+38
View File
@@ -134,3 +134,41 @@ class CommonSampler(tordata.sampler.Sampler):
def __len__(self):
return len(self.dataset)
# **************** For GaitSSB ****************
# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
import random
class BilateralSampler(tordata.sampler.Sampler):
def __init__(self, dataset, batch_size, batch_shuffle=False):
self.dataset = dataset
self.batch_size = batch_size
self.batch_shuffle = batch_shuffle
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.dataset_length = len(self.dataset)
self.total_indices = list(range(self.dataset_length))
def __iter__(self):
random.shuffle(self.total_indices)
count = 0
batch_size = self.batch_size[0] * self.batch_size[1]
while True:
if (count + 1) * batch_size >= self.dataset_length:
count = 0
random.shuffle(self.total_indices)
sampled_indices = self.total_indices[count*batch_size:(count+1)*batch_size]
sampled_indices = sync_random_sample_list(sampled_indices, len(sampled_indices))
total_size = int(math.ceil(batch_size / self.world_size)) * self.world_size
sampled_indices += sampled_indices[:(batch_size - len(sampled_indices))]
sampled_indices = sampled_indices[self.rank:total_size:self.world_size]
count += 1
yield sampled_indices * 2
def __len__(self):
return len(self.dataset)