lidargaitv2 open-source

This commit is contained in:
Noah
2025-06-11 14:43:19 +08:00
parent c42f2f8c07
commit 16a7c3f0bf
11 changed files with 6396 additions and 4 deletions
+11 -2
View File
@@ -33,6 +33,8 @@ class CollateFn(object):
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
self.frames_all_limit = sample_config['frames_all_limit']
self.points_in_use = sample_config.get('points_in_use')
def __call__(self, batch):
batch_size = len(batch)
# currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton.
@@ -88,7 +90,14 @@ class CollateFn(object):
for i in range(feature_num):
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
sampled_fras[i].append(seqs[i][j])
point_cloud_index = self.points_in_use.get('pointcloud_index')
if self.points_in_use is not None and point_cloud_index is not None and i == point_cloud_index:
points_num = self.points_in_use.get('points_num')
sample_points = (random.choices(range(len(seqs[i][j])), k=points_num)
if points_num is not None else list(range(len(seqs[i][j]))))
sampled_fras[i].append(np.asarray([seqs[i][j][p] for p in sample_points]))
else:
sampled_fras[i].append(seqs[i][j])
return sampled_fras
# f: feature_num
@@ -112,4 +121,4 @@ class CollateFn(object):
batch[-1] = np.asarray(seqL_batch)
batch[0] = fras_batch
return batch
return batch