1
0
forked from HQU-gxy/CVTH3PE

refactor: Revamp LeastMeanSquareVelocityFilter to utilize historical 3D poses

- Replaced the historical detections mechanism with deques for managing historical 3D poses and timestamps, enhancing performance and memory efficiency.
- Updated the constructor to accept historical data directly, ensuring proper initialization and sorting of poses and timestamps.
- Refined the `predict` and `update` methods to work with the new data structure, improving clarity and functionality.
- Enhanced error handling to ensure robustness when no historical data is available for predictions.
This commit is contained in:
2025-05-03 14:31:59 +08:00
parent c31cc4e7bf
commit d2c1c8d624

View File

@ -1,4 +1,5 @@
import weakref import weakref
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from itertools import chain from itertools import chain
@ -145,73 +146,55 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
a velocity filter that uses the least mean square method to estimate the velocity a velocity filter that uses the least mean square method to estimate the velocity
""" """
_get_historical_detections: Callable[[], Sequence[Detection]] _historical_3d_poses: deque[Float[Array, "J 3"]]
""" _historical_timestamps: deque[datetime]
get the current historical detections, assuming the detections are sorted by
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
the newest detection)
"""
_velocity: Optional[Float[Array, "J 3"]] = None _velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): def __init__(
self._get_historical_detections = get_historical_detections self,
self._velocity = None historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
@staticmethod max_samples: int = 10,
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": ):
""" assert len(historical_3d_poses) == len(historical_timestamps)
create a LeastMeanSquareVelocityFilter from a Tracking object temp = zip(historical_3d_poses, historical_timestamps)
temp_sorted = sorted(temp, key=lambda x: x[1])
Note that this function is using a weak reference to the tracking object, self._historical_3d_poses = deque(
so that the tracking object can be garbage collected if there are no other map(lambda x: x[0], temp_sorted), maxlen=max_samples
references to it. )
""" self._historical_timestamps = deque(
# Create a weak reference to avoid circular references map(lambda x: x[1], temp_sorted), maxlen=max_samples
# https://docs.python.org/3/library/weakref.html )
tracking_ref = weakref.ref(tracking) self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
# Create a getter function that uses the weak reference self._velocity = None
def get_historical_detections() -> Sequence[Detection]:
tr = tracking_ref()
if tr is None:
return [] # Return empty list if tracking has been garbage collected
return tr.state.historical_detections
velocity = tracking.velocity_filter.get()["velocity"]
if velocity is None:
return LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
)
else: else:
f = LeastMeanSquareVelocityFilter( self._update(
get_historical_detections=get_historical_detections jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
) )
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def predict(self, timestamp: datetime) -> TrackingPrediction: def predict(self, timestamp: datetime) -> TrackingPrediction:
historical_detections = self._get_historical_detections() if not self._historical_3d_poses:
if not historical_detections: raise ValueError("No historical 3D poses available for prediction")
raise ValueError("No historical detections available for prediction")
# Use the latest historical detection # use the latest historical detection
latest_detection = historical_detections[-1] latest_3d_pose = self._historical_3d_poses[-1]
latest_keypoints = latest_detection.keypoints latest_timestamp = self._historical_timestamps[-1]
latest_timestamp = latest_detection.timestamp
delta_t_s = (timestamp - latest_timestamp).total_seconds() delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None: if self._velocity is None:
return TrackingPrediction( return TrackingPrediction(
velocity=None, velocity=None,
keypoints=latest_keypoints, keypoints=latest_3d_pose,
) )
else: else:
# Linear motion model: ẋt = xt' + Vt' · (t - t') # Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction( return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_keypoints velocity=self._velocity, keypoints=predicted_3d_pose
) )
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -253,47 +236,37 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
self._velocity = velocities self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections() last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
if not historical_detections: # deque would manage the maxlen automatically
self._velocity = jnp.zeros_like(keypoints) self._historical_3d_poses.append(keypoints)
return self._historical_timestamps.append(timestamp)
t_0 = min(d.timestamp for d in historical_detections) t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
all_keypoints = jnp.array( def timestamp_to_seconds(timestamp: datetime) -> float:
list(chain((d.keypoints for d in historical_detections), (keypoints,))) assert t_0 <= timestamp
) return (timestamp - t_0).total_seconds()
# Timestamps relative to t_0 (the oldest detection timestamp) # timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array( all_timestamps = jnp.array(
list( map(timestamp_to_seconds, self._historical_timestamps)
chain(
(
(d.timestamp - t_0).total_seconds()
for d in historical_detections
),
((timestamp - t_0).total_seconds(),),
)
)
) )
self._update(all_keypoints, all_timestamps) self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction: def get(self) -> TrackingPrediction:
historical_detections = self._get_historical_detections() if not self._historical_3d_poses:
if not historical_detections: raise ValueError("No historical 3D poses available")
raise ValueError("No historical detections available")
latest_detection = historical_detections[-1] latest_3d_pose = self._historical_3d_poses[-1]
latest_keypoints = latest_detection.keypoints
if self._velocity is None: if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_keypoints) return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else: else:
return TrackingPrediction( return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
velocity=self._velocity, keypoints=latest_keypoints
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -393,7 +366,7 @@ class Tracking:
""" """
# pylint: disable-next=unsubscriptable-object # pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None: if (vel := self.velocity_filter.get()["velocity"]) is None:
raise ValueError("Velocity is not available") return jnp.zeros_like(self.state.keypoints)
else: else:
return vel return vel