forked from HQU-gxy/CVTH3PE
refactor: Enhance tracking state management and velocity filter integration
- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability. - Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management. - Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase. - Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations. - Cleaned up imports and improved type hints for better code organization.
This commit is contained in:
@ -58,7 +58,12 @@ from app.camera import (
|
||||
classify_by_camera,
|
||||
)
|
||||
from app.solver._old import GLPKSolver
|
||||
from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
|
||||
from app.tracking import (
|
||||
AffinityResult,
|
||||
LastDifferenceVelocityFilter,
|
||||
Tracking,
|
||||
TrackingState,
|
||||
)
|
||||
from app.visualize.whole_body import visualize_whole_body
|
||||
|
||||
NDArray: TypeAlias = np.ndarray
|
||||
@ -543,11 +548,14 @@ class GlobalTrackingState:
|
||||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||
next_id = self._last_id + 1
|
||||
tracking = Tracking(
|
||||
id=next_id,
|
||||
tracking_state = TrackingState(
|
||||
keypoints=kps_3d,
|
||||
last_active_timestamp=latest_timestamp,
|
||||
historical_detections=v(*cluster),
|
||||
)
|
||||
tracking = Tracking(
|
||||
id=next_id,
|
||||
state=tracking_state,
|
||||
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||||
)
|
||||
self._trackings[next_id] = tracking
|
||||
@ -753,12 +761,12 @@ def calculate_tracking_detection_affinity(
|
||||
Combined affinity score
|
||||
"""
|
||||
camera = detection.camera
|
||||
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
||||
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
|
||||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||
|
||||
# Calculate 2D affinity
|
||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||
tracking_2d_projection = camera.project(tracking.state.keypoints)
|
||||
w, h = camera.params.image_size
|
||||
distance_2d = calculate_distance_2d(
|
||||
tracking_2d_projection,
|
||||
@ -838,7 +846,7 @@ def calculate_camera_affinity_matrix_jax(
|
||||
|
||||
# === Tracking-side tensors ===
|
||||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||||
[trk.keypoints for trk in trackings]
|
||||
[trk.state.keypoints for trk in trackings]
|
||||
) # (T, J, 3)
|
||||
J = kps3d_trk.shape[1]
|
||||
# === Detection-side tensors ===
|
||||
@ -855,12 +863,12 @@ def calculate_camera_affinity_matrix_jax(
|
||||
# --- timestamps ----------
|
||||
t0 = min(
|
||||
chain(
|
||||
(trk.last_active_timestamp for trk in trackings),
|
||||
(trk.state.last_active_timestamp for trk in trackings),
|
||||
(det.timestamp for det in camera_detections),
|
||||
)
|
||||
).timestamp() # common origin (float)
|
||||
ts_trk = jnp.array(
|
||||
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||||
)
|
||||
ts_det = jnp.array(
|
||||
@ -1032,7 +1040,5 @@ display(affinities)
|
||||
|
||||
# %%
|
||||
def update_tracking(tracking: Tracking, detection: Detection):
|
||||
delta_t_ = detection.timestamp - tracking.last_active_timestamp
|
||||
delta_t = max(delta_t_, DELTA_T_MIN)
|
||||
|
||||
return tracking
|
||||
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user