From c3caba81c85f2a845780aef83da2834c5bf81b22 Mon Sep 17 00:00:00 2001 From: darkliang <11710911@mail.sustech.edu.cn> Date: Tue, 21 Dec 2021 23:07:20 +0800 Subject: [PATCH] Fixed some hard to understand code --- lib/data/collate_fn.py | 4 +++- lib/data/sampler.py | 17 +++++++++-------- lib/modeling/modules.py | 6 +++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/lib/data/collate_fn.py b/lib/data/collate_fn.py index 389cf71..8917bf5 100644 --- a/lib/data/collate_fn.py +++ b/lib/data/collate_fn.py @@ -3,6 +3,7 @@ import random import numpy as np from utils import get_msg_mgr + class CollateFn(object): def __init__(self, label_set, sample_config): self.label_set = label_set @@ -34,6 +35,7 @@ class CollateFn(object): def __call__(self, batch): batch_size = len(batch) + # currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton. feature_num = len(batch[0][0]) seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], [] @@ -78,7 +80,7 @@ class CollateFn(object): if seq_len == 0: get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.' - %(str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count]))) + % (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count]))) count += 1 indices = np.random.choice( diff --git a/lib/data/sampler.py b/lib/data/sampler.py index 8849557..d91d3b6 100644 --- a/lib/data/sampler.py +++ b/lib/data/sampler.py @@ -14,7 +14,7 @@ class TripletSampler(tordata.sampler.Sampler): self.rank = dist.get_rank() def __iter__(self): - while (True): + while True: sample_indices = [] pid_list = sync_random_sample_list( self.dataset.label_set, self.batch_size[0]) @@ -29,10 +29,11 @@ class TripletSampler(tordata.sampler.Sampler): sample_indices = sync_random_sample_list( sample_indices, len(sample_indices)) - _ = self.batch_size[0] * self.batch_size[1] - total_size = int(math.ceil(_ / self.world_size) - ) * self.world_size - sample_indices += sample_indices[:(_ - len(sample_indices))] + total_batch_size = self.batch_size[0] * self.batch_size[1] + total_size = int(math.ceil(total_batch_size / + self.world_size)) * self.world_size + sample_indices += sample_indices[:( + total_batch_size - len(sample_indices))] sample_indices = sample_indices[self.rank:total_size:self.world_size] yield sample_indices @@ -66,10 +67,10 @@ class InferenceSampler(tordata.sampler.Sampler): world_size, batch_size)) if batch_size != 1: - _ = math.ceil(self.size / batch_size) * \ + complement_size = math.ceil(self.size / batch_size) * \ batch_size - indices += indices[:(_ - self.size)] - self.size = _ + indices += indices[:(complement_size - self.size)] + self.size = complement_size batch_size_per_rank = int(self.batch_size / world_size) indx_batch_per_rank = [] diff --git a/lib/modeling/modules.py b/lib/modeling/modules.py index 34c5663..a5d2b6b 100644 --- a/lib/modeling/modules.py +++ b/lib/modeling/modules.py @@ -43,9 +43,9 @@ class SetBlockWrapper(nn.Module): """ n, s, c, h, w = x.size() x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs) - _ = x.size() - _ = [n, s] + [*_[1:]] - return x.view(*_) + input_size = x.size() + output_size = [n, s] + [*input_size[1:]] + return x.view(*output_size) class PackSequenceWrapper(nn.Module):