init
This commit is contained in:
164
app/tracker/single_object_tracker.py
Normal file
164
app/tracker/single_object_tracker.py
Normal file
@ -0,0 +1,164 @@
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from jaxtyping import Num, jaxtyped
|
||||
from loguru import logger
|
||||
from typeguard import typechecked
|
||||
|
||||
from app.typing import BoundingBoxFormat, NDArray
|
||||
|
||||
from .bboxes_tracker import BoxTracker, BoxTrackerConfig, BoxTracking
|
||||
|
||||
|
||||
@jaxtyped(typechecker=typechecked)
|
||||
def bounding_box_area(tracker: Num[NDArray, "N 4"], format: BoundingBoxFormat) -> float:
|
||||
if format == "xyxy":
|
||||
return float(
|
||||
np.mean((tracker[:, 2] - tracker[:, 0]) * (tracker[:, 3] - tracker[:, 1]))
|
||||
)
|
||||
elif format == "xywh":
|
||||
return float(np.mean(tracker[:, 2] * tracker[:, 3]))
|
||||
else:
|
||||
raise ValueError(f"Unknown bounding box format: {format}")
|
||||
|
||||
|
||||
class TrackingIdType(Enum):
|
||||
Overridden = "overridden"
|
||||
Selected = "selected"
|
||||
General = "general"
|
||||
|
||||
|
||||
TrackingId = Tuple[int, TrackingIdType]
|
||||
|
||||
|
||||
def find_suitable_tracking_id(
|
||||
tracking: list[BoxTracking], format: BoundingBoxFormat
|
||||
) -> Optional[int]:
|
||||
if len(tracking) == 0:
|
||||
return None
|
||||
elif len(tracking) == 1:
|
||||
return tracking[0].id
|
||||
else:
|
||||
i = np.argmax(
|
||||
[
|
||||
bounding_box_area(tracker.last_n_bounding_boxes, format)
|
||||
for tracker in tracking
|
||||
]
|
||||
)
|
||||
return tracking[i].id
|
||||
|
||||
|
||||
class SingleObjectTracker:
|
||||
_tracker: BoxTracker
|
||||
_overridden_tracking_id: Optional[int] = None
|
||||
_selected_tracking_id: Optional[int] = None
|
||||
_bounding_box_format: BoundingBoxFormat
|
||||
_on_lost_tracking: Optional[Callable[[BoxTracking], None]] = None
|
||||
"""
|
||||
(tracking, tracking_id) -> None
|
||||
"""
|
||||
_on_tracking_acquired: Optional[Callable[[list[BoxTracking], TrackingId], None]] = (
|
||||
None
|
||||
)
|
||||
"""
|
||||
(trackings, tracking_id) -> None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tracker_param: BoxTrackerConfig,
|
||||
bounding_box_format: BoundingBoxFormat = "xyxy",
|
||||
):
|
||||
self._selected_tracking_id = None
|
||||
self._tracker = BoxTracker(tracker_param, bounding_box_format)
|
||||
|
||||
def reset(self):
|
||||
self._tracker.reset()
|
||||
|
||||
@property
|
||||
def confirmed(self):
|
||||
return self._tracker.confirmed_trackings
|
||||
|
||||
@property
|
||||
def confirmed_trackings(self):
|
||||
"""
|
||||
alias of `confirmed`
|
||||
"""
|
||||
return self.confirmed
|
||||
|
||||
def get_by_id(
|
||||
self, tracking_id: int, trackings: list[BoxTracking]
|
||||
) -> Optional[BoxTracking]:
|
||||
assert tracking_id is not None
|
||||
try:
|
||||
return next(filter(lambda x: x.id == tracking_id, trackings))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def try_get_by_overridden_id(self) -> Optional[BoxTracking]:
|
||||
"""
|
||||
If successfully get the tracking, mutate self._selected_tracking_id.
|
||||
Otherwise, set self._overridden_tracking_id to None.
|
||||
"""
|
||||
overridden_id = self._overridden_tracking_id
|
||||
if overridden_id is None:
|
||||
return None
|
||||
sel: Optional[BoxTracking] = self.get_by_id(overridden_id, self.confirmed)
|
||||
if sel is None:
|
||||
self._overridden_tracking_id = None
|
||||
logger.trace(
|
||||
"Overridden tracking id {} not found in {}",
|
||||
overridden_id,
|
||||
self.confirmed,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self._selected_tracking_id is None
|
||||
or self._selected_tracking_id != overridden_id
|
||||
):
|
||||
self._selected_tracking_id = overridden_id
|
||||
logger.info("Acquired tracking id {} by override", overridden_id)
|
||||
if self._on_tracking_acquired is not None:
|
||||
self._on_tracking_acquired(
|
||||
self.confirmed, (overridden_id, TrackingIdType.Overridden)
|
||||
)
|
||||
return sel
|
||||
|
||||
def try_get_by_selected_id(self) -> Optional[BoxTracking]:
|
||||
"""
|
||||
If no selected tracking, find the one with `find_suitable_tracking_id`.
|
||||
"""
|
||||
selected_id = self._selected_tracking_id
|
||||
if selected_id is None:
|
||||
selected_id = find_suitable_tracking_id(
|
||||
self.confirmed, self._bounding_box_format
|
||||
)
|
||||
if selected_id is None:
|
||||
return None
|
||||
sel: Optional[BoxTracking] = self.get_by_id(selected_id, self.confirmed)
|
||||
if sel is None:
|
||||
self._selected_tracking_id = None
|
||||
logger.warning(
|
||||
"Selected tracking id {} not found in {}", selected_id, self.confirmed
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self._selected_tracking_id is None
|
||||
or self._selected_tracking_id != selected_id
|
||||
):
|
||||
self._selected_tracking_id = selected_id
|
||||
logger.info("Acquired tracking id {}", selected_id)
|
||||
if self._on_tracking_acquired is not None:
|
||||
self._on_tracking_acquired(
|
||||
self.confirmed, (selected_id, TrackingIdType.Selected)
|
||||
)
|
||||
return sel
|
||||
|
||||
@property
|
||||
def bounding_box_format(self) -> BoundingBoxFormat:
|
||||
return self._tracker.bounding_box_format
|
||||
|
||||
def next_measurements(self, boxes: Num[NDArray, "N 4"]):
|
||||
self._tracker.next_measurements(boxes)
|
||||
return self.confirmed
|
||||
Reference in New Issue
Block a user