Update collate_fn.py
This commit is contained in:
@@ -13,9 +13,10 @@ class CollateFn(object):
|
||||
self.ordered = sample_type[1]
|
||||
if self.sampler not in ['fixed', 'unfixed', 'all']:
|
||||
raise ValueError
|
||||
if self.ordered not in ['ordered', 'unordered']:
|
||||
if self.ordered not in ['ordered', 'unordered', 'allordered']:
|
||||
raise ValueError
|
||||
self.ordered = sample_type[1] == 'ordered'
|
||||
self.allordered = self.ordered and "all" in sample_type[1]
|
||||
|
||||
# fixed cases
|
||||
if self.sampler == 'fixed':
|
||||
@@ -62,8 +63,16 @@ class CollateFn(object):
|
||||
else:
|
||||
frames_num = random.choice(
|
||||
list(range(self.frames_num_min, self.frames_num_max+1)))
|
||||
if self.allordered:
|
||||
replace = seq_len < frames_num
|
||||
|
||||
if self.ordered:
|
||||
if seq_len == 0:
|
||||
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
|
||||
% (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
|
||||
count += 1
|
||||
indices = sorted(np.random.choice(
|
||||
indices, frames_num, replace=replace))
|
||||
elif self.ordered:
|
||||
fs_n = frames_num + self.frames_skip_num
|
||||
if seq_len < fs_n:
|
||||
it = math.ceil(fs_n / seq_len)
|
||||
|
||||
Reference in New Issue
Block a user