Update collate_fn.py
This commit is contained in:
@@ -13,9 +13,10 @@ class CollateFn(object):
|
|||||||
self.ordered = sample_type[1]
|
self.ordered = sample_type[1]
|
||||||
if self.sampler not in ['fixed', 'unfixed', 'all']:
|
if self.sampler not in ['fixed', 'unfixed', 'all']:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
if self.ordered not in ['ordered', 'unordered']:
|
if self.ordered not in ['ordered', 'unordered', 'allordered']:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
self.ordered = sample_type[1] == 'ordered'
|
self.ordered = sample_type[1] == 'ordered'
|
||||||
|
self.allordered = self.ordered and "all" in sample_type[1]
|
||||||
|
|
||||||
# fixed cases
|
# fixed cases
|
||||||
if self.sampler == 'fixed':
|
if self.sampler == 'fixed':
|
||||||
@@ -62,8 +63,16 @@ class CollateFn(object):
|
|||||||
else:
|
else:
|
||||||
frames_num = random.choice(
|
frames_num = random.choice(
|
||||||
list(range(self.frames_num_min, self.frames_num_max+1)))
|
list(range(self.frames_num_min, self.frames_num_max+1)))
|
||||||
|
if self.allordered:
|
||||||
|
replace = seq_len < frames_num
|
||||||
|
|
||||||
if self.ordered:
|
if seq_len == 0:
|
||||||
|
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
|
||||||
|
% (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
|
||||||
|
count += 1
|
||||||
|
indices = sorted(np.random.choice(
|
||||||
|
indices, frames_num, replace=replace))
|
||||||
|
elif self.ordered:
|
||||||
fs_n = frames_num + self.frames_skip_num
|
fs_n = frames_num + self.frames_skip_num
|
||||||
if seq_len < fs_n:
|
if seq_len < fs_n:
|
||||||
it = math.ceil(fs_n / seq_len)
|
it = math.ceil(fs_n / seq_len)
|
||||||
|
|||||||
Reference in New Issue
Block a user