lidargaitv2 open-source
This commit is contained in:
@@ -33,6 +33,8 @@ class CollateFn(object):
|
||||
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
|
||||
self.frames_all_limit = sample_config['frames_all_limit']
|
||||
|
||||
self.points_in_use = sample_config.get('points_in_use')
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size = len(batch)
|
||||
# currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton.
|
||||
@@ -88,7 +90,14 @@ class CollateFn(object):
|
||||
|
||||
for i in range(feature_num):
|
||||
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
|
||||
sampled_fras[i].append(seqs[i][j])
|
||||
point_cloud_index = self.points_in_use.get('pointcloud_index')
|
||||
if self.points_in_use is not None and point_cloud_index is not None and i == point_cloud_index:
|
||||
points_num = self.points_in_use.get('points_num')
|
||||
sample_points = (random.choices(range(len(seqs[i][j])), k=points_num)
|
||||
if points_num is not None else list(range(len(seqs[i][j]))))
|
||||
sampled_fras[i].append(np.asarray([seqs[i][j][p] for p in sample_points]))
|
||||
else:
|
||||
sampled_fras[i].append(seqs[i][j])
|
||||
return sampled_fras
|
||||
|
||||
# f: feature_num
|
||||
@@ -112,4 +121,4 @@ class CollateFn(object):
|
||||
batch[-1] = np.asarray(seqL_batch)
|
||||
|
||||
batch[0] = fras_batch
|
||||
return batch
|
||||
return batch
|
||||
Reference in New Issue
Block a user