lidargaitv2 open-source

This commit is contained in:
Noah
2025-06-11 14:43:19 +08:00
parent c42f2f8c07
commit 16a7c3f0bf
11 changed files with 6396 additions and 4 deletions
+11 -2
View File
@@ -33,6 +33,8 @@ class CollateFn(object):
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
self.frames_all_limit = sample_config['frames_all_limit']
self.points_in_use = sample_config.get('points_in_use')
def __call__(self, batch):
batch_size = len(batch)
# currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton.
@@ -88,7 +90,14 @@ class CollateFn(object):
for i in range(feature_num):
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
sampled_fras[i].append(seqs[i][j])
point_cloud_index = self.points_in_use.get('pointcloud_index')
if self.points_in_use is not None and point_cloud_index is not None and i == point_cloud_index:
points_num = self.points_in_use.get('points_num')
sample_points = (random.choices(range(len(seqs[i][j])), k=points_num)
if points_num is not None else list(range(len(seqs[i][j]))))
sampled_fras[i].append(np.asarray([seqs[i][j][p] for p in sample_points]))
else:
sampled_fras[i].append(seqs[i][j])
return sampled_fras
# f: feature_num
@@ -112,4 +121,4 @@ class CollateFn(object):
batch[-1] = np.asarray(seqL_batch)
batch[0] = fras_batch
return batch
return batch
+116 -1
View File
@@ -228,6 +228,121 @@ def get_transform(trf_cfg=None):
return transform
raise "Error type for -Transform-Cfg-"
# **************** For LidarGait++ ****************
# Shen, et al: LidarGait++: Learning Local Features and Size Awareness from LiDAR Point Clouds for 3D Gait Recognition, CVPR2025
def normalize_point_cloud(batch_data):
"""Normalize the batch data using coordinates of the block centered at origin.
Input:
batch_data: BxNxC array
Output:
BxNxC array
"""
centroids = np.mean(batch_data, axis=1, keepdims=True) # shape: (B, 1, C)
centered = batch_data - centroids
scales = np.max(np.linalg.norm(centered, axis=2), axis=1, keepdims=True) # shape: (B, 1)
scales = scales.reshape(batch_data.shape[0], 1, 1) # (B, 1, 1) for broadcasting
return centered / scales
def dropout_point_cloud(batch_data, max_dropout_ratio=0.875, prob=0.2):
"""Randomly drop points in each point cloud.
Input:
batch_data: BxNx3 array
Output:
BxNx3 array, with dropped points replaced by the first point in each cloud.
"""
if np.random.rand() >= prob:
return batch_data
B, N, C = batch_data.shape
# 为每个点云生成一个 dropout_ratio范围 0 ~ max_dropout_ratio
dropout_ratio = np.random.rand(B, 1) * max_dropout_ratio # shape: (B, 1)
random_matrix = np.random.rand(B, N)
drop_mask = random_matrix <= dropout_ratio # shape: (B, N)
# 构造每个点云第一个点重复 N 次的数组,用于替换被 dropout 的点
first_points = np.repeat(batch_data[:, :1, :], N, axis=1)
return np.where(drop_mask[..., None], first_points, batch_data)
def shift_point_cloud(batch_data, shift_range=0.1, prob=0.2):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
if np.random.rand() >= prob:
return batch_data
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B, N,3))
batch_data += shifts
return batch_data
def scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25, prob=0.2):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
if np.random.rand() >= prob:
return batch_data
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
def jitter_point_cloud(batch_data, std=0.01, clip=0.05, prob=0.2):
if np.random.rand() >= prob:
return batch_data
B, N, C = batch_data.shape
jittered_data = np.random.normal(loc=0.0, scale=std, size=(B, N, C))
jittered_data = np.clip(jittered_data, -clip, clip)
batch_data += jittered_data
return batch_data
def flip_point_cloud_y(batch_data, prob=0.25):
if np.random.rand() >= prob:
return batch_data
batch_data[:, :, 1] = -batch_data[:, :, 1]
return batch_data
def getxyz(batch_data,col = 2,to_ground=False):
B,N,C = batch_data.shape
last_col = batch_data[:, :, col]
result = last_col.reshape((B, N, 1))
if to_ground:
result -= result.min(axis=1,keepdims=True)
return result
class PointCloudsTransform():
def __init__(self, xyz_only=True, scale_aware=False, drop_prob=0, shift_prob=0, jit_prob=0,scale_prob=0, flip_prob=0):
self.scale_aware = scale_aware
self.xyz_only = xyz_only
self.flip_prob, self.shift_prob, self.jit_prob, self.scale_prob, self.drop_prob = flip_prob, shift_prob, jit_prob, scale_prob, drop_prob
def __call__(self, points):
if self.xyz_only:
points = points[:,:,:3]
heights = getxyz(points, col = 2, to_ground=True)
points = normalize_point_cloud(points)
points = flip_point_cloud_y(points, prob=self.flip_prob)
points = shift_point_cloud(points, prob=self.shift_prob)
points = jitter_point_cloud(points, prob=self.jit_prob)
points = scale_point_cloud(points, prob=self.scale_prob)
points = dropout_point_cloud(points, prob=self.drop_prob)
if self.scale_aware:
points = np.concatenate([points,heights],axis=-1)
return points
# **************** For GaitSSB ****************
# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
@@ -587,4 +702,4 @@ class MSGGTransform():
def __call__(self, x):
result=x[...,self.mask,:].copy()
return result
return result