Update collate_fn.py

This commit is contained in:
Dongyang Jin
2025-08-14 23:21:32 +08:00
committed by GitHub
parent f27576ef98
commit 23ac81dc32
+11 -2
View File
@@ -13,9 +13,10 @@ class CollateFn(object):
self.ordered = sample_type[1] self.ordered = sample_type[1]
if self.sampler not in ['fixed', 'unfixed', 'all']: if self.sampler not in ['fixed', 'unfixed', 'all']:
raise ValueError raise ValueError
if self.ordered not in ['ordered', 'unordered']: if self.ordered not in ['ordered', 'unordered', 'allordered']:
raise ValueError raise ValueError
self.ordered = sample_type[1] == 'ordered' self.ordered = sample_type[1] == 'ordered'
self.allordered = self.ordered and "all" in sample_type[1]
# fixed cases # fixed cases
if self.sampler == 'fixed': if self.sampler == 'fixed':
@@ -62,8 +63,16 @@ class CollateFn(object):
else: else:
frames_num = random.choice( frames_num = random.choice(
list(range(self.frames_num_min, self.frames_num_max+1))) list(range(self.frames_num_min, self.frames_num_max+1)))
if self.allordered:
replace = seq_len < frames_num
if self.ordered: if seq_len == 0:
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
% (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
count += 1
indices = sorted(np.random.choice(
indices, frames_num, replace=replace))
elif self.ordered:
fs_n = frames_num + self.frames_skip_num fs_n = frames_num + self.frames_skip_num
if seq_len < fs_n: if seq_len < fs_n:
it = math.ceil(fs_n / seq_len) it = math.ceil(fs_n / seq_len)