refactor: Update type hints and enhance affinity calculations in playground.py
- Changed function signatures to use `Sequence` instead of `list` for better type flexibility. - Introduced a new function `calculate_tracking_detection_affinity` to streamline the calculation of affinities between tracking and detection objects. - Refactored existing affinity calculation logic to improve clarity and performance, leveraging the new affinity function. - Removed commented-out code to clean up the implementation and enhance readability.
This commit is contained in:
@ -1,7 +1,7 @@
|
|||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, TypeAlias, TypedDict, Optional
|
from typing import Any, TypeAlias, TypedDict, Optional, Sequence
|
||||||
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
import jax
|
import jax
|
||||||
@ -463,7 +463,7 @@ class Detection:
|
|||||||
|
|
||||||
|
|
||||||
def classify_by_camera(
|
def classify_by_camera(
|
||||||
detections: list[Detection],
|
detections: Sequence[Detection],
|
||||||
) -> OrderedDict[CameraID, list[Detection]]:
|
) -> OrderedDict[CameraID, list[Detection]]:
|
||||||
"""
|
"""
|
||||||
Classify detections by camera
|
Classify detections by camera
|
||||||
@ -677,7 +677,7 @@ def compute_affinity_epipolar_constraint_with_pairs(
|
|||||||
|
|
||||||
|
|
||||||
def calculate_affinity_matrix_by_epipolar_constraint(
|
def calculate_affinity_matrix_by_epipolar_constraint(
|
||||||
detections: list[Detection] | dict[CameraID, list[Detection]],
|
detections: Sequence[Detection] | dict[CameraID, Sequence[Detection]],
|
||||||
alpha_2d: float,
|
alpha_2d: float,
|
||||||
) -> tuple[list[Detection], Num[Array, "N N"]]:
|
) -> tuple[list[Detection], Num[Array, "N N"]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
118
playground.py
118
playground.py
@ -178,7 +178,7 @@ def preprocess_keypoint_dataset(
|
|||||||
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
||||||
|
|
||||||
|
|
||||||
def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta):
|
def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta):
|
||||||
"""
|
"""
|
||||||
given a list of detection generators, return a generator that yields a batch of detections
|
given a list of detection generators, return a generator that yields a batch of detections
|
||||||
|
|
||||||
@ -347,7 +347,7 @@ with jnp.printoptions(precision=3, suppress=True):
|
|||||||
|
|
||||||
|
|
||||||
def clusters_to_detections(
|
def clusters_to_detections(
|
||||||
clusters: list[list[int]], sorted_detections: list[Detection]
|
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
|
||||||
) -> list[list[Detection]]:
|
) -> list[list[Detection]]:
|
||||||
"""
|
"""
|
||||||
given a list of clusters (which is the indices of the detections in the sorted_detections list),
|
given a list of clusters (which is the indices of the detections in the sorted_detections list),
|
||||||
@ -473,8 +473,6 @@ def triangulate_points_from_multiple_views_linear(
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Tracking:
|
class Tracking:
|
||||||
@ -502,7 +500,7 @@ class Tracking:
|
|||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def triangle_from_cluster(
|
def triangle_from_cluster(
|
||||||
cluster: list[Detection],
|
cluster: Sequence[Detection],
|
||||||
) -> tuple[Float[Array, "N 3"], datetime]:
|
) -> tuple[Float[Array, "N 3"], datetime]:
|
||||||
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
|
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
|
||||||
points = jnp.array([el.keypoints_undistorted for el in cluster])
|
points = jnp.array([el.keypoints_undistorted for el in cluster])
|
||||||
@ -516,14 +514,7 @@ def triangle_from_cluster(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# res = {
|
# %%
|
||||||
# "a": triangle_from_cluster(clusters_detections[0]).tolist(),
|
|
||||||
# "b": triangle_from_cluster(clusters_detections[1]).tolist(),
|
|
||||||
# }
|
|
||||||
# with open("samples/res.json", "wb") as f:
|
|
||||||
# f.write(orjson.dumps(res))
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalTrackingState:
|
class GlobalTrackingState:
|
||||||
_last_id: int
|
_last_id: int
|
||||||
_trackings: dict[int, Tracking]
|
_trackings: dict[int, Tracking]
|
||||||
@ -541,7 +532,7 @@ class GlobalTrackingState:
|
|||||||
def trackings(self) -> dict[int, Tracking]:
|
def trackings(self) -> dict[int, Tracking]:
|
||||||
return shallow_copy(self._trackings)
|
return shallow_copy(self._trackings)
|
||||||
|
|
||||||
def add_tracking(self, cluster: list[Detection]) -> Tracking:
|
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||||
next_id = self._last_id + 1
|
next_id = self._last_id + 1
|
||||||
tracking = Tracking(
|
tracking = Tracking(
|
||||||
@ -598,7 +589,7 @@ def calculate_affinity_2d(
|
|||||||
w_2d: float,
|
w_2d: float,
|
||||||
alpha_2d: float,
|
alpha_2d: float,
|
||||||
lambda_a: float,
|
lambda_a: float,
|
||||||
) -> float:
|
) -> Float[Array, "J"]:
|
||||||
"""
|
"""
|
||||||
Calculate the affinity between two detections based on the distances between their keypoints.
|
Calculate the affinity between two detections based on the distances between their keypoints.
|
||||||
|
|
||||||
@ -621,7 +612,7 @@ def calculate_affinity_2d(
|
|||||||
* (1 - distance_2d / (alpha_2d * delta_t_s))
|
* (1 - distance_2d / (alpha_2d * delta_t_s))
|
||||||
* jnp.exp(-lambda_a * delta_t_s)
|
* jnp.exp(-lambda_a * delta_t_s)
|
||||||
)
|
)
|
||||||
return jnp.sum(affinity_per_keypoint).item()
|
return affinity_per_keypoint
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -693,7 +684,7 @@ def calculate_affinity_3d(
|
|||||||
w_3d: float,
|
w_3d: float,
|
||||||
alpha_3d: float,
|
alpha_3d: float,
|
||||||
lambda_a: float,
|
lambda_a: float,
|
||||||
) -> float:
|
) -> Float[Array, "J"]:
|
||||||
"""
|
"""
|
||||||
Calculate 3D affinity score between a tracking and detection.
|
Calculate 3D affinity score between a tracking and detection.
|
||||||
|
|
||||||
@ -714,9 +705,7 @@ def calculate_affinity_3d(
|
|||||||
affinity_per_keypoint = (
|
affinity_per_keypoint = (
|
||||||
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
|
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
|
||||||
)
|
)
|
||||||
|
return affinity_per_keypoint
|
||||||
# Sum affinities across all keypoints
|
|
||||||
return jnp.sum(affinity_per_keypoint).item()
|
|
||||||
|
|
||||||
|
|
||||||
def predict_pose_3d(
|
def predict_pose_3d(
|
||||||
@ -731,6 +720,67 @@ def predict_pose_3d(
|
|||||||
return tracking.keypoints + tracking.velocity * delta_t_s
|
return tracking.keypoints + tracking.velocity * delta_t_s
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def calculate_tracking_detection_affinity(
|
||||||
|
tracking: Tracking,
|
||||||
|
detection: Detection,
|
||||||
|
w_2d: float,
|
||||||
|
alpha_2d: float,
|
||||||
|
w_3d: float,
|
||||||
|
alpha_3d: float,
|
||||||
|
lambda_a: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the affinity between a tracking and a detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracking: The tracking object
|
||||||
|
detection: The detection object
|
||||||
|
w_2d: Weight for 2D affinity
|
||||||
|
alpha_2d: Normalization factor for 2D distance
|
||||||
|
w_3d: Weight for 3D affinity
|
||||||
|
alpha_3d: Normalization factor for 3D distance
|
||||||
|
lambda_a: Decay rate for time difference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined affinity score
|
||||||
|
"""
|
||||||
|
camera = detection.camera
|
||||||
|
delta_t = detection.timestamp - tracking.last_active_timestamp
|
||||||
|
|
||||||
|
# Calculate 2D affinity
|
||||||
|
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||||
|
w, h = camera.params.image_size
|
||||||
|
distance_2d = calculate_distance_2d(
|
||||||
|
tracking_2d_projection,
|
||||||
|
detection.keypoints,
|
||||||
|
image_size=(w, h),
|
||||||
|
)
|
||||||
|
affinity_2d = calculate_affinity_2d(
|
||||||
|
distance_2d,
|
||||||
|
delta_t,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate 3D affinity
|
||||||
|
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||||||
|
detection, tracking, delta_t
|
||||||
|
)
|
||||||
|
affinity_3d = calculate_affinity_3d(
|
||||||
|
distances,
|
||||||
|
delta_t,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine affinities
|
||||||
|
total_affinity = affinity_2d + affinity_3d
|
||||||
|
return jnp.sum(total_affinity).item()
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# let's do cross-view association
|
# let's do cross-view association
|
||||||
W_2D = 1.0
|
W_2D = 1.0
|
||||||
@ -738,7 +788,6 @@ ALPHA_2D = 1.0
|
|||||||
LAMBDA_A = 0.1
|
LAMBDA_A = 0.1
|
||||||
W_3D = 1.0
|
W_3D = 1.0
|
||||||
ALPHA_3D = 1.0
|
ALPHA_3D = 1.0
|
||||||
LAMBDA_A = 0.1
|
|
||||||
|
|
||||||
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||||
unmatched_detections = shallow_copy(next_group)
|
unmatched_detections = shallow_copy(next_group)
|
||||||
@ -757,35 +806,16 @@ detection_by_camera = classify_by_camera(unmatched_detections)
|
|||||||
for i, tracking in enumerate(trackings):
|
for i, tracking in enumerate(trackings):
|
||||||
j = 0
|
j = 0
|
||||||
for c, detections in detection_by_camera.items():
|
for c, detections in detection_by_camera.items():
|
||||||
camera = next(iter(detections)).camera
|
|
||||||
# pixel space, unnormalized
|
|
||||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
|
||||||
for det in detections:
|
for det in detections:
|
||||||
delta_t = det.timestamp - tracking.last_active_timestamp
|
affinity_value = calculate_tracking_detection_affinity(
|
||||||
w, h = camera.params.image_size
|
tracking,
|
||||||
distance_2d = calculate_distance_2d(
|
det,
|
||||||
tracking_2d_projection,
|
|
||||||
det.keypoints,
|
|
||||||
image_size=(w, h),
|
|
||||||
)
|
|
||||||
affinity_2d = calculate_affinity_2d(
|
|
||||||
distance_2d,
|
|
||||||
delta_t,
|
|
||||||
w_2d=W_2D,
|
w_2d=W_2D,
|
||||||
alpha_2d=ALPHA_2D,
|
alpha_2d=ALPHA_2D,
|
||||||
lambda_a=LAMBDA_A,
|
|
||||||
)
|
|
||||||
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|
||||||
det, tracking, delta_t
|
|
||||||
)
|
|
||||||
affinity_3d = calculate_affinity_3d(
|
|
||||||
distances,
|
|
||||||
delta_t,
|
|
||||||
w_3d=W_3D,
|
w_3d=W_3D,
|
||||||
alpha_3d=ALPHA_3D,
|
alpha_3d=ALPHA_3D,
|
||||||
lambda_a=LAMBDA_A,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
affinity_sum = affinity_2d + affinity_3d
|
affinity = affinity.at[i, j].set(affinity_value)
|
||||||
affinity = affinity.at[i, j].set(affinity_sum)
|
|
||||||
j += 1
|
j += 1
|
||||||
display(affinity)
|
display(affinity)
|
||||||
|
|||||||
Reference in New Issue
Block a user