1
0
forked from HQU-gxy/CVTH3PE

refactor: Enhance tracking state management and velocity filter integration

- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability.
- Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management.
- Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase.
- Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations.
- Cleaned up imports and improved type hints for better code organization.
This commit is contained in:
2025-05-02 12:44:58 +08:00
parent 46b8518a10
commit c31cc4e7bf
2 changed files with 111 additions and 56 deletions

View File

@ -58,7 +58,12 @@ from app.camera import (
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
from app.tracking import (
AffinityResult,
LastDifferenceVelocityFilter,
Tracking,
TrackingState,
)
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
@ -543,11 +548,14 @@ class GlobalTrackingState:
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking = Tracking(
id=next_id,
tracking_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
)
self._trackings[next_id] = tracking
@ -753,12 +761,12 @@ def calculate_tracking_detection_affinity(
Combined affinity score
"""
camera = detection.camera
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints)
tracking_2d_projection = camera.project(tracking.state.keypoints)
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
@ -838,7 +846,7 @@ def calculate_camera_affinity_matrix_jax(
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings]
[trk.state.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
# === Detection-side tensors ===
@ -855,12 +863,12 @@ def calculate_camera_affinity_matrix_jax(
# --- timestamps ----------
t0 = min(
chain(
(trk.last_active_timestamp for trk in trackings),
(trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections),
)
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
@ -1032,7 +1040,5 @@ display(affinities)
# %%
def update_tracking(tracking: Tracking, detection: Detection):
delta_t_ = detection.timestamp - tracking.last_active_timestamp
delta_t = max(delta_t_, DELTA_T_MIN)
return tracking
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp
raise NotImplementedError