refactor: Update type hints and enhance affinity calculations in playground.py

- Changed function signatures to use `Sequence` instead of `list` for better type flexibility.
- Introduced a new function `calculate_tracking_detection_affinity` to streamline the calculation of affinities between tracking and detection objects.
- Refactored existing affinity calculation logic to improve clarity and performance, leveraging the new affinity function.
- Removed commented-out code to clean up the implementation and enhance readability.
This commit is contained in:
2025-04-27 17:45:31 +08:00
parent 4f48c78cfb
commit 41e0141bde
2 changed files with 77 additions and 47 deletions

View File

@ -1,7 +1,7 @@
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, TypeAlias, TypedDict, Optional
from typing import Any, TypeAlias, TypedDict, Optional, Sequence
from beartype import beartype
import jax
@ -463,7 +463,7 @@ class Detection:
def classify_by_camera(
detections: list[Detection],
detections: Sequence[Detection],
) -> OrderedDict[CameraID, list[Detection]]:
"""
Classify detections by camera
@ -677,7 +677,7 @@ def compute_affinity_epipolar_constraint_with_pairs(
def calculate_affinity_matrix_by_epipolar_constraint(
detections: list[Detection] | dict[CameraID, list[Detection]],
detections: Sequence[Detection] | dict[CameraID, Sequence[Detection]],
alpha_2d: float,
) -> tuple[list[Detection], Num[Array, "N N"]]:
"""

View File

@ -178,7 +178,7 @@ def preprocess_keypoint_dataset(
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta):
def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta):
"""
given a list of detection generators, return a generator that yields a batch of detections
@ -347,7 +347,7 @@ with jnp.printoptions(precision=3, suppress=True):
def clusters_to_detections(
clusters: list[list[int]], sorted_detections: list[Detection]
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
) -> list[list[Detection]]:
"""
given a list of clusters (which is the indices of the detections in the sorted_detections list),
@ -473,8 +473,6 @@ def triangulate_points_from_multiple_views_linear(
# %%
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
@ -502,7 +500,7 @@ class Tracking:
@jaxtyped(typechecker=beartype)
def triangle_from_cluster(
cluster: list[Detection],
cluster: Sequence[Detection],
) -> tuple[Float[Array, "N 3"], datetime]:
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
points = jnp.array([el.keypoints_undistorted for el in cluster])
@ -516,14 +514,7 @@ def triangle_from_cluster(
)
# res = {
# "a": triangle_from_cluster(clusters_detections[0]).tolist(),
# "b": triangle_from_cluster(clusters_detections[1]).tolist(),
# }
# with open("samples/res.json", "wb") as f:
# f.write(orjson.dumps(res))
# %%
class GlobalTrackingState:
_last_id: int
_trackings: dict[int, Tracking]
@ -541,7 +532,7 @@ class GlobalTrackingState:
def trackings(self) -> dict[int, Tracking]:
return shallow_copy(self._trackings)
def add_tracking(self, cluster: list[Detection]) -> Tracking:
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking = Tracking(
@ -598,7 +589,7 @@ def calculate_affinity_2d(
w_2d: float,
alpha_2d: float,
lambda_a: float,
) -> float:
) -> Float[Array, "J"]:
"""
Calculate the affinity between two detections based on the distances between their keypoints.
@ -621,7 +612,7 @@ def calculate_affinity_2d(
* (1 - distance_2d / (alpha_2d * delta_t_s))
* jnp.exp(-lambda_a * delta_t_s)
)
return jnp.sum(affinity_per_keypoint).item()
return affinity_per_keypoint
@jaxtyped(typechecker=beartype)
@ -693,7 +684,7 @@ def calculate_affinity_3d(
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> float:
) -> Float[Array, "J"]:
"""
Calculate 3D affinity score between a tracking and detection.
@ -714,9 +705,7 @@ def calculate_affinity_3d(
affinity_per_keypoint = (
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
)
# Sum affinities across all keypoints
return jnp.sum(affinity_per_keypoint).item()
return affinity_per_keypoint
def predict_pose_3d(
@ -731,6 +720,67 @@ def predict_pose_3d(
return tracking.keypoints + tracking.velocity * delta_t_s
@beartype
def calculate_tracking_detection_affinity(
tracking: Tracking,
detection: Detection,
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> float:
"""
Calculate the affinity between a tracking and a detection.
Args:
tracking: The tracking object
detection: The detection object
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
Combined affinity score
"""
camera = detection.camera
delta_t = detection.timestamp - tracking.last_active_timestamp
# Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints)
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
detection.keypoints,
image_size=(w, h),
)
affinity_2d = calculate_affinity_2d(
distance_2d,
delta_t,
w_2d=w_2d,
alpha_2d=alpha_2d,
lambda_a=lambda_a,
)
# Calculate 3D affinity
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection, tracking, delta_t
)
affinity_3d = calculate_affinity_3d(
distances,
delta_t,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
# Combine affinities
total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item()
# %%
# let's do cross-view association
W_2D = 1.0
@ -738,7 +788,6 @@ ALPHA_2D = 1.0
LAMBDA_A = 0.1
W_3D = 1.0
ALPHA_3D = 1.0
LAMBDA_A = 0.1
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
@ -757,35 +806,16 @@ detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings):
j = 0
for c, detections in detection_by_camera.items():
camera = next(iter(detections)).camera
# pixel space, unnormalized
tracking_2d_projection = camera.project(tracking.keypoints)
for det in detections:
delta_t = det.timestamp - tracking.last_active_timestamp
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
det.keypoints,
image_size=(w, h),
)
affinity_2d = calculate_affinity_2d(
distance_2d,
delta_t,
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
lambda_a=LAMBDA_A,
)
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
det, tracking, delta_t
)
affinity_3d = calculate_affinity_3d(
distances,
delta_t,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
affinity_sum = affinity_2d + affinity_3d
affinity = affinity.at[i, j].set(affinity_sum)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
display(affinity)