fix: various
- Added the `LeastMeanSquareVelocityFilter` to improve tracking velocity estimation using historical 3D poses. - Updated the `triangulate_one_point_from_multiple_views_linear` and `triangulate_points_from_multiple_views_linear` functions to enhance documentation and ensure proper handling of input parameters. - Refined the logic in triangulation functions to ensure correct handling of homogeneous coordinates. - Improved error handling in the `LastDifferenceVelocityFilter` to assert non-negative time deltas, enhancing robustness.
This commit is contained in:
@ -114,6 +114,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|||||||
|
|
||||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}"
|
||||||
if self._last_velocity is None:
|
if self._last_velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=None,
|
velocity=None,
|
||||||
@ -127,6 +128,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|||||||
|
|
||||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}"
|
||||||
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
|
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
|
||||||
self._last_keypoints = keypoints
|
self._last_keypoints = keypoints
|
||||||
self._last_timestamp = timestamp
|
self._last_timestamp = timestamp
|
||||||
@ -160,8 +162,20 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
historical_timestamps: Sequence[datetime],
|
historical_timestamps: Sequence[datetime],
|
||||||
max_samples: int = 10,
|
max_samples: int = 10,
|
||||||
):
|
):
|
||||||
assert len(historical_3d_poses) == len(historical_timestamps)
|
"""
|
||||||
|
Args:
|
||||||
|
historical_3d_poses: sequence of 3D poses, at least one element is required
|
||||||
|
historical_timestamps: sequence of timestamps, whose length is the same as `historical_3d_poses`
|
||||||
|
max_samples: maximum number of samples to keep
|
||||||
|
"""
|
||||||
|
assert (N := len(historical_3d_poses)) == len(
|
||||||
|
historical_timestamps
|
||||||
|
), f"the length of `historical_3d_poses` and `historical_timestamps` must be the same; got {N} and {len(historical_timestamps)}"
|
||||||
|
if N < 1:
|
||||||
|
raise ValueError("at least one historical 3D pose is required")
|
||||||
|
|
||||||
temp = zip(historical_3d_poses, historical_timestamps)
|
temp = zip(historical_3d_poses, historical_timestamps)
|
||||||
|
# sorted by timestamp
|
||||||
temp_sorted = sorted(temp, key=lambda x: x[1])
|
temp_sorted = sorted(temp, key=lambda x: x[1])
|
||||||
self._historical_3d_poses = deque(
|
self._historical_3d_poses = deque(
|
||||||
map(lambda x: x[0], temp_sorted), maxlen=max_samples
|
map(lambda x: x[0], temp_sorted), maxlen=max_samples
|
||||||
@ -187,6 +201,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
latest_timestamp = self._historical_timestamps[-1]
|
latest_timestamp = self._historical_timestamps[-1]
|
||||||
|
|
||||||
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
||||||
|
assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}"
|
||||||
|
|
||||||
if self._velocity is None:
|
if self._velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
@ -228,7 +243,6 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
keypoints_reshaped = keypoints.reshape(n_samples, -1)
|
keypoints_reshaped = keypoints.reshape(n_samples, -1)
|
||||||
|
|
||||||
# Use JAX's lstsq to solve the least squares problem
|
# Use JAX's lstsq to solve the least squares problem
|
||||||
# This is more numerically stable than manually computing pseudoinverse
|
|
||||||
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
|
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
|
||||||
|
|
||||||
# Coefficients shape is [2, J*3]
|
# Coefficients shape is [2, J*3]
|
||||||
|
|||||||
@ -64,6 +64,7 @@ from app.tracking import (
|
|||||||
TrackingID,
|
TrackingID,
|
||||||
AffinityResult,
|
AffinityResult,
|
||||||
LastDifferenceVelocityFilter,
|
LastDifferenceVelocityFilter,
|
||||||
|
LeastMeanSquareVelocityFilter,
|
||||||
Tracking,
|
Tracking,
|
||||||
TrackingState,
|
TrackingState,
|
||||||
)
|
)
|
||||||
@ -438,12 +439,12 @@ def triangulate_one_point_from_multiple_views_linear(
|
|||||||
) -> Float[Array, "3"]:
|
) -> Float[Array, "3"]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
|
proj_matrices: (N, 3, 4) projection matrices
|
||||||
points: 形状为(N, 2)的点坐标序列
|
points: (N, 2) image-coordinates per view
|
||||||
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
|
confidences: (N,) optional per-view confidences in [0,1]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
point_3d: 形状为(3,)的三角测量得到的3D点
|
(3,) 3D point
|
||||||
"""
|
"""
|
||||||
assert len(proj_matrices) == len(points)
|
assert len(proj_matrices) == len(points)
|
||||||
|
|
||||||
@ -469,7 +470,7 @@ def triangulate_one_point_from_multiple_views_linear(
|
|||||||
|
|
||||||
# replace the Python `if` with a jnp.where
|
# replace the Python `if` with a jnp.where
|
||||||
point_3d_homo = jnp.where(
|
point_3d_homo = jnp.where(
|
||||||
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
|
point_3d_homo[3] <= 0, # predicate (scalar bool tracer)
|
||||||
-point_3d_homo, # if True
|
-point_3d_homo, # if True
|
||||||
point_3d_homo, # if False
|
point_3d_homo, # if False
|
||||||
)
|
)
|
||||||
@ -493,14 +494,14 @@ def triangulate_points_from_multiple_views_linear(
|
|||||||
confidences: (N, P, 1) optional per-view confidences in [0,1]
|
confidences: (N, P, 1) optional per-view confidences in [0,1]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(P, 3) 3D point for each of the P tracks
|
(P, 3) 3D point for each of the P
|
||||||
"""
|
"""
|
||||||
N, P, _ = points.shape
|
N, P, _ = points.shape
|
||||||
assert proj_matrices.shape[0] == N
|
assert proj_matrices.shape[0] == N
|
||||||
if confidences is None:
|
if confidences is None:
|
||||||
conf = jnp.ones((N, P), dtype=jnp.float32)
|
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||||||
else:
|
else:
|
||||||
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))
|
conf = confidences
|
||||||
|
|
||||||
# vectorize your one-point routine over P
|
# vectorize your one-point routine over P
|
||||||
vmap_triangulate = jax.vmap(
|
vmap_triangulate = jax.vmap(
|
||||||
@ -578,7 +579,7 @@ def triangulate_one_point_from_multiple_views_linear_time_weighted(
|
|||||||
|
|
||||||
# Ensure homogeneous coordinate is positive
|
# Ensure homogeneous coordinate is positive
|
||||||
point_3d_homo = jnp.where(
|
point_3d_homo = jnp.where(
|
||||||
point_3d_homo[3] < 0,
|
point_3d_homo[3] <= 0,
|
||||||
-point_3d_homo,
|
-point_3d_homo,
|
||||||
point_3d_homo,
|
point_3d_homo,
|
||||||
)
|
)
|
||||||
@ -679,6 +680,8 @@ def group_by_cluster_by_camera(
|
|||||||
eld = r[el.camera.id]
|
eld = r[el.camera.id]
|
||||||
preserved = max([eld, el], key=lambda x: x.timestamp)
|
preserved = max([eld, el], key=lambda x: x.timestamp)
|
||||||
r[el.camera.id] = preserved
|
r[el.camera.id] = preserved
|
||||||
|
else:
|
||||||
|
r[el.camera.id] = el
|
||||||
return pmap(r)
|
return pmap(r)
|
||||||
|
|
||||||
|
|
||||||
@ -714,7 +717,11 @@ class GlobalTrackingState:
|
|||||||
tracking = Tracking(
|
tracking = Tracking(
|
||||||
id=next_id,
|
id=next_id,
|
||||||
state=tracking_state,
|
state=tracking_state,
|
||||||
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
# velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||||||
|
velocity_filter=LeastMeanSquareVelocityFilter(
|
||||||
|
historical_3d_poses=[kps_3d],
|
||||||
|
historical_timestamps=[latest_timestamp],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self._trackings[next_id] = tracking
|
self._trackings[next_id] = tracking
|
||||||
self._last_id = next_id
|
self._last_id = next_id
|
||||||
|
|||||||
Reference in New Issue
Block a user