feat: Implement time-weighted triangulation for enhanced 3D point reconstruction

- Added two new functions: `triangulate_one_point_from_multiple_views_linear_time_weighted` and `triangulate_points_from_multiple_views_linear_time_weighted` to perform triangulation with time-based weighting, improving accuracy in 3D point estimation.
- Introduced a method to group detections by camera while preserving the latest detection, enhancing tracking state management.
- Updated the `update_tracking` function to incorporate time-weighted triangulation, allowing for more robust updates to tracking states based on new detections.
- Refactored the `TrackingState` to utilize a mapping of historical detections by camera, improving data organization and access.
This commit is contained in:
2025-05-03 17:17:47 +08:00
parent 20b2cf59f2
commit 1f8d70803f
2 changed files with 249 additions and 15 deletions

View File

@ -13,9 +13,9 @@ from typing import (
TypeAlias, TypeAlias,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union,
cast, cast,
overload, overload,
Union,
) )
import jax.numpy as jnp import jax.numpy as jnp
@ -23,9 +23,11 @@ from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from jax import Array from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector, v from pyrsistent import PVector, v, PRecord, PMap
from app.camera import Detection from app.camera import Detection, CameraID
TrackingID: TypeAlias = int
class TrackingPrediction(TypedDict): class TrackingPrediction(TypedDict):
@ -440,7 +442,7 @@ class TrackingState:
The last active timestamp of the tracking The last active timestamp of the tracking
""" """
historical_detections: PVector[Detection] historical_detections_by_camera: PMap[CameraID, Detection]
""" """
Historical detections of the tracking. Historical detections of the tracking.
@ -449,13 +451,13 @@ class TrackingState:
class Tracking: class Tracking:
id: int id: TrackingID
state: TrackingState state: TrackingState
velocity_filter: GenericVelocityFilter velocity_filter: GenericVelocityFilter
def __init__( def __init__(
self, self,
id: int, id: TrackingID,
state: TrackingState, state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None, velocity_filter: Optional[GenericVelocityFilter] = None,
): ):
@ -512,6 +514,15 @@ class Tracking:
# pylint: disable-next=unsubscriptable-object # pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"] return self.velocity_filter.predict(timestamp)["keypoints"]
def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the tracking with a new 3D pose
Note:
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
"""
self.velocity_filter.update(new_3d_pose, timestamp)
@property @property
def velocity(self) -> Float[Array, "J 3"]: def velocity(self) -> Float[Array, "J 3"]:
""" """
@ -537,7 +548,7 @@ class AffinityResult:
indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections( def tracking_association(
self, self,
) -> Generator[tuple[float, Tracking, Detection], None, None]: ) -> Generator[tuple[float, Tracking, Detection], None, None]:
""" """

View File

@ -31,6 +31,7 @@ from typing import (
TypeVar, TypeVar,
cast, cast,
overload, overload,
Iterable,
) )
import awkward as ak import awkward as ak
@ -45,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from optax.assignment import hungarian_algorithm as linear_sum_assignment from optax.assignment import hungarian_algorithm as linear_sum_assignment
from pyrsistent import pvector, v from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
from collections import defaultdict
from app.camera import ( from app.camera import (
Camera, Camera,
@ -59,6 +61,7 @@ from app.camera import (
) )
from app.solver._old import GLPKSolver from app.solver._old import GLPKSolver
from app.tracking import ( from app.tracking import (
TrackingID,
AffinityResult, AffinityResult,
LastDifferenceVelocityFilter, LastDifferenceVelocityFilter,
Tracking, Tracking,
@ -508,6 +511,142 @@ def triangulate_points_from_multiple_views_linear(
return vmap_triangulate(proj_matrices, points, conf) return vmap_triangulate(proj_matrices, points, conf)
# %%
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Triangulate one point from multiple views with time-weighted linear least squares.
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
Args:
proj_matrices: Shape (N, 3, 4) projection matrices sequence
points: Shape (N, 2) point coordinates sequence
delta_t: Time differences between current time and each observation (in seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N,) confidence values in range [0.0, 1.0]
Returns:
point_3d: Shape (3,) triangulated 3D point
"""
assert len(proj_matrices) == len(points)
assert len(delta_t) == len(points)
N = len(proj_matrices)
# Prepare confidence weights
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
# First build the coefficient matrix without weights
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
# Then apply the time-based and confidence weights
for i in range(N):
# Calculate time-decay weight: e^(-λ_t * Δt)
time_weight = jnp.exp(-lambda_t * delta_t[i])
# Calculate normalization factor: ||c^i^T||_2
row_norm_1 = jnp.linalg.norm(A[2 * i])
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
# Apply combined weight: time_weight / row_norm * confidence
w1 = (time_weight / row_norm_1) * confi[i]
w2 = (time_weight / row_norm_2) * confi[i]
A = A.at[2 * i].mul(w1)
A = A.at[2 * i + 1].mul(w2)
# Solve using SVD
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# Ensure homogeneous coordinate is positive
point_3d_homo = jnp.where(
point_3d_homo[3] < 0,
-point_3d_homo,
point_3d_homo,
)
# Convert from homogeneous to Euclidean coordinates
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Vectorized version that triangulates P points from N camera views with time-weighting.
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
Args:
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N, P) confidence values for each point in each camera
Returns:
points_3d: Shape (P, 3) triangulated 3D points
"""
N, P, _ = points.shape
assert (
proj_matrices.shape[0] == N
), "Number of projection matrices must match number of cameras"
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
if confidences is None:
# Create uniform confidences if none provided
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = confidences
# Define the vmapped version of the single-point function
# We map over the second dimension (P points) of the input arrays
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear_time_weighted,
in_axes=(
None,
1,
None,
None,
1,
), # proj_matrices and delta_t static, map over points
out_axes=0, # Output has first dimension corresponding to points
)
# For each point p, extract the 2D coordinates from all cameras and triangulate
return vmap_triangulate(
proj_matrices, # (N, 3, 4) - static across points
points, # (N, P, 2) - map over dim 1 (P)
delta_t, # (N,) - static across points
lambda_t, # scalar - static
conf, # (N, P) - map over dim 1 (P)
)
# %% # %%
@ -528,6 +667,21 @@ def triangle_from_cluster(
# %% # %%
def group_by_cluster_by_camera(
cluster: Sequence[Detection],
) -> PMap[CameraID, Detection]:
"""
group the detections by camera, and preserve the latest detection for each camera
"""
r: dict[CameraID, Detection] = {}
for el in cluster:
if el.camera.id in r:
eld = r[el.camera.id]
preserved = max([eld, el], key=lambda x: x.timestamp)
r[el.camera.id] = preserved
return pmap(r)
class GlobalTrackingState: class GlobalTrackingState:
_last_id: int _last_id: int
_trackings: dict[int, Tracking] _trackings: dict[int, Tracking]
@ -546,12 +700,16 @@ class GlobalTrackingState:
return shallow_copy(self._trackings) return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
if len(cluster) < 2:
raise ValueError(
"cluster must contain at least 2 detections to form a tracking"
)
kps_3d, latest_timestamp = triangle_from_cluster(cluster) kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1 next_id = self._last_id + 1
tracking_state = TrackingState( tracking_state = TrackingState(
keypoints=kps_3d, keypoints=kps_3d,
last_active_timestamp=latest_timestamp, last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster), historical_detections_by_camera=group_by_cluster_by_camera(cluster),
) )
tracking = Tracking( tracking = Tracking(
id=next_id, id=next_id,
@ -679,9 +837,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint Array of perpendicular distances for each keypoint
""" """
camera = detection.camera camera = detection.camera
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to predicted_pose = tracking.predict(delta_t)
# avoid division-by-zero / exploding affinities.
predicted_pose = tracking.predict(max(delta_t, DELTA_T_MIN))
# Back-project the 2D points to 3D space # Back-project the 2D points to 3D space
# intersection with z=0 plane # intersection with z=0 plane
@ -1039,6 +1195,73 @@ display(affinities)
# %% # %%
def update_tracking(tracking: Tracking, detection: Detection): def affinity_result_by_tracking(
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp results: Iterable[AffinityResult],
raise NotImplementedError ) -> dict[TrackingID, list[Detection]]:
"""
Group affinity results by target ID.
"""
res: dict[TrackingID, list[Detection]] = defaultdict(list)
for affinity_result in results:
for _affinity, t, d in affinity_result.tracking_association():
res[t.id].append(d)
return res
def update_tracking(
tracking: Tracking,
detections: Sequence[Detection],
max_delta_t: timedelta = timedelta(milliseconds=100),
lambda_t: float = 10.0,
) -> None:
"""
update the tracking with a new set of detections
Args:
tracking: the tracking to update
detections: the detections to update the tracking with
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
lambda_t: the lambda value for the time difference
Note:
the function would mutate the tracking object
"""
last_active_timestamp = tracking.state.last_active_timestamp
latest_timestamp = max(d.timestamp for d in detections)
d = thaw(tracking.state.historical_detections_by_camera)
for detection in detections:
d[detection.camera.id] = detection
for camera_id, detection in d.items():
if detection.timestamp - latest_timestamp > max_delta_t:
del d[camera_id]
new_detections = freeze(d)
new_detections_list = list(new_detections.values())
project_matrices = jnp.stack(
[detection.camera.params.projection_matrix for detection in new_detections_list]
)
delta_t = jnp.array(
[
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
for detection in new_detections_list
]
)
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
conf = jnp.stack([detection.confidences for detection in new_detections_list])
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
project_matrices, kps, delta_t, lambda_t, conf
)
new_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=new_detections,
)
tracking.update(kps_3d, latest_timestamp)
tracking.state = new_state
# %%
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values())
for tracking_id, detections in affinity_results_by_tracking.items():
update_tracking(global_tracking_state.trackings[tracking_id], detections)
# %%