555 lines
18 KiB
Python
555 lines
18 KiB
Python
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from typing import (
|
|
Callable,
|
|
Generator,
|
|
Optional,
|
|
Tuple,
|
|
TypedDict,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from loguru import logger
|
|
import numpy as np
|
|
from jaxtyping import Float, Int, Num, jaxtyped
|
|
from pydantic import BaseModel
|
|
from scipy.optimize import linear_sum_assignment
|
|
from typeguard import typechecked
|
|
|
|
from app.typing import BoundingBoxFormat
|
|
from app.typing.constant import AREA_FILTER_THRESHOLD
|
|
|
|
|
|
class BoxTrackerConfig(BaseModel):
|
|
dt: float = 1.0
|
|
cov_threshold: float = 4.0
|
|
tentative_mahalanobis_threshold: float = 10.0
|
|
confirm_mahalanobis_threshold: float = 10.0
|
|
forming_tracks_euclidean_threshold: float = 25.0
|
|
survival_steps_threshold: int = 3
|
|
max_preserved_history_bounding_boxes: int = 10
|
|
|
|
@staticmethod
|
|
def default() -> "BoxTrackerConfig":
|
|
return BoxTrackerConfig(
|
|
dt=1.0,
|
|
cov_threshold=4.0,
|
|
tentative_mahalanobis_threshold=10.0,
|
|
confirm_mahalanobis_threshold=10.0,
|
|
forming_tracks_euclidean_threshold=25.0,
|
|
survival_steps_threshold=3,
|
|
max_preserved_history_bounding_boxes=10,
|
|
)
|
|
|
|
|
|
from . import (
|
|
CvModelGaussianState,
|
|
GaussianState,
|
|
LinearMeasurementModel,
|
|
LinearMotionNoInputModel,
|
|
NDArray,
|
|
PosterioriResult,
|
|
outer_distance,
|
|
predict,
|
|
update,
|
|
)
|
|
|
|
|
|
class TrackingState(Enum):
|
|
Tentative = auto()
|
|
Confirmed = auto()
|
|
|
|
|
|
class BoxTrackingDict(TypedDict):
|
|
id: int
|
|
bounding_box: NDArray
|
|
state_x: NDArray
|
|
state_P: NDArray
|
|
|
|
|
|
@dataclass
|
|
class BoxTracking:
|
|
id: int
|
|
state: CvModelGaussianState
|
|
survived_time_steps: int
|
|
missed_time_steps: int
|
|
last_n_bounding_boxes: Num[NDArray, "N 4"]
|
|
"""
|
|
History of bounding boxes in a sliding window, with the latest one at the end.
|
|
The window size is determined by the `max_preserved_history_bounding_boxes` parameter.
|
|
"""
|
|
|
|
@property
|
|
def last_bounding_box(self) -> Num[NDArray, "4"]:
|
|
b = cast(NDArray, self.last_n_bounding_boxes[-1])
|
|
assert b.shape == (4,)
|
|
return b
|
|
|
|
def to_dict(self) -> BoxTrackingDict:
|
|
return {
|
|
"id": self.id,
|
|
"bounding_box": self.last_bounding_box,
|
|
"state_x": self.state.x,
|
|
"state_P": self.state.P,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class CreateTrackingEvent:
|
|
group: TrackingState
|
|
id: int
|
|
tracking: BoxTracking
|
|
|
|
|
|
@dataclass
|
|
class RemoveTrackingEvent:
|
|
group: TrackingState
|
|
id: int
|
|
tracking: BoxTracking
|
|
|
|
|
|
@dataclass
|
|
class MatchedTrackingEvent:
|
|
group: TrackingState
|
|
id: int
|
|
matched_bounding_box: Num[NDArray, "4"]
|
|
|
|
|
|
TrackingEvent = Union[CreateTrackingEvent, RemoveTrackingEvent, MatchedTrackingEvent]
|
|
|
|
|
|
def bounding_boxes_to_center(
|
|
bounding_boxes: Num[NDArray, "N 4"], format: BoundingBoxFormat
|
|
) -> Num[NDArray, "N 2"]:
|
|
if format == "xyxy":
|
|
return (bounding_boxes[:, :2] + bounding_boxes[:, 2:]) / 2
|
|
if format == "xywh":
|
|
return bounding_boxes[:, :2] + (bounding_boxes[:, 2:] / 2)
|
|
raise ValueError(f"Unsupported bounding box format: {format}")
|
|
|
|
|
|
def bounding_box_to_center(
|
|
bounding_box: Num[NDArray, "4"], format: BoundingBoxFormat
|
|
) -> Num[NDArray, "2"]:
|
|
if format == "xyxy":
|
|
return (bounding_box[:2] + bounding_box[2:]) / 2
|
|
if format == "xywh":
|
|
return bounding_box[:2] + (bounding_box[2:] / 2)
|
|
raise ValueError(f"Unsupported bounding box format: {format}")
|
|
|
|
|
|
def bounding_boxes_area(
|
|
bounding_boxes: Num[NDArray, "N 4"], format: BoundingBoxFormat
|
|
) -> Num[NDArray, "N"]:
|
|
if format == "xyxy":
|
|
return (bounding_boxes[:, 2] - bounding_boxes[:, 0]) * (
|
|
bounding_boxes[:, 3] - bounding_boxes[:, 1]
|
|
)
|
|
if format == "xywh":
|
|
return bounding_boxes[:, 2] * bounding_boxes[:, 3]
|
|
raise ValueError(f"Unsupported bounding box format: {format}")
|
|
|
|
|
|
class BoxTracker:
|
|
"""
|
|
A simple GNN tracker, but for tracking targets with bounding boxes
|
|
|
|
TODO: use score to help data association
|
|
"""
|
|
|
|
_last_measurements: NDArray = np.empty((0, 2), dtype=np.float32)
|
|
_tentative_tracks: list[BoxTracking] = []
|
|
_confirmed_tracks: list[BoxTracking] = []
|
|
_last_id: int = 0
|
|
_params: BoxTrackerConfig
|
|
_bounding_boxes_format: BoundingBoxFormat
|
|
|
|
def __init__(
|
|
self,
|
|
params: BoxTrackerConfig,
|
|
bounding_boxes_format: BoundingBoxFormat,
|
|
):
|
|
self._last_measurements = np.empty((0, 2), dtype=np.float32)
|
|
self._tentative_tracks = []
|
|
self._confirmed_tracks = []
|
|
self._last_id = 0
|
|
self._params = params
|
|
self._bounding_boxes_format = bounding_boxes_format
|
|
|
|
def reset(self):
|
|
self._last_id = 0
|
|
self._last_measurements = np.empty((0, 2), dtype=np.float32)
|
|
self._tentative_tracks = []
|
|
self._confirmed_tracks = []
|
|
|
|
def _push_new_bounding_box(
|
|
self, old_bbs: Num[NDArray, "N 4"], new_bb: Num[NDArray, "4"]
|
|
) -> Num[NDArray, "N 4"]:
|
|
bbs = np.append(old_bbs, np.expand_dims(new_bb, axis=0), axis=0)
|
|
if bbs.shape[0] > self._params.max_preserved_history_bounding_boxes:
|
|
bbs = bbs[-self._params.max_preserved_history_bounding_boxes :]
|
|
return bbs
|
|
|
|
def _predict(self, tracks: list[BoxTracking], dt: float = 1.0):
|
|
def _predict_one(track: BoxTracking):
|
|
new_st = predict(track.state, BoxTracker.motion_model(dt=dt))
|
|
o_cx, o_cy = bounding_box_to_center(
|
|
track.last_bounding_box, self._bounding_boxes_format
|
|
)
|
|
n_cx, n_cy, _v_x, _v_y = new_st.x
|
|
|
|
delta_x, delta_y = n_cx - o_cx, n_cy - o_cy
|
|
if self._bounding_boxes_format == "xyxy":
|
|
x_0, y_0, x_1, y_1 = track.last_bounding_box
|
|
new_bb = np.array(
|
|
[x_0 + delta_x, y_0 + delta_y, x_1 + delta_x, y_1 + delta_y]
|
|
)
|
|
elif self._bounding_boxes_format == "xywh":
|
|
x_0, y_0, w, h = track.last_bounding_box
|
|
new_bb = np.array([x_0 + delta_x - w / 2, y_0 + delta_y - h / 2, w, h])
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported bounding box format: {self._bounding_boxes_format}"
|
|
)
|
|
new_bbs = self._push_new_bounding_box(track.last_n_bounding_boxes, new_bb)
|
|
return BoxTracking(
|
|
id=track.id,
|
|
state=CvModelGaussianState.from_gaussian(new_st),
|
|
survived_time_steps=track.survived_time_steps,
|
|
missed_time_steps=track.missed_time_steps,
|
|
last_n_bounding_boxes=new_bbs,
|
|
)
|
|
|
|
return [_predict_one(track) for track in tracks]
|
|
|
|
@jaxtyped(typechecker=typechecked)
|
|
def _data_associate_and_update(
|
|
self,
|
|
select_array: TrackingState,
|
|
measurements: Num[NDArray, "N 2"],
|
|
bounding_boxes: Num[NDArray, "N 4"],
|
|
) -> Tuple[list[MatchedTrackingEvent], Num[NDArray, "M 2"], Num[NDArray, "M 4"]]:
|
|
"""
|
|
Match tracks with measurements and update the tracks
|
|
|
|
Parameters
|
|
----------
|
|
[in] measurements: Float["a 2"]
|
|
[in,out] tracks: Tracking["b"] the tracking list (tentative or confirmed) to be updated (mutated in place)
|
|
|
|
Returns
|
|
----------
|
|
return
|
|
Float["... 2"] the unmatched measurements
|
|
|
|
Effect
|
|
----------
|
|
find the best match by minimum Mahalanobis distance, please note that I assume the state has been predicted
|
|
"""
|
|
evs: list[MatchedTrackingEvent] = []
|
|
assert measurements.ndim == 2
|
|
assert measurements.shape[1] == 2
|
|
|
|
assert bounding_boxes.ndim == 2
|
|
assert bounding_boxes.shape[1] == 4
|
|
|
|
assert bounding_boxes.shape[0] == measurements.shape[0]
|
|
|
|
if select_array == TrackingState.Tentative:
|
|
tracks = self._tentative_tracks
|
|
distance_threshold = self._params.tentative_mahalanobis_threshold
|
|
elif select_array == TrackingState.Confirmed:
|
|
tracks = self._confirmed_tracks
|
|
distance_threshold = self._params.confirm_mahalanobis_threshold
|
|
else:
|
|
raise ValueError("Unexpected tracking state {}".format(select_array))
|
|
|
|
if len(tracks) == 0:
|
|
return evs, measurements, bounding_boxes
|
|
|
|
def _update(measurement: NDArray, tracking: BoxTracking):
|
|
return update(measurement, tracking.state, BoxTracker.measurement_model())
|
|
|
|
def outer_posteriori(
|
|
measurements: NDArray, tracks: list[BoxTracking]
|
|
) -> list[list[PosterioriResult]]:
|
|
"""
|
|
calculate the outer posteriori for each measurement and track
|
|
|
|
Parameters
|
|
----------
|
|
[in] measurements: Float["a 2"]
|
|
[in] tracks: Tracking["b"]
|
|
|
|
Returns
|
|
----------
|
|
PosterioriResult["a b"]
|
|
"""
|
|
return [
|
|
[_update(measurement, tracking) for measurement in measurements]
|
|
for tracking in tracks
|
|
]
|
|
|
|
def posteriori_to_mahalanobis(
|
|
posteriori: list[list[PosterioriResult]],
|
|
) -> NDArray:
|
|
"""
|
|
Parameters
|
|
----------
|
|
[in] posteriori: PosterioriResult["a b"]
|
|
|
|
Returns
|
|
----------
|
|
Float["a b"]
|
|
"""
|
|
return np.array(
|
|
[[r_m.mahalanobis_distance for r_m in p_t] for p_t in posteriori],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
posteriors = outer_posteriori(measurements, tracks)
|
|
distances = posteriori_to_mahalanobis(posteriors)
|
|
row, col = linear_sum_assignment(np.array(distances))
|
|
row = np.array(row)
|
|
col = np.array(col)
|
|
|
|
def to_be_deleted() -> Generator[Tuple[int, int], None, None]:
|
|
for i, j in zip(row, col):
|
|
post: PosterioriResult = posteriors[i][j]
|
|
if post.mahalanobis_distance > distance_threshold:
|
|
yield i, j
|
|
|
|
for i, j in to_be_deleted():
|
|
row = row[row != i]
|
|
col = col[col != j]
|
|
|
|
# update matched tracks
|
|
for i, j in zip(row, col):
|
|
track = cast(BoxTracking, tracks[i])
|
|
post: PosterioriResult = posteriors[i][j]
|
|
track.state = CvModelGaussianState.from_gaussian(post.state)
|
|
track.survived_time_steps += 1
|
|
track.last_n_bounding_boxes = self._push_new_bounding_box(
|
|
track.last_n_bounding_boxes, bounding_boxes[j]
|
|
)
|
|
tracks[i] = track
|
|
evs.append(
|
|
MatchedTrackingEvent(
|
|
group=select_array,
|
|
id=track.id,
|
|
matched_bounding_box=bounding_boxes[j],
|
|
)
|
|
)
|
|
|
|
# missed tracks
|
|
# note that it just for statistical purpose
|
|
# the tracking should be removed by the covariance threshold
|
|
for i, track in enumerate(tracks):
|
|
if i not in row:
|
|
track.missed_time_steps += 1
|
|
tracks[i] = track
|
|
|
|
# remove measurements that have been matched
|
|
left_measurements = np.delete(measurements, col, axis=0)
|
|
left_bounding_boxes = np.delete(bounding_boxes, col, axis=0)
|
|
return evs, left_measurements, left_bounding_boxes
|
|
|
|
@jaxtyped(typechecker=typechecked)
|
|
def _tracks_from_past_measurements(
|
|
self,
|
|
measurements: Num[NDArray, "N 2"],
|
|
bounding_boxes: Num[NDArray, "N 4"],
|
|
dt: float = 1.0,
|
|
distance_threshold: float = 3.0,
|
|
):
|
|
"""
|
|
consume the last measurements and create tentative tracks from them
|
|
|
|
Note
|
|
----
|
|
mutate self._tentative_tracks and self._last_measurements
|
|
"""
|
|
evs: list[CreateTrackingEvent] = []
|
|
if self._last_measurements.shape[0] == 0:
|
|
self._last_measurements = measurements
|
|
return evs
|
|
distances = outer_distance(self._last_measurements, measurements)
|
|
row, col = linear_sum_assignment(distances)
|
|
row = np.array(row)
|
|
col = np.array(col)
|
|
|
|
def to_be_deleted() -> Generator[Tuple[int, int], None, None]:
|
|
for i, j in zip(row, col):
|
|
euclidean_distance = distances[i, j]
|
|
if euclidean_distance > distance_threshold:
|
|
yield i, j
|
|
|
|
for i, j in to_be_deleted():
|
|
row = row[row != i]
|
|
col = col[col != j]
|
|
|
|
for i, j in zip(row, col):
|
|
coord = measurements[j]
|
|
vel = (coord - self._last_measurements[i]) / dt
|
|
s = np.concatenate([coord, vel])
|
|
state = GaussianState(x=s, P=np.eye(4))
|
|
track = BoxTracking(
|
|
id=self._last_id,
|
|
state=CvModelGaussianState.from_gaussian(state),
|
|
survived_time_steps=0,
|
|
missed_time_steps=0,
|
|
last_n_bounding_boxes=np.expand_dims(bounding_boxes[j], axis=0),
|
|
)
|
|
self._last_id += 1
|
|
self._tentative_tracks.append(track)
|
|
evs.append(
|
|
CreateTrackingEvent(
|
|
group=TrackingState.Tentative, id=track.id, tracking=track
|
|
)
|
|
)
|
|
# update the last measurements with the unmatched measurements
|
|
self._last_measurements = np.delete(measurements, col, axis=0)
|
|
return evs
|
|
|
|
def _transfer_tentative_to_confirmed(self, survival_steps_threshold: int = 3):
|
|
"""
|
|
transfer tentative tracks to confirmed tracks
|
|
|
|
Note
|
|
----
|
|
mutate self._tentative_tracks and self._confirmed_tracks in place
|
|
"""
|
|
evs: list[CreateTrackingEvent] = []
|
|
for i, track in enumerate(self._tentative_tracks):
|
|
if track.survived_time_steps > survival_steps_threshold:
|
|
self._confirmed_tracks.append(track)
|
|
self._tentative_tracks.pop(i)
|
|
evs.append(
|
|
CreateTrackingEvent(
|
|
group=TrackingState.Confirmed, id=track.id, tracking=track
|
|
)
|
|
)
|
|
return evs
|
|
|
|
def _track_cov_deleter(
|
|
self, track_to_use: TrackingState, cov_threshold: float = 4.0
|
|
):
|
|
"""
|
|
delete tracks with covariance trace greater than threshold
|
|
|
|
Parameters
|
|
----------
|
|
[in,out] tracks: list[BoxTracking]
|
|
cov_threshold: float
|
|
the threshold of the covariance trace
|
|
|
|
Returns
|
|
----------
|
|
list[BoxTracking]
|
|
the deleted tracks
|
|
|
|
Note
|
|
----
|
|
mutate tracks in place
|
|
"""
|
|
if track_to_use == TrackingState.Tentative:
|
|
tracks = self._tentative_tracks
|
|
elif track_to_use == TrackingState.Confirmed:
|
|
tracks = self._confirmed_tracks
|
|
else:
|
|
raise ValueError("Unexpected tracking state {}".format(track_to_use))
|
|
ret: list[RemoveTrackingEvent] = []
|
|
for i, track in enumerate(tracks):
|
|
# https://numpy.org/doc/stable/reference/generated/numpy.trace.html
|
|
if np.trace(track.state.P) > cov_threshold:
|
|
tracks.pop(i)
|
|
ret.append(
|
|
RemoveTrackingEvent(group=track_to_use, id=track.id, tracking=track)
|
|
)
|
|
return ret
|
|
|
|
def next_measurements(
|
|
self,
|
|
bounding_boxes: Num[NDArray, "N 4"],
|
|
):
|
|
evs: list[TrackingEvent]
|
|
areas = bounding_boxes_area(bounding_boxes, self._bounding_boxes_format)
|
|
# 10 x 10 is too small for a normal bounding box
|
|
# filter out
|
|
# TODO: use area as gating threshold
|
|
if any(areas <= AREA_FILTER_THRESHOLD):
|
|
logger.trace(
|
|
"too small bounding boxes; bboxes={}; areas={}",
|
|
bounding_boxes,
|
|
areas,
|
|
)
|
|
bounding_boxes = np.delete(
|
|
bounding_boxes, np.where(areas <= AREA_FILTER_THRESHOLD), axis=0
|
|
)
|
|
|
|
measurements = bounding_boxes_to_center(
|
|
bounding_boxes, self._bounding_boxes_format
|
|
)
|
|
self._confirmed_tracks = self._predict(self._confirmed_tracks, self._params.dt)
|
|
self._tentative_tracks = self._predict(self._tentative_tracks, self._params.dt)
|
|
c_evs, c_left_m, c_left_bb = (
|
|
self._data_associate_and_update( # pylint: disable=E1102
|
|
TrackingState.Confirmed, measurements, bounding_boxes
|
|
)
|
|
)
|
|
t_evs, t_left_m, t_left_bb = (
|
|
self._data_associate_and_update( # pylint: disable=E1102
|
|
TrackingState.Tentative, c_left_m, c_left_bb
|
|
)
|
|
)
|
|
create_c_evs = self._transfer_tentative_to_confirmed(
|
|
self._params.survival_steps_threshold
|
|
)
|
|
# target initialize
|
|
create_t_evs = self._tracks_from_past_measurements( # pylint: disable=E1102
|
|
t_left_m,
|
|
t_left_bb,
|
|
self._params.dt,
|
|
self._params.forming_tracks_euclidean_threshold,
|
|
)
|
|
del_t_evs = self._track_cov_deleter(
|
|
TrackingState.Tentative, self._params.cov_threshold
|
|
)
|
|
del_c_evs = self._track_cov_deleter(
|
|
TrackingState.Confirmed, self._params.cov_threshold
|
|
)
|
|
evs = c_evs + t_evs + create_c_evs + create_t_evs + del_t_evs + del_c_evs
|
|
return evs
|
|
|
|
@property
|
|
def confirmed_trackings(self):
|
|
return self._confirmed_tracks
|
|
|
|
@property
|
|
def bounding_box_format(self):
|
|
return self._bounding_boxes_format
|
|
|
|
@staticmethod
|
|
def motion_model(dt: float = 1, q: float = 0.05) -> LinearMotionNoInputModel:
|
|
"""
|
|
a constant velocity motion model
|
|
"""
|
|
# yapf: disable
|
|
F = np.array([[1, 0, dt, 0],
|
|
[0, 1, 0, dt],
|
|
[0, 0, 1, 0],
|
|
[0, 0, 0, 1]])
|
|
# yapf: enable
|
|
Q = q * np.eye(4)
|
|
return LinearMotionNoInputModel(F=F, Q=Q)
|
|
|
|
@staticmethod
|
|
def measurement_model(r: float = 0.75) -> LinearMeasurementModel:
|
|
# yapf: disable
|
|
H = np.array([[1, 0, 0, 0],
|
|
[0, 1, 0, 0]])
|
|
# yapf: enable
|
|
R = r * np.eye(2)
|
|
return LinearMeasurementModel(H=H, R=R)
|