165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
from enum import Enum
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from jaxtyping import Num, jaxtyped
|
|
from loguru import logger
|
|
from typeguard import typechecked
|
|
|
|
from app.typing import BoundingBoxFormat, NDArray
|
|
|
|
from .bboxes_tracker import BoxTracker, BoxTrackerConfig, BoxTracking
|
|
|
|
|
|
@jaxtyped(typechecker=typechecked)
|
|
def bounding_box_area(tracker: Num[NDArray, "N 4"], format: BoundingBoxFormat) -> float:
|
|
if format == "xyxy":
|
|
return float(
|
|
np.mean((tracker[:, 2] - tracker[:, 0]) * (tracker[:, 3] - tracker[:, 1]))
|
|
)
|
|
elif format == "xywh":
|
|
return float(np.mean(tracker[:, 2] * tracker[:, 3]))
|
|
else:
|
|
raise ValueError(f"Unknown bounding box format: {format}")
|
|
|
|
|
|
class TrackingIdType(Enum):
|
|
Overridden = "overridden"
|
|
Selected = "selected"
|
|
General = "general"
|
|
|
|
|
|
TrackingId = Tuple[int, TrackingIdType]
|
|
|
|
|
|
def find_suitable_tracking_id(
|
|
tracking: list[BoxTracking], format: BoundingBoxFormat
|
|
) -> Optional[int]:
|
|
if len(tracking) == 0:
|
|
return None
|
|
elif len(tracking) == 1:
|
|
return tracking[0].id
|
|
else:
|
|
i = np.argmax(
|
|
[
|
|
bounding_box_area(tracker.last_n_bounding_boxes, format)
|
|
for tracker in tracking
|
|
]
|
|
)
|
|
return tracking[i].id
|
|
|
|
|
|
class SingleObjectTracker:
|
|
_tracker: BoxTracker
|
|
_overridden_tracking_id: Optional[int] = None
|
|
_selected_tracking_id: Optional[int] = None
|
|
_bounding_box_format: BoundingBoxFormat
|
|
_on_lost_tracking: Optional[Callable[[BoxTracking], None]] = None
|
|
"""
|
|
(tracking, tracking_id) -> None
|
|
"""
|
|
_on_tracking_acquired: Optional[Callable[[list[BoxTracking], TrackingId], None]] = (
|
|
None
|
|
)
|
|
"""
|
|
(trackings, tracking_id) -> None
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tracker_param: BoxTrackerConfig,
|
|
bounding_box_format: BoundingBoxFormat = "xyxy",
|
|
):
|
|
self._selected_tracking_id = None
|
|
self._tracker = BoxTracker(tracker_param, bounding_box_format)
|
|
|
|
def reset(self):
|
|
self._tracker.reset()
|
|
|
|
@property
|
|
def confirmed(self):
|
|
return self._tracker.confirmed_trackings
|
|
|
|
@property
|
|
def confirmed_trackings(self):
|
|
"""
|
|
alias of `confirmed`
|
|
"""
|
|
return self.confirmed
|
|
|
|
def get_by_id(
|
|
self, tracking_id: int, trackings: list[BoxTracking]
|
|
) -> Optional[BoxTracking]:
|
|
assert tracking_id is not None
|
|
try:
|
|
return next(filter(lambda x: x.id == tracking_id, trackings))
|
|
except StopIteration:
|
|
return None
|
|
|
|
def try_get_by_overridden_id(self) -> Optional[BoxTracking]:
|
|
"""
|
|
If successfully get the tracking, mutate self._selected_tracking_id.
|
|
Otherwise, set self._overridden_tracking_id to None.
|
|
"""
|
|
overridden_id = self._overridden_tracking_id
|
|
if overridden_id is None:
|
|
return None
|
|
sel: Optional[BoxTracking] = self.get_by_id(overridden_id, self.confirmed)
|
|
if sel is None:
|
|
self._overridden_tracking_id = None
|
|
logger.trace(
|
|
"Overridden tracking id {} not found in {}",
|
|
overridden_id,
|
|
self.confirmed,
|
|
)
|
|
else:
|
|
if (
|
|
self._selected_tracking_id is None
|
|
or self._selected_tracking_id != overridden_id
|
|
):
|
|
self._selected_tracking_id = overridden_id
|
|
logger.info("Acquired tracking id {} by override", overridden_id)
|
|
if self._on_tracking_acquired is not None:
|
|
self._on_tracking_acquired(
|
|
self.confirmed, (overridden_id, TrackingIdType.Overridden)
|
|
)
|
|
return sel
|
|
|
|
def try_get_by_selected_id(self) -> Optional[BoxTracking]:
|
|
"""
|
|
If no selected tracking, find the one with `find_suitable_tracking_id`.
|
|
"""
|
|
selected_id = self._selected_tracking_id
|
|
if selected_id is None:
|
|
selected_id = find_suitable_tracking_id(
|
|
self.confirmed, self._bounding_box_format
|
|
)
|
|
if selected_id is None:
|
|
return None
|
|
sel: Optional[BoxTracking] = self.get_by_id(selected_id, self.confirmed)
|
|
if sel is None:
|
|
self._selected_tracking_id = None
|
|
logger.warning(
|
|
"Selected tracking id {} not found in {}", selected_id, self.confirmed
|
|
)
|
|
else:
|
|
if (
|
|
self._selected_tracking_id is None
|
|
or self._selected_tracking_id != selected_id
|
|
):
|
|
self._selected_tracking_id = selected_id
|
|
logger.info("Acquired tracking id {}", selected_id)
|
|
if self._on_tracking_acquired is not None:
|
|
self._on_tracking_acquired(
|
|
self.confirmed, (selected_id, TrackingIdType.Selected)
|
|
)
|
|
return sel
|
|
|
|
@property
|
|
def bounding_box_format(self) -> BoundingBoxFormat:
|
|
return self._tracker.bounding_box_format
|
|
|
|
def next_measurements(self, boxes: Num[NDArray, "N 4"]):
|
|
self._tracker.next_measurements(boxes)
|
|
return self.confirmed
|