diff --git a/opengait/data/collate_fn.py b/opengait/data/collate_fn.py index 6767b04..a78c73d 100644 --- a/opengait/data/collate_fn.py +++ b/opengait/data/collate_fn.py @@ -13,9 +13,10 @@ class CollateFn(object): self.ordered = sample_type[1] if self.sampler not in ['fixed', 'unfixed', 'all']: raise ValueError - if self.ordered not in ['ordered', 'unordered']: + if self.ordered not in ['ordered', 'unordered', 'allordered']: raise ValueError self.ordered = sample_type[1] == 'ordered' + self.allordered = self.ordered and "all" in sample_type[1] # fixed cases if self.sampler == 'fixed': @@ -62,8 +63,16 @@ class CollateFn(object): else: frames_num = random.choice( list(range(self.frames_num_min, self.frames_num_max+1))) + if self.allordered: + replace = seq_len < frames_num - if self.ordered: + if seq_len == 0: + get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.' + % (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count]))) + count += 1 + indices = sorted(np.random.choice( + indices, frames_num, replace=replace)) + elif self.ordered: fs_n = frames_num + self.frames_skip_num if seq_len < fs_n: it = math.ceil(fs_n / seq_len)