From 6cd13064f3d842ffc3854a6d7efdbf3959094b9e Mon Sep 17 00:00:00 2001 From: crosstyan Date: Wed, 18 Jun 2025 10:35:23 +0800 Subject: [PATCH] fix: various - Added the `LeastMeanSquareVelocityFilter` to improve tracking velocity estimation using historical 3D poses. - Updated the `triangulate_one_point_from_multiple_views_linear` and `triangulate_points_from_multiple_views_linear` functions to enhance documentation and ensure proper handling of input parameters. - Refined the logic in triangulation functions to ensure correct handling of homogeneous coordinates. - Improved error handling in the `LastDifferenceVelocityFilter` to assert non-negative time deltas, enhancing robustness. --- app/tracking/__init__.py | 18 ++++++++++++++++-- playground.py | 25 ++++++++++++++++--------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index b56c2a3..3ea1550 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -114,6 +114,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter): def predict(self, timestamp: datetime) -> TrackingPrediction: delta_t_s = (timestamp - self._last_timestamp).total_seconds() + assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}" if self._last_velocity is None: return TrackingPrediction( velocity=None, @@ -127,6 +128,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter): def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: delta_t_s = (timestamp - self._last_timestamp).total_seconds() + assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}" self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s self._last_keypoints = keypoints self._last_timestamp = timestamp @@ -160,8 +162,20 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): historical_timestamps: Sequence[datetime], max_samples: int = 10, ): - assert len(historical_3d_poses) == len(historical_timestamps) + """ + Args: + historical_3d_poses: sequence of 3D poses, at least one element is required + historical_timestamps: sequence of timestamps, whose length is the same as `historical_3d_poses` + max_samples: maximum number of samples to keep + """ + assert (N := len(historical_3d_poses)) == len( + historical_timestamps + ), f"the length of `historical_3d_poses` and `historical_timestamps` must be the same; got {N} and {len(historical_timestamps)}" + if N < 1: + raise ValueError("at least one historical 3D pose is required") + temp = zip(historical_3d_poses, historical_timestamps) + # sorted by timestamp temp_sorted = sorted(temp, key=lambda x: x[1]) self._historical_3d_poses = deque( map(lambda x: x[0], temp_sorted), maxlen=max_samples @@ -187,6 +201,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): latest_timestamp = self._historical_timestamps[-1] delta_t_s = (timestamp - latest_timestamp).total_seconds() + assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}" if self._velocity is None: return TrackingPrediction( @@ -228,7 +243,6 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): keypoints_reshaped = keypoints.reshape(n_samples, -1) # Use JAX's lstsq to solve the least squares problem - # This is more numerically stable than manually computing pseudoinverse coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None) # Coefficients shape is [2, J*3] diff --git a/playground.py b/playground.py index e6ef055..b58424b 100644 --- a/playground.py +++ b/playground.py @@ -64,6 +64,7 @@ from app.tracking import ( TrackingID, AffinityResult, LastDifferenceVelocityFilter, + LeastMeanSquareVelocityFilter, Tracking, TrackingState, ) @@ -438,12 +439,12 @@ def triangulate_one_point_from_multiple_views_linear( ) -> Float[Array, "3"]: """ Args: - proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 - points: 形状为(N, 2)的点坐标序列 - confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] + proj_matrices: (N, 3, 4) projection matrices + points: (N, 2) image-coordinates per view + confidences: (N,) optional per-view confidences in [0,1] Returns: - point_3d: 形状为(3,)的三角测量得到的3D点 + (3,) 3D point """ assert len(proj_matrices) == len(points) @@ -469,7 +470,7 @@ def triangulate_one_point_from_multiple_views_linear( # replace the Python `if` with a jnp.where point_3d_homo = jnp.where( - point_3d_homo[3] < 0, # predicate (scalar bool tracer) + point_3d_homo[3] <= 0, # predicate (scalar bool tracer) -point_3d_homo, # if True point_3d_homo, # if False ) @@ -493,14 +494,14 @@ def triangulate_points_from_multiple_views_linear( confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: - (P, 3) 3D point for each of the P tracks + (P, 3) 3D point for each of the P """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: - conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) + conf = confidences # vectorize your one-point routine over P vmap_triangulate = jax.vmap( @@ -578,7 +579,7 @@ def triangulate_one_point_from_multiple_views_linear_time_weighted( # Ensure homogeneous coordinate is positive point_3d_homo = jnp.where( - point_3d_homo[3] < 0, + point_3d_homo[3] <= 0, -point_3d_homo, point_3d_homo, ) @@ -679,6 +680,8 @@ def group_by_cluster_by_camera( eld = r[el.camera.id] preserved = max([eld, el], key=lambda x: x.timestamp) r[el.camera.id] = preserved + else: + r[el.camera.id] = el return pmap(r) @@ -714,7 +717,11 @@ class GlobalTrackingState: tracking = Tracking( id=next_id, state=tracking_state, - velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp), + # velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp), + velocity_filter=LeastMeanSquareVelocityFilter( + historical_3d_poses=[kps_3d], + historical_timestamps=[latest_timestamp], + ), ) self._trackings[next_id] = tracking self._last_id = next_id