forked from HQU-gxy/CVTH3PE
feat: Implement time-weighted triangulation for enhanced 3D point reconstruction
- Added two new functions: `triangulate_one_point_from_multiple_views_linear_time_weighted` and `triangulate_points_from_multiple_views_linear_time_weighted` to perform triangulation with time-based weighting, improving accuracy in 3D point estimation. - Introduced a method to group detections by camera while preserving the latest detection, enhancing tracking state management. - Updated the `update_tracking` function to incorporate time-weighted triangulation, allowing for more robust updates to tracking states based on new detections. - Refactored the `TrackingState` to utilize a mapping of historical detections by camera, improving data organization and access.
This commit is contained in:
@ -13,9 +13,9 @@ from typing import (
|
|||||||
TypeAlias,
|
TypeAlias,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -23,9 +23,11 @@ from beartype import beartype
|
|||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
from jax import Array
|
from jax import Array
|
||||||
from jaxtyping import Array, Float, Int, jaxtyped
|
from jaxtyping import Array, Float, Int, jaxtyped
|
||||||
from pyrsistent import PVector, v
|
from pyrsistent import PVector, v, PRecord, PMap
|
||||||
|
|
||||||
from app.camera import Detection
|
from app.camera import Detection, CameraID
|
||||||
|
|
||||||
|
TrackingID: TypeAlias = int
|
||||||
|
|
||||||
|
|
||||||
class TrackingPrediction(TypedDict):
|
class TrackingPrediction(TypedDict):
|
||||||
@ -440,7 +442,7 @@ class TrackingState:
|
|||||||
The last active timestamp of the tracking
|
The last active timestamp of the tracking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
historical_detections: PVector[Detection]
|
historical_detections_by_camera: PMap[CameraID, Detection]
|
||||||
"""
|
"""
|
||||||
Historical detections of the tracking.
|
Historical detections of the tracking.
|
||||||
|
|
||||||
@ -449,13 +451,13 @@ class TrackingState:
|
|||||||
|
|
||||||
|
|
||||||
class Tracking:
|
class Tracking:
|
||||||
id: int
|
id: TrackingID
|
||||||
state: TrackingState
|
state: TrackingState
|
||||||
velocity_filter: GenericVelocityFilter
|
velocity_filter: GenericVelocityFilter
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: int,
|
id: TrackingID,
|
||||||
state: TrackingState,
|
state: TrackingState,
|
||||||
velocity_filter: Optional[GenericVelocityFilter] = None,
|
velocity_filter: Optional[GenericVelocityFilter] = None,
|
||||||
):
|
):
|
||||||
@ -512,6 +514,15 @@ class Tracking:
|
|||||||
# pylint: disable-next=unsubscriptable-object
|
# pylint: disable-next=unsubscriptable-object
|
||||||
return self.velocity_filter.predict(timestamp)["keypoints"]
|
return self.velocity_filter.predict(timestamp)["keypoints"]
|
||||||
|
|
||||||
|
def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
"""
|
||||||
|
update the tracking with a new 3D pose
|
||||||
|
|
||||||
|
Note:
|
||||||
|
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
|
||||||
|
"""
|
||||||
|
self.velocity_filter.update(new_3d_pose, timestamp)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def velocity(self) -> Float[Array, "J 3"]:
|
def velocity(self) -> Float[Array, "J 3"]:
|
||||||
"""
|
"""
|
||||||
@ -537,7 +548,7 @@ class AffinityResult:
|
|||||||
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
||||||
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
||||||
|
|
||||||
def tracking_detections(
|
def tracking_association(
|
||||||
self,
|
self,
|
||||||
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
239
playground.py
239
playground.py
@ -31,6 +31,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
|
Iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
import awkward as ak
|
import awkward as ak
|
||||||
@ -45,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||||||
from pyrsistent import pvector, v
|
from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from app.camera import (
|
from app.camera import (
|
||||||
Camera,
|
Camera,
|
||||||
@ -59,6 +61,7 @@ from app.camera import (
|
|||||||
)
|
)
|
||||||
from app.solver._old import GLPKSolver
|
from app.solver._old import GLPKSolver
|
||||||
from app.tracking import (
|
from app.tracking import (
|
||||||
|
TrackingID,
|
||||||
AffinityResult,
|
AffinityResult,
|
||||||
LastDifferenceVelocityFilter,
|
LastDifferenceVelocityFilter,
|
||||||
Tracking,
|
Tracking,
|
||||||
@ -508,6 +511,142 @@ def triangulate_points_from_multiple_views_linear(
|
|||||||
return vmap_triangulate(proj_matrices, points, conf)
|
return vmap_triangulate(proj_matrices, points, conf)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def triangulate_one_point_from_multiple_views_linear_time_weighted(
|
||||||
|
proj_matrices: Float[Array, "N 3 4"],
|
||||||
|
points: Num[Array, "N 2"],
|
||||||
|
delta_t: Num[Array, "N"],
|
||||||
|
lambda_t: float = 10.0,
|
||||||
|
confidences: Optional[Float[Array, "N"]] = None,
|
||||||
|
) -> Float[Array, "3"]:
|
||||||
|
"""
|
||||||
|
Triangulate one point from multiple views with time-weighted linear least squares.
|
||||||
|
|
||||||
|
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
|
||||||
|
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proj_matrices: Shape (N, 3, 4) projection matrices sequence
|
||||||
|
points: Shape (N, 2) point coordinates sequence
|
||||||
|
delta_t: Time differences between current time and each observation (in seconds)
|
||||||
|
lambda_t: Time penalty rate (higher values decrease influence of older observations)
|
||||||
|
confidences: Shape (N,) confidence values in range [0.0, 1.0]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
point_3d: Shape (3,) triangulated 3D point
|
||||||
|
"""
|
||||||
|
assert len(proj_matrices) == len(points)
|
||||||
|
assert len(delta_t) == len(points)
|
||||||
|
|
||||||
|
N = len(proj_matrices)
|
||||||
|
|
||||||
|
# Prepare confidence weights
|
||||||
|
confi: Float[Array, "N"]
|
||||||
|
if confidences is None:
|
||||||
|
confi = jnp.ones(N, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
|
||||||
|
|
||||||
|
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||||||
|
|
||||||
|
# First build the coefficient matrix without weights
|
||||||
|
for i in range(N):
|
||||||
|
x, y = points[i]
|
||||||
|
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
|
||||||
|
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
|
||||||
|
|
||||||
|
# Then apply the time-based and confidence weights
|
||||||
|
for i in range(N):
|
||||||
|
# Calculate time-decay weight: e^(-λ_t * Δt)
|
||||||
|
time_weight = jnp.exp(-lambda_t * delta_t[i])
|
||||||
|
|
||||||
|
# Calculate normalization factor: ||c^i^T||_2
|
||||||
|
row_norm_1 = jnp.linalg.norm(A[2 * i])
|
||||||
|
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
|
||||||
|
|
||||||
|
# Apply combined weight: time_weight / row_norm * confidence
|
||||||
|
w1 = (time_weight / row_norm_1) * confi[i]
|
||||||
|
w2 = (time_weight / row_norm_2) * confi[i]
|
||||||
|
|
||||||
|
A = A.at[2 * i].mul(w1)
|
||||||
|
A = A.at[2 * i + 1].mul(w2)
|
||||||
|
|
||||||
|
# Solve using SVD
|
||||||
|
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||||||
|
point_3d_homo = vh[-1] # shape (4,)
|
||||||
|
|
||||||
|
# Ensure homogeneous coordinate is positive
|
||||||
|
point_3d_homo = jnp.where(
|
||||||
|
point_3d_homo[3] < 0,
|
||||||
|
-point_3d_homo,
|
||||||
|
point_3d_homo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert from homogeneous to Euclidean coordinates
|
||||||
|
point_3d = point_3d_homo[:3] / point_3d_homo[3]
|
||||||
|
return point_3d
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def triangulate_points_from_multiple_views_linear_time_weighted(
|
||||||
|
proj_matrices: Float[Array, "N 3 4"],
|
||||||
|
points: Num[Array, "N P 2"],
|
||||||
|
delta_t: Num[Array, "N"],
|
||||||
|
lambda_t: float = 10.0,
|
||||||
|
confidences: Optional[Float[Array, "N P"]] = None,
|
||||||
|
) -> Float[Array, "P 3"]:
|
||||||
|
"""
|
||||||
|
Vectorized version that triangulates P points from N camera views with time-weighting.
|
||||||
|
|
||||||
|
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
|
||||||
|
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
|
||||||
|
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
|
||||||
|
lambda_t: Time penalty rate (higher values decrease influence of older observations)
|
||||||
|
confidences: Shape (N, P) confidence values for each point in each camera
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
points_3d: Shape (P, 3) triangulated 3D points
|
||||||
|
"""
|
||||||
|
N, P, _ = points.shape
|
||||||
|
assert (
|
||||||
|
proj_matrices.shape[0] == N
|
||||||
|
), "Number of projection matrices must match number of cameras"
|
||||||
|
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
|
||||||
|
|
||||||
|
if confidences is None:
|
||||||
|
# Create uniform confidences if none provided
|
||||||
|
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||||||
|
else:
|
||||||
|
conf = confidences
|
||||||
|
|
||||||
|
# Define the vmapped version of the single-point function
|
||||||
|
# We map over the second dimension (P points) of the input arrays
|
||||||
|
vmap_triangulate = jax.vmap(
|
||||||
|
triangulate_one_point_from_multiple_views_linear_time_weighted,
|
||||||
|
in_axes=(
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
), # proj_matrices and delta_t static, map over points
|
||||||
|
out_axes=0, # Output has first dimension corresponding to points
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each point p, extract the 2D coordinates from all cameras and triangulate
|
||||||
|
return vmap_triangulate(
|
||||||
|
proj_matrices, # (N, 3, 4) - static across points
|
||||||
|
points, # (N, P, 2) - map over dim 1 (P)
|
||||||
|
delta_t, # (N,) - static across points
|
||||||
|
lambda_t, # scalar - static
|
||||||
|
conf, # (N, P) - map over dim 1 (P)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
@ -528,6 +667,21 @@ def triangle_from_cluster(
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
def group_by_cluster_by_camera(
|
||||||
|
cluster: Sequence[Detection],
|
||||||
|
) -> PMap[CameraID, Detection]:
|
||||||
|
"""
|
||||||
|
group the detections by camera, and preserve the latest detection for each camera
|
||||||
|
"""
|
||||||
|
r: dict[CameraID, Detection] = {}
|
||||||
|
for el in cluster:
|
||||||
|
if el.camera.id in r:
|
||||||
|
eld = r[el.camera.id]
|
||||||
|
preserved = max([eld, el], key=lambda x: x.timestamp)
|
||||||
|
r[el.camera.id] = preserved
|
||||||
|
return pmap(r)
|
||||||
|
|
||||||
|
|
||||||
class GlobalTrackingState:
|
class GlobalTrackingState:
|
||||||
_last_id: int
|
_last_id: int
|
||||||
_trackings: dict[int, Tracking]
|
_trackings: dict[int, Tracking]
|
||||||
@ -546,12 +700,16 @@ class GlobalTrackingState:
|
|||||||
return shallow_copy(self._trackings)
|
return shallow_copy(self._trackings)
|
||||||
|
|
||||||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||||||
|
if len(cluster) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"cluster must contain at least 2 detections to form a tracking"
|
||||||
|
)
|
||||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||||
next_id = self._last_id + 1
|
next_id = self._last_id + 1
|
||||||
tracking_state = TrackingState(
|
tracking_state = TrackingState(
|
||||||
keypoints=kps_3d,
|
keypoints=kps_3d,
|
||||||
last_active_timestamp=latest_timestamp,
|
last_active_timestamp=latest_timestamp,
|
||||||
historical_detections=v(*cluster),
|
historical_detections_by_camera=group_by_cluster_by_camera(cluster),
|
||||||
)
|
)
|
||||||
tracking = Tracking(
|
tracking = Tracking(
|
||||||
id=next_id,
|
id=next_id,
|
||||||
@ -679,9 +837,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
Array of perpendicular distances for each keypoint
|
Array of perpendicular distances for each keypoint
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
|
predicted_pose = tracking.predict(delta_t)
|
||||||
# avoid division-by-zero / exploding affinities.
|
|
||||||
predicted_pose = tracking.predict(max(delta_t, DELTA_T_MIN))
|
|
||||||
|
|
||||||
# Back-project the 2D points to 3D space
|
# Back-project the 2D points to 3D space
|
||||||
# intersection with z=0 plane
|
# intersection with z=0 plane
|
||||||
@ -1039,6 +1195,73 @@ display(affinities)
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def update_tracking(tracking: Tracking, detection: Detection):
|
def affinity_result_by_tracking(
|
||||||
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp
|
results: Iterable[AffinityResult],
|
||||||
raise NotImplementedError
|
) -> dict[TrackingID, list[Detection]]:
|
||||||
|
"""
|
||||||
|
Group affinity results by target ID.
|
||||||
|
"""
|
||||||
|
res: dict[TrackingID, list[Detection]] = defaultdict(list)
|
||||||
|
for affinity_result in results:
|
||||||
|
for _affinity, t, d in affinity_result.tracking_association():
|
||||||
|
res[t.id].append(d)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def update_tracking(
|
||||||
|
tracking: Tracking,
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
max_delta_t: timedelta = timedelta(milliseconds=100),
|
||||||
|
lambda_t: float = 10.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
update the tracking with a new set of detections
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracking: the tracking to update
|
||||||
|
detections: the detections to update the tracking with
|
||||||
|
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
|
||||||
|
lambda_t: the lambda value for the time difference
|
||||||
|
|
||||||
|
Note:
|
||||||
|
the function would mutate the tracking object
|
||||||
|
"""
|
||||||
|
last_active_timestamp = tracking.state.last_active_timestamp
|
||||||
|
latest_timestamp = max(d.timestamp for d in detections)
|
||||||
|
d = thaw(tracking.state.historical_detections_by_camera)
|
||||||
|
for detection in detections:
|
||||||
|
d[detection.camera.id] = detection
|
||||||
|
for camera_id, detection in d.items():
|
||||||
|
if detection.timestamp - latest_timestamp > max_delta_t:
|
||||||
|
del d[camera_id]
|
||||||
|
new_detections = freeze(d)
|
||||||
|
new_detections_list = list(new_detections.values())
|
||||||
|
project_matrices = jnp.stack(
|
||||||
|
[detection.camera.params.projection_matrix for detection in new_detections_list]
|
||||||
|
)
|
||||||
|
delta_t = jnp.array(
|
||||||
|
[
|
||||||
|
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
|
||||||
|
for detection in new_detections_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
|
||||||
|
conf = jnp.stack([detection.confidences for detection in new_detections_list])
|
||||||
|
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
|
||||||
|
project_matrices, kps, delta_t, lambda_t, conf
|
||||||
|
)
|
||||||
|
new_state = TrackingState(
|
||||||
|
keypoints=kps_3d,
|
||||||
|
last_active_timestamp=latest_timestamp,
|
||||||
|
historical_detections_by_camera=new_detections,
|
||||||
|
)
|
||||||
|
tracking.update(kps_3d, latest_timestamp)
|
||||||
|
tracking.state = new_state
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values())
|
||||||
|
for tracking_id, detections in affinity_results_by_tracking.items():
|
||||||
|
update_tracking(global_tracking_state.trackings[tracking_id], detections)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|||||||
Reference in New Issue
Block a user