116 lines
4.5 KiB
Python
116 lines
4.5 KiB
Python
import math
|
|
import random
|
|
import numpy as np
|
|
from utils import get_msg_mgr
|
|
|
|
|
|
class CollateFn(object):
|
|
def __init__(self, label_set, sample_config):
|
|
self.label_set = label_set
|
|
sample_type = sample_config['sample_type']
|
|
sample_type = sample_type.split('_')
|
|
self.sampler = sample_type[0]
|
|
self.ordered = sample_type[1]
|
|
if self.sampler not in ['fixed', 'unfixed', 'all']:
|
|
raise ValueError
|
|
if self.ordered not in ['ordered', 'unordered']:
|
|
raise ValueError
|
|
self.ordered = sample_type[1] == 'ordered'
|
|
|
|
# fixed cases
|
|
if self.sampler == 'fixed':
|
|
self.frames_num_fixed = sample_config['frames_num_fixed']
|
|
|
|
# unfixed cases
|
|
if self.sampler == 'unfixed':
|
|
self.frames_num_max = sample_config['frames_num_max']
|
|
self.frames_num_min = sample_config['frames_num_min']
|
|
|
|
if self.sampler != 'all' and self.ordered:
|
|
self.frames_skip_num = sample_config['frames_skip_num']
|
|
|
|
self.frames_all_limit = -1
|
|
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
|
|
self.frames_all_limit = sample_config['frames_all_limit']
|
|
|
|
def __call__(self, batch):
|
|
batch_size = len(batch)
|
|
# currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton.
|
|
feature_num = len(batch[0][0])
|
|
seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], []
|
|
|
|
for bt in batch:
|
|
seqs_batch.append(bt[0])
|
|
labs_batch.append(self.label_set.index(bt[1][0]))
|
|
typs_batch.append(bt[1][1])
|
|
vies_batch.append(bt[1][2])
|
|
|
|
global count
|
|
count = 0
|
|
|
|
def sample_frames(seqs):
|
|
global count
|
|
sampled_fras = [[] for i in range(feature_num)]
|
|
seq_len = len(seqs[0])
|
|
indices = list(range(seq_len))
|
|
|
|
if self.sampler in ['fixed', 'unfixed']:
|
|
if self.sampler == 'fixed':
|
|
frames_num = self.frames_num_fixed
|
|
else:
|
|
frames_num = random.choice(
|
|
list(range(self.frames_num_min, self.frames_num_max+1)))
|
|
|
|
if self.ordered:
|
|
fs_n = frames_num + self.frames_skip_num
|
|
if seq_len < fs_n:
|
|
it = math.ceil(fs_n / seq_len)
|
|
seq_len = seq_len * it
|
|
indices = indices * it
|
|
|
|
start = random.choice(list(range(0, seq_len - fs_n + 1)))
|
|
end = start + fs_n
|
|
idx_lst = list(range(seq_len))
|
|
idx_lst = idx_lst[start:end]
|
|
idx_lst = sorted(np.random.choice(
|
|
idx_lst, frames_num, replace=False))
|
|
indices = [indices[i] for i in idx_lst]
|
|
else:
|
|
replace = seq_len < frames_num
|
|
|
|
if seq_len == 0:
|
|
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
|
|
% (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
|
|
|
|
count += 1
|
|
indices = np.random.choice(
|
|
indices, frames_num, replace=replace)
|
|
|
|
for i in range(feature_num):
|
|
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
|
|
sampled_fras[i].append(seqs[i][j])
|
|
return sampled_fras
|
|
|
|
# f: feature_num
|
|
# b: batch_size
|
|
# p: batch_size_per_gpu
|
|
# g: gpus_num
|
|
fras_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f]
|
|
batch = [fras_batch, labs_batch, typs_batch, vies_batch, None]
|
|
|
|
if self.sampler == "fixed":
|
|
fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)]
|
|
for j in range(feature_num)] # [f, b]
|
|
else:
|
|
seqL_batch = [[len(fras_batch[i][0])
|
|
for i in range(batch_size)]] # [1, p]
|
|
|
|
def my_cat(k): return np.concatenate(
|
|
[fras_batch[i][k] for i in range(batch_size)], 0)
|
|
fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g]
|
|
|
|
batch[-1] = np.asarray(seqL_batch)
|
|
|
|
batch[0] = fras_batch
|
|
return batch
|