diff --git a/app/camera/__init__.py b/app/camera/__init__.py index 9aa25f5..b583566 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -1,7 +1,7 @@ from collections import OrderedDict, defaultdict from dataclasses import dataclass from datetime import datetime -from typing import Any, TypeAlias, TypedDict, Optional +from typing import Any, TypeAlias, TypedDict, Optional, Sequence from beartype import beartype import jax @@ -463,7 +463,7 @@ class Detection: def classify_by_camera( - detections: list[Detection], + detections: Sequence[Detection], ) -> OrderedDict[CameraID, list[Detection]]: """ Classify detections by camera @@ -677,7 +677,7 @@ def compute_affinity_epipolar_constraint_with_pairs( def calculate_affinity_matrix_by_epipolar_constraint( - detections: list[Detection] | dict[CameraID, list[Detection]], + detections: Sequence[Detection] | dict[CameraID, Sequence[Detection]], alpha_2d: float, ) -> tuple[list[Detection], Num[Array, "N N"]]: """ diff --git a/playground.py b/playground.py index 5bfc764..33fa20d 100644 --- a/playground.py +++ b/playground.py @@ -178,7 +178,7 @@ def preprocess_keypoint_dataset( DetectionGenerator: TypeAlias = Generator[Detection, None, None] -def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta): +def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): """ given a list of detection generators, return a generator that yields a batch of detections @@ -347,7 +347,7 @@ with jnp.printoptions(precision=3, suppress=True): def clusters_to_detections( - clusters: list[list[int]], sorted_detections: list[Detection] + clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), @@ -473,8 +473,6 @@ def triangulate_points_from_multiple_views_linear( # %% - - @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: @@ -502,7 +500,7 @@ class Tracking: @jaxtyped(typechecker=beartype) def triangle_from_cluster( - cluster: list[Detection], + cluster: Sequence[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) @@ -516,14 +514,7 @@ def triangle_from_cluster( ) -# res = { -# "a": triangle_from_cluster(clusters_detections[0]).tolist(), -# "b": triangle_from_cluster(clusters_detections[1]).tolist(), -# } -# with open("samples/res.json", "wb") as f: -# f.write(orjson.dumps(res)) - - +# %% class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] @@ -541,7 +532,7 @@ class GlobalTrackingState: def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) - def add_tracking(self, cluster: list[Detection]) -> Tracking: + def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking = Tracking( @@ -598,7 +589,7 @@ def calculate_affinity_2d( w_2d: float, alpha_2d: float, lambda_a: float, -) -> float: +) -> Float[Array, "J"]: """ Calculate the affinity between two detections based on the distances between their keypoints. @@ -621,7 +612,7 @@ def calculate_affinity_2d( * (1 - distance_2d / (alpha_2d * delta_t_s)) * jnp.exp(-lambda_a * delta_t_s) ) - return jnp.sum(affinity_per_keypoint).item() + return affinity_per_keypoint @jaxtyped(typechecker=beartype) @@ -693,7 +684,7 @@ def calculate_affinity_3d( w_3d: float, alpha_3d: float, lambda_a: float, -) -> float: +) -> Float[Array, "J"]: """ Calculate 3D affinity score between a tracking and detection. @@ -714,9 +705,7 @@ def calculate_affinity_3d( affinity_per_keypoint = ( w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s) ) - - # Sum affinities across all keypoints - return jnp.sum(affinity_per_keypoint).item() + return affinity_per_keypoint def predict_pose_3d( @@ -731,6 +720,67 @@ def predict_pose_3d( return tracking.keypoints + tracking.velocity * delta_t_s +@beartype +def calculate_tracking_detection_affinity( + tracking: Tracking, + detection: Detection, + w_2d: float, + alpha_2d: float, + w_3d: float, + alpha_3d: float, + lambda_a: float, +) -> float: + """ + Calculate the affinity between a tracking and a detection. + + Args: + tracking: The tracking object + detection: The detection object + w_2d: Weight for 2D affinity + alpha_2d: Normalization factor for 2D distance + w_3d: Weight for 3D affinity + alpha_3d: Normalization factor for 3D distance + lambda_a: Decay rate for time difference + + Returns: + Combined affinity score + """ + camera = detection.camera + delta_t = detection.timestamp - tracking.last_active_timestamp + + # Calculate 2D affinity + tracking_2d_projection = camera.project(tracking.keypoints) + w, h = camera.params.image_size + distance_2d = calculate_distance_2d( + tracking_2d_projection, + detection.keypoints, + image_size=(w, h), + ) + affinity_2d = calculate_affinity_2d( + distance_2d, + delta_t, + w_2d=w_2d, + alpha_2d=alpha_2d, + lambda_a=lambda_a, + ) + + # Calculate 3D affinity + distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting( + detection, tracking, delta_t + ) + affinity_3d = calculate_affinity_3d( + distances, + delta_t, + w_3d=w_3d, + alpha_3d=alpha_3d, + lambda_a=lambda_a, + ) + + # Combine affinities + total_affinity = affinity_2d + affinity_3d + return jnp.sum(total_affinity).item() + + # %% # let's do cross-view association W_2D = 1.0 @@ -738,7 +788,6 @@ ALPHA_2D = 1.0 LAMBDA_A = 0.1 W_3D = 1.0 ALPHA_3D = 1.0 -LAMBDA_A = 0.1 trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) @@ -757,35 +806,16 @@ detection_by_camera = classify_by_camera(unmatched_detections) for i, tracking in enumerate(trackings): j = 0 for c, detections in detection_by_camera.items(): - camera = next(iter(detections)).camera - # pixel space, unnormalized - tracking_2d_projection = camera.project(tracking.keypoints) for det in detections: - delta_t = det.timestamp - tracking.last_active_timestamp - w, h = camera.params.image_size - distance_2d = calculate_distance_2d( - tracking_2d_projection, - det.keypoints, - image_size=(w, h), - ) - affinity_2d = calculate_affinity_2d( - distance_2d, - delta_t, + affinity_value = calculate_tracking_detection_affinity( + tracking, + det, w_2d=W_2D, alpha_2d=ALPHA_2D, - lambda_a=LAMBDA_A, - ) - distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting( - det, tracking, delta_t - ) - affinity_3d = calculate_affinity_3d( - distances, - delta_t, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) - affinity_sum = affinity_2d + affinity_3d - affinity = affinity.at[i, j].set(affinity_sum) + affinity = affinity.at[i, j].set(affinity_value) j += 1 display(affinity)