1
0
forked from HQU-gxy/CVTH3PE

9 Commits

20 changed files with 13005 additions and 1738 deletions

2
.gitignore vendored
View File

@ -10,3 +10,5 @@ wheels/
.venv .venv
.hypothesis .hypothesis
samples samples
*.jpg
*.parquet

View File

@ -1,4 +1,5 @@
```bash ```bash
jupytext --to py:percent <script>.ipynb jupytext --to py:percent <script>.ipynb
``` ```
hjkhljkl

View File

@ -1,6 +1,7 @@
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import string
from typing import Any, TypeAlias, TypedDict, Optional, Sequence from typing import Any, TypeAlias, TypedDict, Optional, Sequence
from beartype import beartype from beartype import beartype
@ -522,10 +523,14 @@ def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array,
raise ValueError(f"Invalid shape for points: {points.shape}") raise ValueError(f"Invalid shape for points: {points.shape}")
import awkward as ak
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def point_line_distance( def point_line_distance(
points: Num[Array, "N 3"] | Num[Array, "N 2"], points: Num[Array, "N 3"] | Num[Array, "N 2"],
line: Num[Array, "N 3"], line: Num[Array, "N 3"],
description: str,
eps: float = 1e-9, eps: float = 1e-9,
): ):
""" """
@ -544,6 +549,12 @@ def point_line_distance(
""" """
numerator = abs(line[:, 0] * points[:, 0] + line[:, 1] * points[:, 1] + line[:, 2]) numerator = abs(line[:, 0] * points[:, 0] + line[:, 1] * points[:, 1] + line[:, 2])
denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1]) denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1])
# line_data = {"a": line[:, 0], "b": line[:, 1], "c": line[:, 2]}
# line_x_y = {"x": points[:, 0], "y": points[:, 1]}
# ak.to_parquet(
# line_data, f"/home/admin/Code/CVTH3PE/line_a_b_c_{description}.parquet"
# )
# ak.to_parquet(line_x_y, f"/home/admin/Code/CVTH3PE/line_x_y_{description}.parquet")
return numerator / (denominator + eps) return numerator / (denominator + eps)
@ -571,7 +582,7 @@ def left_to_right_epipolar_distance(
""" """
F_t = fundamental_matrix.transpose() F_t = fundamental_matrix.transpose()
line1_in_2 = jnp.matmul(left, F_t) line1_in_2 = jnp.matmul(left, F_t)
return point_line_distance(right, line1_in_2) return point_line_distance(right, line1_in_2, "left_to_right")
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -597,7 +608,7 @@ def right_to_left_epipolar_distance(
$$x^{\\prime T}Fx = 0$$ $$x^{\\prime T}Fx = 0$$
""" """
line2_in_1 = jnp.matmul(right, fundamental_matrix) line2_in_1 = jnp.matmul(right, fundamental_matrix)
return point_line_distance(left, line2_in_1) return point_line_distance(left, line2_in_1, "right_to_left")
def distance_between_epipolar_lines( def distance_between_epipolar_lines(

View File

@ -1,14 +1,20 @@
import warnings
import weakref
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime, timedelta
from itertools import chain
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generator, Generator,
Optional, Optional,
Protocol,
Sequence, Sequence,
TypeAlias, TypeAlias,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union,
cast, cast,
overload, overload,
) )
@ -18,18 +24,428 @@ from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from jax import Array from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector from pyrsistent import PVector, v, PRecord, PMap
from app.camera import Detection from app.camera import Detection, CameraID
TrackingID: TypeAlias = int
class TrackingPrediction(TypedDict):
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if delta_t_s <= 0:
warnings.warn(
"delta_t={}; last={}; current={}".format(
delta_t_s, self._last_timestamp, timestamp
)
)
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
if delta_t_s <= 0:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if delta_t_s <= 0:
pass
else:
self._last_timestamp = timestamp
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_historical_3d_poses: deque[Float[Array, "J 3"]]
_historical_timestamps: deque[datetime]
_velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(
self,
historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
max_samples: int = 10,
):
assert len(historical_3d_poses) == len(historical_timestamps)
temp = zip(historical_3d_poses, historical_timestamps)
temp_sorted = sorted(temp, key=lambda x: x[1])
self._historical_3d_poses = deque(
map(lambda x: x[0], temp_sorted), maxlen=max_samples
)
self._historical_timestamps = deque(
map(lambda x: x[1], temp_sorted), maxlen=max_samples
)
self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
self._velocity = None
else:
self._update(
jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
)
def predict(self, timestamp: datetime) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available for prediction")
# use the latest historical detection
latest_3d_pose = self._historical_3d_poses[-1]
latest_timestamp = self._historical_timestamps[-1]
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_3d_pose,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_3d_pose
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
# deque would manage the maxlen automatically
self._historical_3d_poses.append(keypoints)
self._historical_timestamps.append(timestamp)
t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
def timestamp_to_seconds(timestamp: datetime) -> float:
assert t_0 <= timestamp
return (timestamp - t_0).total_seconds()
# timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
map(timestamp_to_seconds, self._historical_timestamps)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available")
latest_3d_pose = self._historical_3d_poses[-1]
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else:
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
class OneEuroFilter(GenericVelocityFilter):
"""
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
based on movement speed to reduce jitter during slow movements while maintaining
responsiveness during fast movements.
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
"""
_x_filtered: Float[Array, "J 3"]
_dx_filtered: Optional[Float[Array, "J 3"]] = None
_last_timestamp: datetime
_min_cutoff: float
_beta: float
_d_cutoff: float
def __init__(
self,
keypoints: Float[Array, "J 3"],
timestamp: datetime,
min_cutoff: float = 1.0,
beta: float = 0.0,
d_cutoff: float = 1.0,
):
"""
Initialize the One Euro Filter.
Args:
keypoints: Initial keypoints positions
timestamp: Initial timestamp
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
beta: Speed coefficient (higher = less lag during fast movements)
d_cutoff: Cutoff frequency for the derivative filter
"""
self._last_timestamp = timestamp
# Filter parameters
self._min_cutoff = min_cutoff
self._beta = beta
self._d_cutoff = d_cutoff
# Filter state
self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate
@overload
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
@overload
def _smoothing_factor(
self, cutoff: Float[Array, "J"], dt: float
) -> Float[Array, "J"]: ...
@jaxtyped(typechecker=beartype)
def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]:
"""Calculate the smoothing factor for the low-pass filter."""
r = 2 * jnp.pi * cutoff * dt
return r / (r + 1)
@jaxtyped(typechecker=beartype)
def _exponential_smoothing(
self,
a: Union[float, Float[Array, "J"]],
x: Float[Array, "J 3"],
x_prev: Float[Array, "J 3"],
) -> Float[Array, "J 3"]:
"""Apply exponential smoothing to the input."""
return a * x + (1 - a) * x_prev
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
Predict keypoints position at a given timestamp.
Args:
timestamp: Timestamp for prediction
Returns:
TrackingPrediction with velocity and keypoints
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if self._dx_filtered is None:
return TrackingPrediction(
velocity=None,
keypoints=self._x_filtered,
)
else:
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=predicted_keypoints,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
Update the filter with new measurements.
Args:
keypoints: New keypoint measurements
timestamp: Timestamp of the measurements
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0:
raise ValueError(
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
)
dx = (keypoints - self._x_filtered) / dt
# Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
dx, axis=-1, keepdims=True
)
# Apply low-pass filter to velocity
a_d = self._smoothing_factor(self._d_cutoff, dt)
self._dx_filtered = self._exponential_smoothing(
a_d,
dx,
(
jnp.zeros_like(keypoints)
if self._dx_filtered is None
else self._dx_filtered
),
)
# Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered
)
# Update timestamp
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
"""
Get the current state of the filter.
Returns:
TrackingPrediction with velocity and keypoints
"""
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=self._x_filtered,
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
class Tracking: class TrackingState:
id: int
""" """
The tracking id immutable state of a tracking
""" """
keypoints: Float[Array, "J 3"] keypoints: Float[Array, "J 3"]
""" """
The 3D keypoints of the tracking The 3D keypoints of the tracking
@ -41,50 +457,97 @@ class Tracking:
The last active timestamp of the tracking The last active timestamp of the tracking
""" """
historical_detections: PVector[Detection] historical_detections_by_camera: PMap[CameraID, Detection]
""" """
Historical detections of the tracking. Historical detections of the tracking.
Used for 3D re-triangulation Used for 3D re-triangulation
""" """
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new class Tracking:
3D pose. id: TrackingID
""" state: TrackingState
velocity_filter: GenericVelocityFilter
def __init__(
self,
id: TrackingID,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})" return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict( def predict(
self, self,
delta_t_s: float, time: float | timedelta | datetime,
) -> Float[Array, "J 3"]: ) -> Float[Array, "J 3"]:
""" if isinstance(time, timedelta):
Predict the 3D pose of a tracking based on its velocity. timestamp = self.state.last_active_timestamp + time
JAX-friendly implementation that avoids Python control flow. elif isinstance(time, datetime):
timestamp = time
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else: else:
velocity = self.velocity # (J, 3) timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
# ------------------------------------------------------------------ def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
# Step 2 pure JAX math """
# ------------------------------------------------------------------ update the tracking with a new 3D pose
return self.keypoints + velocity * delta_t_s
Note:
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
"""
self.velocity_filter.update(new_3d_pose, timestamp)
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
return jnp.zeros_like(self.state.keypoints)
else:
return vel
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -97,10 +560,10 @@ class AffinityResult:
matrix: Float[Array, "T D"] matrix: Float[Array, "T D"]
trackings: Sequence[Tracking] trackings: Sequence[Tracking]
detections: Sequence[Detection] detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_T: Int[Array, "A"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name indices_D: Int[Array, "A"] # pylint: disable=invalid-name
def tracking_detections( def tracking_association(
self, self,
) -> Generator[tuple[float, Tracking, Detection], None, None]: ) -> Generator[tuple[float, Tracking, Detection], None, None]:
""" """

1452
filter_object_by_box.ipynb Normal file

File diff suppressed because one or more lines are too long

208
filter_object_by_box.py Normal file
View File

@ -0,0 +1,208 @@
from narwhals import Boolean
import numpy as np
import cv2
from typing import (
TypeAlias,
TypedDict,
)
from jaxtyping import Array, Num
from shapely.geometry import Polygon
from sympy import false, true
NDArray: TypeAlias = np.ndarray
# 盒子各个面的三维三角形集合
box_triangles_list = [
["4", "6", "7"],
["4", "5", "6"],
["2", "5", "6"],
["1", "2", "5"],
["1", "2", "3"],
["0", "1", "3"],
["0", "3", "7"],
["0", "4", "7"],
["2", "6", "7"],
["2", "3", "7"],
["0", "4", "5"],
["0", "1", "5"],
]
class Camera_Params(TypedDict):
rvec: Num[NDArray, "3"]
tvec: Num[NDArray, "3"]
camera_matrix: Num[Array, "3 3"]
dist: Num[Array, "N"]
width: int
height: int
class KeypointDataset(TypedDict):
frame_index: int
boxes: Num[NDArray, "N 4"]
kps: Num[NDArray, "N J 2"]
kps_scores: Num[NDArray, "N J"]
# 三维坐标系根据相机内外参计算该镜头下的二维重投影坐标
def reprojet_3d_to_2d(point_3d, camera_param):
point_2d, _ = cv2.projectPoints(
objectPoints=point_3d,
rvec=np.array(camera_param.params.Rt[:3, :3]),
tvec=np.array(camera_param.params.Rt[:3, 3]),
cameraMatrix=np.array(camera_param.params.K),
distCoeffs=np.array(camera_param.params.dist_coeffs),
)
point_2d = point_2d.reshape(-1).astype(int)
return point_2d
# 计算盒子三维坐标系
def calculaterCubeVersices(position, dimensions):
[cx, cy, cz] = position
[width, height, depth] = dimensions
halfWidth = width / 2
halfHeight = height / 2
halfDepth = depth / 2
return [
[cx - halfWidth, cy - halfHeight, cz - halfDepth],
[cx + halfWidth, cy - halfHeight, cz - halfDepth],
[cx + halfWidth, cy + halfHeight, cz - halfDepth],
[cx - halfWidth, cy + halfHeight, cz - halfDepth],
[cx - halfWidth, cy - halfHeight, cz + halfDepth],
[cx + halfWidth, cy - halfHeight, cz + halfDepth],
[cx + halfWidth, cy + halfHeight, cz + halfDepth],
[cx - halfWidth, cy + halfHeight, cz + halfDepth],
]
# 获得盒子三维坐标系
def calculater_box_3d_points():
# 盒子原点位置,相对于六面体中心偏移
box_ori_potision = [0.205 + 0.2, 0.205 + 0.50, -0.205 - 0.45]
# 盒子边长1.5米1.5米深度1.8米
box_geometry = [0.65, 1.8, 1]
filter_box_points_3d = calculaterCubeVersices(box_ori_potision, box_geometry)
filter_box_points_3d = {
str(index): element for index, element in enumerate(filter_box_points_3d)
}
return filter_box_points_3d
# 计算盒子坐标系的二维重投影数据
def calculater_box_2d_points(filter_box_points_3d, camera_param):
box_points_2d = dict()
for element_index, elment_point_3d in enumerate(filter_box_points_3d.values()):
box_points_2d[str(element_index)] = reprojet_3d_to_2d(
np.array(elment_point_3d), camera_param
).tolist()
return box_points_2d
# 盒子总的二维平面各三角形坐标点
def calculater_box_common_scope(box_points_2d):
box_triangles_all_points = []
# 遍历三角形个数
for i in range(len(box_triangles_list)):
# 获取单个三角形二维平面坐标点
single_triangles = []
for element_key in box_triangles_list[i]:
single_triangles.append(box_points_2d[element_key])
box_triangles_all_points.append(single_triangles)
return box_triangles_all_points
def calculate_triangle_union(triangles):
"""
计算多个三角形的并集区域
参数:
triangles: 包含多个三角形的列表,每个三角形由三个点的坐标组成
返回:
union_area: 并集区域的面积
union_polygon: 表示并集区域的多边形对象
"""
# 创建多边形对象列表
polygons = [Polygon(tri) for tri in triangles]
# 计算并集
union_polygon = polygons[0]
for polygon in polygons[1:]:
union_polygon = union_polygon.union(polygon)
# 计算并集面积
union_area = union_polygon.area
return union_area, union_polygon
# 射线法判断坐标点是否在box二维重投影的区域内
def point_in_polygon(p, polygon):
x, y = p
n = len(polygon)
intersections = 0
on_boundary = False
for i in range(n):
xi, yi = polygon[i]
xj, yj = polygon[(i + 1) % n] # 闭合多边形
# 检查点是否在顶点上
if (x == xi and y == yi) or (x == xj and y == yj):
on_boundary = True
break
# 检查点是否在线段上(非顶点情况)
if (min(xi, xj) <= x <= max(xi, xj)) and (min(yi, yj) <= y <= max(yi, yj)):
cross = (x - xi) * (yj - yi) - (y - yi) * (xj - xi)
if cross == 0:
on_boundary = True
break
# 计算射线与边的交点(非水平边)
if (yi > y) != (yj > y):
slope = (xj - xi) / (yj - yi) if (yj - yi) != 0 else float("inf")
x_intersect = xi + (y - yi) * slope
if x <= x_intersect:
intersections += 1
if on_boundary:
return false
return intersections % 2 == 1 # 奇数为内部返回True
# 获取并集区域坐标点
def get_contours(union_polygon):
if union_polygon.geom_type == "Polygon":
# 单一多边形
x, y = union_polygon.exterior.xy
contours = [(list(x)[i], list(y)[i]) for i in range(len(x))]
contours = np.array(contours, np.int32)
return contours
# 筛选落在盒子二维重投影区域内的关键点信息
def filter_kps_in_contours(kps, contours) -> Boolean:
# 4 5 16 17
keypoint_index: list[list[int]] = [[4, 5], [16, 17]]
centers = []
for element_keypoint in keypoint_index:
x1, y1 = kps[element_keypoint[0]]
x2, y2 = kps[element_keypoint[1]]
centers.append([(x1 + x2) / 2, (y1 + y2) / 2])
if point_in_polygon(centers[0], contours) and point_in_polygon(
centers[1], contours
):
return true
else:
return false

View File

@ -0,0 +1,282 @@
[
{
"kps": [
419.0,
154.0
],
"kps_scores": 1.0,
"index": 0
},
{
"kps": [
419.0521240234375,
154.07498168945312
],
"kps_scores": 1.0,
"index": 1
},
{
"kps": [
418.5992736816406,
154.3507080078125
],
"kps_scores": 1.0,
"index": 2
},
{
"kps": [
417.0777893066406,
154.17327880859375
],
"kps_scores": 1.0,
"index": 3
},
{
"kps": [
416.8981628417969,
154.15330505371094
],
"kps_scores": 1.0,
"index": 4
},
{
"kps": [
415.1317443847656,
153.68324279785156
],
"kps_scores": 1.0,
"index": 5
},
{
"kps": [
413.2596130371094,
153.39761352539062
],
"kps_scores": 1.0,
"index": 6
},
{
"kps": [
412.7089538574219,
153.3645782470703
],
"kps_scores": 1.0,
"index": 7
},
{
"kps": [
409.3253173828125,
152.9347686767578
],
"kps_scores": 1.0,
"index": 8
},
{
"kps": [
404.74853515625,
152.21153259277344
],
"kps_scores": 1.0,
"index": 9
},
{
"kps": [
404.3977355957031,
152.19647216796875
],
"kps_scores": 1.0,
"index": 10
},
{
"kps": [
396.53131103515625,
152.09912109375
],
"kps_scores": 1.0,
"index": 11
},
{
"kps": [
393.76605224609375,
151.91282653808594
],
"kps_scores": 1.0,
"index": 12
},
{
"kps": [
393.28106689453125,
151.76124572753906
],
"kps_scores": 1.0,
"index": 13
},
{
"kps": [
383.2342834472656,
152.3790740966797
],
"kps_scores": 1.0,
"index": 14
},
{
"kps": [
379.7545471191406,
152.79055786132812
],
"kps_scores": 1.0,
"index": 15
},
{
"kps": [
379.8231506347656,
152.8155975341797
],
"kps_scores": 1.0,
"index": 16
},
{
"kps": [
370.0028076171875,
155.16213989257812
],
"kps_scores": 1.0,
"index": 17
},
{
"kps": [
366.5267639160156,
155.72059631347656
],
"kps_scores": 1.0,
"index": 18
},
{
"kps": [
366.69610595703125,
156.3056182861328
],
"kps_scores": 1.0,
"index": 19
},
{
"kps": [
359.8770751953125,
158.69798278808594
],
"kps_scores": 1.0,
"index": 20
},
{
"kps": [
356.67681884765625,
160.0414581298828
],
"kps_scores": 1.0,
"index": 21
},
{
"kps": [
348.1063232421875,
163.32858276367188
],
"kps_scores": 1.0,
"index": 22
},
{
"kps": [
343.6862487792969,
165.0043182373047
],
"kps_scores": 1.0,
"index": 23
},
{
"kps": [
339.2411804199219,
167.18580627441406
],
"kps_scores": 1.0,
"index": 24
},
{
"kps": [
330.0,
170.0
],
"kps_scores": 0.0,
"index": 25
},
{
"kps": [
322.0425720214844,
174.9293975830078
],
"kps_scores": 1.0,
"index": 26
},
{
"kps": [
310.0,
176.0
],
"kps_scores": 0.0,
"index": 27
},
{
"kps": [
305.0433349609375,
178.03123474121094
],
"kps_scores": 1.0,
"index": 28
},
{
"kps": [
293.71295166015625,
183.8294219970703
],
"kps_scores": 1.0,
"index": 29
},
{
"kps": [
291.28656005859375,
184.33445739746094
],
"kps_scores": 1.0,
"index": 30
},
{
"kps": [
281.0,
190.0
],
"kps_scores": 0.0,
"index": 31
},
{
"kps": [
272.0,
200.0
],
"kps_scores": 0.0,
"index": 32
},
{
"kps": [
261.0457763671875,
211.67132568359375
],
"kps_scores": 1.0,
"index": 33
},
{
"kps": [
239.03567504882812,
248.68519592285156
],
"kps_scores": 1.0,
"index": 34
}
]

View File

@ -0,0 +1,282 @@
[
{
"kps": [
474.0,
215.00003051757812
],
"kps_scores": 1.0,
"index": 0
},
{
"kps": [
474.0710754394531,
215.04542541503906
],
"kps_scores": 1.0,
"index": 1
},
{
"kps": [
476.81365966796875,
215.0387420654297
],
"kps_scores": 1.0,
"index": 2
},
{
"kps": [
479.3288269042969,
214.4371795654297
],
"kps_scores": 1.0,
"index": 3
},
{
"kps": [
479.3817443847656,
214.49256896972656
],
"kps_scores": 1.0,
"index": 4
},
{
"kps": [
483.0047302246094,
213.85231018066406
],
"kps_scores": 1.0,
"index": 5
},
{
"kps": [
484.1208801269531,
213.64219665527344
],
"kps_scores": 1.0,
"index": 6
},
{
"kps": [
484.140869140625,
213.63470458984375
],
"kps_scores": 1.0,
"index": 7
},
{
"kps": [
487.458251953125,
213.45497131347656
],
"kps_scores": 1.0,
"index": 8
},
{
"kps": [
488.8343505859375,
213.4651336669922
],
"kps_scores": 1.0,
"index": 9
},
{
"kps": [
488.899658203125,
213.48526000976562
],
"kps_scores": 1.0,
"index": 10
},
{
"kps": [
493.831787109375,
214.70533752441406
],
"kps_scores": 1.0,
"index": 11
},
{
"kps": [
495.60980224609375,
215.26271057128906
],
"kps_scores": 1.0,
"index": 12
},
{
"kps": [
495.5881042480469,
215.2436065673828
],
"kps_scores": 1.0,
"index": 13
},
{
"kps": [
502.015380859375,
217.81201171875
],
"kps_scores": 1.0,
"index": 14
},
{
"kps": [
504.2356262207031,
218.78392028808594
],
"kps_scores": 1.0,
"index": 15
},
{
"kps": [
504.2625427246094,
218.81021118164062
],
"kps_scores": 1.0,
"index": 16
},
{
"kps": [
511.97552490234375,
222.26150512695312
],
"kps_scores": 1.0,
"index": 17
},
{
"kps": [
514.9180908203125,
224.3387908935547
],
"kps_scores": 1.0,
"index": 18
},
{
"kps": [
514.7620239257812,
224.2892608642578
],
"kps_scores": 1.0,
"index": 19
},
{
"kps": [
524.9593505859375,
230.30003356933594
],
"kps_scores": 1.0,
"index": 20
},
{
"kps": [
528.3402709960938,
232.76568603515625
],
"kps_scores": 1.0,
"index": 21
},
{
"kps": [
528.371826171875,
232.73399353027344
],
"kps_scores": 1.0,
"index": 22
},
{
"kps": [
538.7906494140625,
240.9889678955078
],
"kps_scores": 1.0,
"index": 23
},
{
"kps": [
538.7630004882812,
241.00299072265625
],
"kps_scores": 1.0,
"index": 24
},
{
"kps": [
550.0248413085938,
248.24708557128906
],
"kps_scores": 1.0,
"index": 25
},
{
"kps": [
554.3512573242188,
250.6501922607422
],
"kps_scores": 1.0,
"index": 26
},
{
"kps": [
554.0921020507812,
250.47769165039062
],
"kps_scores": 1.0,
"index": 27
},
{
"kps": [
567.93212890625,
266.1629943847656
],
"kps_scores": 1.0,
"index": 28
},
{
"kps": [
571.8528442382812,
273.5104675292969
],
"kps_scores": 1.0,
"index": 29
},
{
"kps": [
571.9888305664062,
273.5711669921875
],
"kps_scores": 1.0,
"index": 30
},
{
"kps": [
586.6533203125,
309.09576416015625
],
"kps_scores": 1.0,
"index": 31
},
{
"kps": [
591.8392944335938,
325.38385009765625
],
"kps_scores": 1.0,
"index": 32
},
{
"kps": [
592.3212280273438,
325.2934265136719
],
"kps_scores": 1.0,
"index": 33
},
{
"kps": [
603.3639526367188,
362.4980773925781
],
"kps_scores": 1.0,
"index": 34
}
]

View File

@ -0,0 +1,282 @@
[
{
"kps": [
461.0,
164.0
],
"kps_scores": 1.0,
"index": 0
},
{
"kps": [
460.9234619140625,
164.2275390625
],
"kps_scores": 1.0,
"index": 1
},
{
"kps": [
460.93524169921875,
164.19480895996094
],
"kps_scores": 1.0,
"index": 2
},
{
"kps": [
460.4592590332031,
164.14320373535156
],
"kps_scores": 1.0,
"index": 3
},
{
"kps": [
459.9245910644531,
164.054931640625
],
"kps_scores": 1.0,
"index": 4
},
{
"kps": [
459.8656921386719,
164.08154296875
],
"kps_scores": 1.0,
"index": 5
},
{
"kps": [
456.9087219238281,
163.1707305908203
],
"kps_scores": 1.0,
"index": 6
},
{
"kps": [
455.7566223144531,
162.69784545898438
],
"kps_scores": 1.0,
"index": 7
},
{
"kps": [
455.740478515625,
162.74818420410156
],
"kps_scores": 1.0,
"index": 8
},
{
"kps": [
449.8667907714844,
161.95462036132812
],
"kps_scores": 1.0,
"index": 9
},
{
"kps": [
447.55975341796875,
162.12559509277344
],
"kps_scores": 1.0,
"index": 10
},
{
"kps": [
447.5325012207031,
162.12460327148438
],
"kps_scores": 1.0,
"index": 11
},
{
"kps": [
439.9998474121094,
162.59873962402344
],
"kps_scores": 1.0,
"index": 12
},
{
"kps": [
437.3090515136719,
162.88577270507812
],
"kps_scores": 1.0,
"index": 13
},
{
"kps": [
437.2088623046875,
162.84994506835938
],
"kps_scores": 1.0,
"index": 14
},
{
"kps": [
429.199951171875,
164.5860595703125
],
"kps_scores": 1.0,
"index": 15
},
{
"kps": [
429.32745361328125,
164.66001892089844
],
"kps_scores": 1.0,
"index": 16
},
{
"kps": [
424.8293762207031,
166.40106201171875
],
"kps_scores": 1.0,
"index": 17
},
{
"kps": [
419.6496887207031,
168.80294799804688
],
"kps_scores": 1.0,
"index": 18
},
{
"kps": [
419.6795349121094,
168.93418884277344
],
"kps_scores": 1.0,
"index": 19
},
{
"kps": [
414.8919677734375,
172.65428161621094
],
"kps_scores": 1.0,
"index": 20
},
{
"kps": [
410.0992431640625,
175.77218627929688
],
"kps_scores": 1.0,
"index": 21
},
{
"kps": [
410.0442810058594,
175.911376953125
],
"kps_scores": 1.0,
"index": 22
},
{
"kps": [
400.20159912109375,
184.33380126953125
],
"kps_scores": 1.0,
"index": 23
},
{
"kps": [
396.4606628417969,
186.7172088623047
],
"kps_scores": 1.0,
"index": 24
},
{
"kps": [
396.3185119628906,
186.76808166503906
],
"kps_scores": 1.0,
"index": 25
},
{
"kps": [
382.623291015625,
192.941650390625
],
"kps_scores": 1.0,
"index": 26
},
{
"kps": [
376.8236999511719,
195.2269744873047
],
"kps_scores": 1.0,
"index": 27
},
{
"kps": [
376.66937255859375,
195.1109161376953
],
"kps_scores": 1.0,
"index": 28
},
{
"kps": [
362.7231750488281,
209.30923461914062
],
"kps_scores": 1.0,
"index": 29
},
{
"kps": [
355.9901123046875,
216.26303100585938
],
"kps_scores": 1.0,
"index": 30
},
{
"kps": [
356.3956298828125,
216.3310546875
],
"kps_scores": 1.0,
"index": 31
},
{
"kps": [
343.6780090332031,
235.2663116455078
],
"kps_scores": 1.0,
"index": 32
},
{
"kps": [
332.50238037109375,
261.8990783691406
],
"kps_scores": 1.0,
"index": 33
},
{
"kps": [
332.8721923828125,
261.7060546875
],
"kps_scores": 1.0,
"index": 34
}
]

File diff suppressed because one or more lines are too long

3268
play.ipynb Normal file

File diff suppressed because it is too large Load Diff

View File

@ -31,13 +31,13 @@ from typing import (
TypeVar, TypeVar,
cast, cast,
overload, overload,
Iterable,
) )
import awkward as ak import awkward as ak
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import orjson
from beartype import beartype from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints from cv2 import undistortPoints
@ -46,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from optax.assignment import hungarian_algorithm as linear_sum_assignment from optax.assignment import hungarian_algorithm as linear_sum_assignment
from pyrsistent import v, pvector from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
from collections import defaultdict
from app.camera import ( from app.camera import (
Camera, Camera,
@ -59,25 +60,31 @@ from app.camera import (
classify_by_camera, classify_by_camera,
) )
from app.solver._old import GLPKSolver from app.solver._old import GLPKSolver
from app.tracking import AffinityResult, Tracking from app.tracking import (
TrackingID,
AffinityResult,
LastDifferenceVelocityFilter,
Tracking,
TrackingState,
)
from app.visualize.whole_body import visualize_whole_body from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray NDArray: TypeAlias = np.ndarray
# %% # %%
DATASET_PATH = Path("samples") / "04_02" DATASET_PATH = Path("samples") / "04_02" #定义数据集路径
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # 从parquet文件中读取相机参数数据集
DELTA_T_MIN = timedelta(milliseconds=10) DELTA_T_MIN = timedelta(milliseconds=1) #定义最小时间间隔为1毫秒
display(AK_CAMERA_DATASET) display(AK_CAMERA_DATASET) #显示相机参数
# %% # %%
class Resolution(TypedDict): class Resolution(TypedDict): #定义Resonlution类型用于表述图像分辨率
width: int width: int
height: int height: int
class Intrinsic(TypedDict): class Intrinsic(TypedDict): #定义Intrinsic类型用于表示相机参数
camera_matrix: Num[Array, "3 3"] camera_matrix: Num[Array, "3 3"]
""" """
K K
@ -88,12 +95,12 @@ class Intrinsic(TypedDict):
""" """
class Extrinsic(TypedDict): class Extrinsic(TypedDict): #相机外参
rvec: Num[NDArray, "3"] rvec: Num[NDArray, "3"]
tvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"]
class ExternalCameraParams(TypedDict): class ExternalCameraParams(TypedDict): #外部相机参数
name: str name: str
port: int port: int
intrinsic: Intrinsic intrinsic: Intrinsic
@ -102,93 +109,93 @@ class ExternalCameraParams(TypedDict):
# %% # %%
def read_dataset_by_port(port: int) -> ak.Array: def read_dataset_by_port(port: int) -> ak.Array: #定义函数根据端口号读取数据集
P = DATASET_PATH / f"{port}.parquet" P = DATASET_PATH / f"{port}.parquet" #构建数据集文件路径
return ak.from_parquet(P) return ak.from_parquet(P) #从Parquet文件中读取数据集
KEYPOINT_DATASET = { KEYPOINT_DATASET = { #构建关键点数据集字典,键为端口号,,值为对应的数据集
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"]) int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
} }
# %% # %%
class KeypointDataset(TypedDict): class KeypointDataset(TypedDict): #用于表示关键点数据集
frame_index: int frame_index: int # 帧索引
boxes: Num[NDArray, "N 4"] boxes: Num[NDArray, "N 4"] # 边界框N个框每个框4个坐标
kps: Num[NDArray, "N J 2"] kps: Num[NDArray, "N J 2"] # 关键点N个对象每个对象J个关键点每个关键点2维坐标
kps_scores: Num[NDArray, "N J"] kps_scores: Num[NDArray, "N J"] # 关键点分数N个对象每个对象J个分数
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype) #运行时检查函数参数和返回值是否符合类型注解中的维度约束
def to_transformation_matrix( def to_transformation_matrix( #将旋转向量和平移向量转换为4x4的变换矩阵
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] #输入参数
) -> Num[NDArray, "4 4"]: ) -> Num[NDArray, "4 4"]:
res = np.eye(4) res = np.eye(4) #初始化一个4x4的单位矩阵
res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, :3] = R.from_rotvec(rvec).as_matrix() #将旋转向量转换为旋转矩阵并赋值给左上角3x3子矩阵
res[:3, 3] = tvec res[:3, 3] = tvec #将平移向量赋值给最后一列的前三个元素
return res return res
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def undistort_points( def undistort_points( # 对图像点进行去畸变处理
points: Num[NDArray, "M 2"], points: Num[NDArray, "M 2"], #输入参数 # M个点每个点2维坐标 (x, y)
camera_matrix: Num[NDArray, "3 3"], camera_matrix: Num[NDArray, "3 3"], # 3×3相机内参矩阵
dist_coeffs: Num[NDArray, "N"], dist_coeffs: Num[NDArray, "N"], # N个畸变系数
) -> Num[NDArray, "M 2"]: ) -> Num[NDArray, "M 2"]: # 返回M个去畸变后的点坐标
K = camera_matrix K = camera_matrix # 重新赋值参数
dist = dist_coeffs dist = dist_coeffs
res = undistortPoints(points, K, dist, P=K) # type: ignore res = undistortPoints(points, K, dist, P=K) # type: ignore #使用OpenCV 中的函数,用于对图像点进行去畸变处理
return res.reshape(-1, 2) return res.reshape(-1, 2) #将输出结果重塑为 M×2 的二维数组,确保返回格式正确
def from_camera_params(camera: ExternalCameraParams) -> Camera: def from_camera_params(camera: ExternalCameraParams) -> Camera: #将外部相机参数转换为内部 Camera 对象
rt = jnp.array( rt = jnp.array(
to_transformation_matrix( to_transformation_matrix( #调用函数,将将旋转向量和平移向量组合为齐次变换矩阵
ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["rvec"]), #数据转换为 NumPy 数组
ak.to_numpy(camera["extrinsic"]["tvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]),
) )
) )
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) #从外部参数中提取相机内参矩阵,重塑为 3×3 矩阵
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) #提取相机的畸变系数
image_size = jnp.array( image_size = jnp.array( #提取图像的宽度和高度,存储为 JAX 数组
(camera["resolution"]["width"], camera["resolution"]["height"]) (camera["resolution"]["width"], camera["resolution"]["height"])
) )
return Camera( return Camera(
id=camera["name"], id=camera["name"],
params=CameraParams( params=CameraParams( #封装所有相机参数
K=K, K=K, #相机内参矩阵
Rt=rt, Rt=rt, #相机外参矩阵(齐次变换矩阵)
dist_coeffs=dist_coeffs, dist_coeffs=dist_coeffs, #畸变系数
image_size=image_size, image_size=image_size, #图像分辨率
), ),
) )
def preprocess_keypoint_dataset( def preprocess_keypoint_dataset( #用于将关键点数据集KeypointDataset 序列)转换为 Detection 对象流
dataset: Sequence[KeypointDataset], dataset: Sequence[KeypointDataset], # 输入:关键点数据集序列
camera: Camera, camera: Camera, # 相机参数
fps: float, fps: float, # 帧率(帧/秒)
start_timestamp: datetime, start_timestamp: datetime, # 起始时间戳
) -> Generator[Detection, None, None]: ) -> Generator[Detection, None, None]: # 输出Detection对象生成器
frame_interval_s = 1 / fps frame_interval_s = 1 / fps #计算每帧的时间间隔(秒)
for el in dataset: for el in dataset:
frame_index = el["frame_index"] frame_index = el["frame_index"] # 获取当前帧索引
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
for kp, kp_score in zip(el["kps"], el["kps_scores"]): for kp, kp_score in zip(el["kps"], el["kps_scores"]):
yield Detection( yield Detection(
keypoints=jnp.array(kp), keypoints=jnp.array(kp), # 关键点坐标
confidences=jnp.array(kp_score), confidences=jnp.array(kp_score), # 关键点置信度
camera=camera, camera=camera, # 相机参数
timestamp=timestamp, timestamp=timestamp, # 时间戳
) )
# %% # %%
DetectionGenerator: TypeAlias = Generator[Detection, None, None] DetectionGenerator: TypeAlias = Generator[Detection, None, None] #别名定义
#将多个异步的检测流按时间戳同步,生成时间上 “对齐” 的批次
def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): #gens: 检测生成器列表diff: 允许的时间戳最大差异,用于判断两个检测是否属于同一批次
""" """
given a list of detection generators, return a generator that yields a batch of detections given a list of detection generators, return a generator that yields a batch of detections
@ -196,13 +203,13 @@ def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta):
gens: list of detection generators gens: list of detection generators
diff: maximum timestamp difference between detections to consider them part of the same batch diff: maximum timestamp difference between detections to consider them part of the same batch
""" """
N = len(gens) N = len(gens) # 生成器数量
last_batch_timestamp: Optional[datetime] = None last_batch_timestamp: Optional[datetime] = None # 当前批次的时间戳
next_batch_timestamp: Optional[datetime] = None next_batch_timestamp: Optional[datetime] = None # 下一批次的时间戳
current_batch: list[Detection] = [] current_batch: list[Detection] = [] # 当前批次的检测结果
next_batch: list[Detection] = [] next_batch: list[Detection] = [] # 下一批次的检测结果
paused: list[bool] = [False] * N paused: list[bool] = [False] * N # 标记每个生成器是否暂停
finished: list[bool] = [False] * N finished: list[bool] = [False] * N # 标记每个生成器是否已耗尽
def reset_paused(): def reset_paused():
""" """
@ -216,56 +223,56 @@ def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta):
EPS = 1e-6 EPS = 1e-6
# a small epsilon to avoid floating point precision issues # a small epsilon to avoid floating point precision issues
diff_esp = diff - timedelta(seconds=EPS) diff_esp = diff - timedelta(seconds=EPS) #用于处理浮点数精度问题,避免因微小时间差导致误判。
while True: while True:
for i, gen in enumerate(gens): for i, gen in enumerate(gens):
try: try:
if finished[i] or paused[i]: if finished[i] or paused[i]:
continue continue
val = next(gen) val = next(gen) # 获取下一个检测结果
if last_batch_timestamp is None: if last_batch_timestamp is None: # ... 时间戳比较与批次分配 ...
last_batch_timestamp = val.timestamp last_batch_timestamp = val.timestamp
current_batch.append(val) current_batch.append(val) # 初始化第一批
else: else:
if abs(val.timestamp - last_batch_timestamp) >= diff_esp: if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
next_batch.append(val) next_batch.append(val) # 时间差超过阈值,放入下一批
if next_batch_timestamp is None: if next_batch_timestamp is None:
next_batch_timestamp = val.timestamp next_batch_timestamp = val.timestamp
paused[i] = True paused[i] = True # 暂停该生成器,等待批次切换
if all(paused): if all(paused):
yield current_batch yield current_batch # 所有生成器都暂停时,输出当前批次
current_batch = next_batch current_batch = next_batch
next_batch = [] next_batch = []
last_batch_timestamp = next_batch_timestamp last_batch_timestamp = next_batch_timestamp
next_batch_timestamp = None next_batch_timestamp = None
reset_paused() reset_paused() # 重置暂停状态
else: else:
current_batch.append(val) current_batch.append(val) # 时间差在阈值内,加入当前批次
except StopIteration: except StopIteration:
finished[i] = True finished[i] = True
paused[i] = True paused[i] = True
if all(finished): if all(finished):
if len(current_batch) > 0: if len(current_batch) > 0:
# All generators exhausted, flush remaining batch and exit # All generators exhausted, flush remaining batch and exit
yield current_batch yield current_batch # 输出最后一批
break break
# %% # %%
@overload @overload
def to_projection_matrix( def to_projection_matrix( #将 变换矩阵4×4 和 相机内参矩阵3×3 组合成一个 投影矩阵3×4
transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"] transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"]
) -> Num[NDArray, "3 4"]: ... ) -> Num[NDArray, "3 4"]: ...
@overload @overload
def to_projection_matrix( def to_projection_matrix( #将 变换矩阵4×4 和 相机内参矩阵3×3 组合成一个 投影矩阵3×4
transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"] transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"]
) -> Num[Array, "3 4"]: ... ) -> Num[Array, "3 4"]: ...
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def to_projection_matrix( def to_projection_matrix( #计算投影矩阵使用jax.jit提高性能
transformation_matrix: Num[Any, "4 4"], transformation_matrix: Num[Any, "4 4"],
camera_matrix: Num[Any, "3 3"], camera_matrix: Num[Any, "3 3"],
) -> Num[Any, "3 4"]: ) -> Num[Any, "3 4"]:
@ -276,28 +283,29 @@ to_projection_matrix_jit = jax.jit(to_projection_matrix)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def dlt( def dlt( # DLT算法
H1: Num[NDArray, "3 4"], H1: Num[NDArray, "3 4"], # 第一个相机的投影矩阵3×4
H2: Num[NDArray, "3 4"], H2: Num[NDArray, "3 4"], # 第二个相机的投影矩阵3×4
p1: Num[NDArray, "2"], p1: Num[NDArray, "2"], # 三维点在第一个相机图像上的投影u1, v1
p2: Num[NDArray, "2"], p2: Num[NDArray, "2"], # 三维点在第二个相机图像上的投影u2, v2
) -> Num[NDArray, "3"]: ) -> Num[NDArray, "3"]: # 输出三维空间点坐标X, Y, Z
""" """
Direct Linear Transformation Direct Linear Transformation
""" """
A = [ A = [ # 构建矩阵A
p1[1] * H1[2, :] - H1[1, :], p1[1] * H1[2, :] - H1[1, :], # 第一行v₁·H1[2,:] - H1[1,:]
H1[0, :] - p1[0] * H1[2, :], H1[0, :] - p1[0] * H1[2, :], # 第二行H1[0,:] - u₁·H1[2,:]
p2[1] * H2[2, :] - H2[1, :], p2[1] * H2[2, :] - H2[1, :], # 第三行v₂·H2[2,:] - H2[1,:]
H2[0, :] - p2[0] * H2[2, :], H2[0, :] - p2[0] * H2[2, :], # 第四行H2[0,:] - u₂·H2[2,:]
] ]
A = np.array(A).reshape((4, 4)) A = np.array(A).reshape((4, 4)) # 转换为4×4矩阵
B = A.transpose() @ A # 求解超定方程组
B = A.transpose() @ A # 计算A^T·A4×4矩阵
from scipy import linalg from scipy import linalg
U, s, Vh = linalg.svd(B, full_matrices=False) U, s, Vh = linalg.svd(B, full_matrices=False) # SVD分解
return Vh[3, 0:3] / Vh[3, 3] return Vh[3, 0:3] / Vh[3, 3] # 提取解并归一化
@overload @overload
@ -309,7 +317,7 @@ def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ..
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def homogeneous_to_euclidean( def homogeneous_to_euclidean( #将 齐次坐标 转换为 欧几里得坐标
points: Num[Any, "N 4"], points: Num[Any, "N 4"],
) -> Num[Any, "N 3"]: ) -> Num[Any, "N 3"]:
""" """
@ -324,25 +332,31 @@ def homogeneous_to_euclidean(
return points[..., :-1] / points[..., -1:] return points[..., :-1] / points[..., -1:]
# %% # %% # 创建三个相机的关键点检测生成器,并使用 sync_batch_gen 函数将它们同步为时间对齐的批次。
FPS = 24 FPS = 24 # 帧率24帧/秒
# 创建三个相机的检测生成器假设port=5600,5601,5602对应三个不同相机
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
display(1 / FPS) display(1 / FPS) # 每帧时间间隔约0.0417秒
# 同步三个生成器时间窗口为1/FPS秒即同一批次内的检测时间差不超过一帧
sync_gen = sync_batch_gen( sync_gen = sync_batch_gen(
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS) [image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
) )
# %% # %% 基于 对极约束 计算不同相机检测结果之间的关联度矩阵,并返回排序后的检测结果和关联度矩阵
#输入 # next(sync_gen):从同步生成器获取的一批检测结果(包含多个相机在相近时间点的检测)
# alpha_2d=2000控制 2D 距离权重的参数,用于平衡对极约束和其他特征(如外观、运动)的影响
#输出 #sorted_detections排序后的检测结果列表
#affinity_matrix关联度矩阵matrix[i][j] 表示第 i 个检测与第 j 个检测的关联程度(值越大表示越可能是同一目标)
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
next(sync_gen), alpha_2d=2000 next(sync_gen), alpha_2d=2000
) )
display(sorted_detections) display(sorted_detections)
# %% # %% # 可视化多相机目标跟踪中的关键数据:检测时间戳和关联度矩阵
display( display( #将排序后的检测结果转换为包含时间戳和相机 ID 的字典列表,并在 Jupyter 中显示
list( list(
map( map(
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id}, lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
@ -350,12 +364,12 @@ display(
) )
) )
) )
with jnp.printoptions(precision=3, suppress=True): with jnp.printoptions(precision=3, suppress=True): #以高精度格式显示关联度矩阵,控制浮点数精度为 3 位,并禁用科学计数法
display(affinity_matrix) display(affinity_matrix)
# %% # %% #实现了一个基于关联度矩阵的聚类算法,将可能属于同一目标的检测结果分组
def clusters_to_detections( def clusters_to_detections( # 聚类函数
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection] clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
) -> list[list[Detection]]: ) -> list[list[Detection]]:
""" """
@ -372,17 +386,17 @@ def clusters_to_detections(
return [[sorted_detections[i] for i in cluster] for cluster in clusters] return [[sorted_detections[i] for i in cluster] for cluster in clusters]
solver = GLPKSolver() solver = GLPKSolver() # 初始化GLPK线性规划求解器
aff_np = np.asarray(affinity_matrix).astype(np.float64) aff_np = np.asarray(affinity_matrix).astype(np.float64) # 转换关联度矩阵为NumPy数组
clusters, sol_matrix = solver.solve(aff_np) clusters, sol_matrix = solver.solve(aff_np) # 求解聚类问题
display(clusters) display(clusters)
display(sol_matrix) display(sol_matrix)
# %% # %% #两个函数用于处理嵌套数据结构
T = TypeVar("T") T = TypeVar("T")
def flatten_values( def flatten_values( # 将 字典 中所有序列值展开成一个 一维 列表
d: Mapping[Any, Sequence[T]], d: Mapping[Any, Sequence[T]],
) -> list[T]: ) -> list[T]:
""" """
@ -391,7 +405,7 @@ def flatten_values(
return [v for vs in d.values() for v in vs] return [v for vs in d.values() for v in vs]
def flatten_values_len( def flatten_values_len( #计算字典中所有序列值的元素总数
d: Mapping[Any, Sequence[T]], d: Mapping[Any, Sequence[T]],
) -> int: ) -> int:
""" """
@ -401,19 +415,22 @@ def flatten_values_len(
return val return val
# %% # %% #将同一目标在不同相机中的关键点投影到同一图像上,直观验证多相机跟踪的准确性
WIDTH = 2560 WIDTH = 2560
HEIGHT = 1440 HEIGHT = 1440
# 将聚类结果转换为Detection对象列表
clusters_detections = clusters_to_detections(clusters, sorted_detections) clusters_detections = clusters_to_detections(clusters, sorted_detections)
# 创建空白图像(黑色背景)
im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
# 可视化第一个聚类中的所有检测(同一目标在不同相机中的关键点)
for el in clusters_detections[0]: for el in clusters_detections[0]:
im = visualize_whole_body(np.asarray(el.keypoints), im) im = visualize_whole_body(np.asarray(el.keypoints), im)
# 显示结果图像
p = plt.imshow(im) p = plt.imshow(im)
display(p) display(p)
# %% # %% #根据上部分顺延,可视化第二个聚类,通常指检测中的第二个个体
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[1]: for el in clusters_detections[1]:
im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime) im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)
@ -422,9 +439,9 @@ p_prime = plt.imshow(im_prime)
display(p_prime) display(p_prime)
# %% # %% #从多视角图像点进行三维点三角测量的算法
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear( def triangulate_one_point_from_multiple_views_linear( # 单一点的三角测量
proj_matrices: Float[Array, "N 3 4"], proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"], points: Num[Array, "N 2"],
confidences: Optional[Float[Array, "N"]] = None, confidences: Optional[Float[Array, "N"]] = None,
@ -472,7 +489,7 @@ def triangulate_one_point_from_multiple_views_linear(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear( def triangulate_points_from_multiple_views_linear( # 批量三角测量
proj_matrices: Float[Array, "N 3 4"], proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"], points: Num[Array, "N P 2"],
confidences: Optional[Float[Array, "N P"]] = None, confidences: Optional[Float[Array, "N P"]] = None,
@ -504,9 +521,143 @@ def triangulate_points_from_multiple_views_linear(
return vmap_triangulate(proj_matrices, points, conf) return vmap_triangulate(proj_matrices, points, conf)
# %% # %% #两个函数实现了带时间权重的多视角三维点三角测量算法
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear_time_weighted( #单一点三角测量函数
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Triangulate one point from multiple views with time-weighted linear least squares.
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
Args:
proj_matrices: Shape (N, 3, 4) projection matrices sequence
points: Shape (N, 2) point coordinates sequence
delta_t: Time differences between current time and each observation (in seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N,) confidence values in range [0.0, 1.0]
Returns:
point_3d: Shape (3,) triangulated 3D point
"""
assert len(proj_matrices) == len(points)
assert len(delta_t) == len(points)
N = len(proj_matrices)
# Prepare confidence weights
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
# First build the coefficient matrix without weights
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
# Then apply the time-based and confidence weights
for i in range(N):
# Calculate time-decay weight: e^(-λ_t * Δt)
time_weight = jnp.exp(-lambda_t * delta_t[i])
# Calculate normalization factor: ||c^i^T||_2
row_norm_1 = jnp.linalg.norm(A[2 * i])
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
# Apply combined weight: time_weight / row_norm * confidence
w1 = (time_weight / row_norm_1) * confi[i]
w2 = (time_weight / row_norm_2) * confi[i]
A = A.at[2 * i].mul(w1)
A = A.at[2 * i + 1].mul(w2)
# Solve using SVD
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# Ensure homogeneous coordinate is positive
point_3d_homo = jnp.where(
point_3d_homo[3] < 0,
-point_3d_homo,
point_3d_homo,
)
# Convert from homogeneous to Euclidean coordinates
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear_time_weighted( #批量三角测量函数
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Vectorized version that triangulates P points from N camera views with time-weighting.
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
Args:
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N, P) confidence values for each point in each camera
Returns:
points_3d: Shape (P, 3) triangulated 3D points
"""
N, P, _ = points.shape
assert (
proj_matrices.shape[0] == N
), "Number of projection matrices must match number of cameras"
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
if confidences is None:
# Create uniform confidences if none provided
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = confidences
# Define the vmapped version of the single-point function
# We map over the second dimension (P points) of the input arrays
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear_time_weighted,
in_axes=(
None,
1,
None,
None,
1,
), # proj_matrices and delta_t static, map over points
out_axes=0, # Output has first dimension corresponding to points
)
# For each point p, extract the 2D coordinates from all cameras and triangulate
return vmap_triangulate(
proj_matrices, # (N, 3, 4) - static across points
points, # (N, P, 2) - map over dim 1 (P)
delta_t, # (N,) - static across points
lambda_t, # scalar - static
conf, # (N, P) - map over dim 1 (P)
)
# %% #从一个聚类的检测结果中通过三角测量计算三维点坐标,并返回该聚类的最新时间戳
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def triangle_from_cluster( def triangle_from_cluster(
cluster: Sequence[Detection], cluster: Sequence[Detection],
@ -523,8 +674,23 @@ def triangle_from_cluster(
) )
# %% # %% #多目标跟踪系统的核心逻辑,用于从聚类的检测结果中创建和管理全局跟踪状态
class GlobalTrackingState: def group_by_cluster_by_camera( #按相机分组函数
cluster: Sequence[Detection],
) -> PMap[CameraID, Detection]:
"""
group the detections by camera, and preserve the latest detection for each camera
"""
r: dict[CameraID, Detection] = {}
for el in cluster:
if el.camera.id in r:
eld = r[el.camera.id]
preserved = max([eld, el], key=lambda x: x.timestamp)
r[el.camera.id] = preserved
return pmap(r)
class GlobalTrackingState: #全局跟踪状态类
_last_id: int _last_id: int
_trackings: dict[int, Tracking] _trackings: dict[int, Tracking]
@ -541,14 +707,22 @@ class GlobalTrackingState:
def trackings(self) -> dict[int, Tracking]: def trackings(self) -> dict[int, Tracking]:
return shallow_copy(self._trackings) return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: #为一个聚类创建新的跟踪记录
if len(cluster) < 2:
raise ValueError(
"cluster must contain at least 2 detections to form a tracking"
)
kps_3d, latest_timestamp = triangle_from_cluster(cluster) kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1 next_id = self._last_id + 1
tracking = Tracking( tracking_state = TrackingState(
id=next_id,
keypoints=kps_3d, keypoints=kps_3d,
last_active_timestamp=latest_timestamp, last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster), historical_detections_by_camera=group_by_cluster_by_camera(cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
) )
self._trackings[next_id] = tracking self._trackings[next_id] = tracking
self._last_id = next_id self._last_id = next_id
@ -560,14 +734,14 @@ for cluster in clusters_detections:
global_tracking_state.add_tracking(cluster) global_tracking_state.add_tracking(cluster)
display(global_tracking_state) display(global_tracking_state)
# %% # %% #从同步生成器 sync_gen 中获取下一批时间对齐的检测结果,并通过 display() 函数进行可视化
next_group = next(sync_gen) next_group = next(sync_gen) # 从同步生成器获取下一批检测结果
display(next_group) display(next_group) # 在Jupyter环境中显示该批次数据
# %% # %% #多相机跟踪系统中 关联亲和度 计算的核心算法
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def calculate_distance_2d( def calculate_distance_2d( #归一化 2D 距离
left: Num[Array, "J 2"], left: Num[Array, "J 2"],
right: Num[Array, "J 2"], right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1), image_size: tuple[int, int] = (1, 1),
@ -596,7 +770,7 @@ def calculate_distance_2d(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def calculate_affinity_2d( def calculate_affinity_2d( #2D 亲和度分数
distance_2d: Float[Array, "J"], distance_2d: Float[Array, "J"],
delta_t: timedelta, delta_t: timedelta,
w_2d: float, w_2d: float,
@ -629,7 +803,7 @@ def calculate_affinity_2d(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points( def perpendicular_distance_point_to_line_two_points( #点到射线的垂直距离
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
) -> Float[Array, ""]: ) -> Float[Array, ""]:
""" """
@ -652,6 +826,7 @@ def perpendicular_distance_point_to_line_two_points(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
#多相机三维重建中的射线距离计算,是评估 2D 检测点与 3D 跟踪点匹配程度的核心算法
def perpendicular_distance_camera_2d_points_to_tracking_raycasting( def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection: Detection, detection: Detection,
tracking: Tracking, tracking: Tracking,
@ -671,11 +846,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint Array of perpendicular distances for each keypoint
""" """
camera = detection.camera camera = detection.camera
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to predicted_pose = tracking.predict(delta_t)
# avoid division-by-zero / exploding affinities.
delta_t = max(delta_t, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s)
# Back-project the 2D points to 3D space # Back-project the 2D points to 3D space
# intersection with z=0 plane # intersection with z=0 plane
@ -699,7 +870,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def calculate_affinity_3d( def calculate_affinity_3d( #3D 亲和度分数
distances: Float[Array, "J"], distances: Float[Array, "J"],
delta_t: timedelta, delta_t: timedelta,
w_3d: float, w_3d: float,
@ -730,7 +901,7 @@ def calculate_affinity_3d(
@beartype @beartype
def calculate_tracking_detection_affinity( def calculate_tracking_detection_affinity( #综合亲和度计算流程
tracking: Tracking, tracking: Tracking,
detection: Detection, detection: Detection,
w_2d: float, w_2d: float,
@ -755,12 +926,12 @@ def calculate_tracking_detection_affinity(
Combined affinity score Combined affinity score
""" """
camera = detection.camera camera = detection.camera
delta_t_raw = detection.timestamp - tracking.last_active_timestamp delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity. # Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN) delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity # Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints) tracking_2d_projection = camera.project(tracking.state.keypoints)
w, h = camera.params.image_size w, h = camera.params.image_size
distance_2d = calculate_distance_2d( distance_2d = calculate_distance_2d(
tracking_2d_projection, tracking_2d_projection,
@ -792,9 +963,9 @@ def calculate_tracking_detection_affinity(
return jnp.sum(total_affinity).item() return jnp.sum(total_affinity).item()
# %% # %% #实现了多相机跟踪系统中亲和度矩阵的高效计算是连接跟踪轨迹Tracking与新检测结果Detection的核心算法
@beartype @beartype
def calculate_camera_affinity_matrix_jax( def calculate_camera_affinity_matrix_jax( #相机亲和度矩阵计算
trackings: Sequence[Tracking], trackings: Sequence[Tracking],
camera_detections: Sequence[Detection], camera_detections: Sequence[Detection],
w_2d: float, w_2d: float,
@ -840,7 +1011,7 @@ def calculate_camera_affinity_matrix_jax(
# === Tracking-side tensors === # === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack( kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings] [trk.state.keypoints for trk in trackings]
) # (T, J, 3) ) # (T, J, 3)
J = kps3d_trk.shape[1] J = kps3d_trk.shape[1]
# === Detection-side tensors === # === Detection-side tensors ===
@ -857,12 +1028,12 @@ def calculate_camera_affinity_matrix_jax(
# --- timestamps ---------- # --- timestamps ----------
t0 = min( t0 = min(
chain( chain(
(trk.last_active_timestamp for trk in trackings), (trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections), (det.timestamp for det in camera_detections),
) )
).timestamp() # common origin (float) ).timestamp() # common origin (float)
ts_trk = jnp.array( ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings], [trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32 dtype=jnp.float32, # now small, ms-scale fits in fp32
) )
ts_det = jnp.array( ts_det = jnp.array(
@ -956,7 +1127,7 @@ def calculate_camera_affinity_matrix_jax(
@beartype @beartype
def calculate_affinity_matrix( def calculate_affinity_matrix( #多相机亲和度矩阵计算
trackings: Sequence[Tracking], trackings: Sequence[Tracking],
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]], detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
w_2d: float, w_2d: float,
@ -1008,7 +1179,7 @@ def calculate_affinity_matrix(
return res return res
# %% # %% #实现了跨视角关联cross-view association 流程
# let's do cross-view association # let's do cross-view association
W_2D = 1.0 W_2D = 1.0
ALPHA_2D = 1.0 ALPHA_2D = 1.0
@ -1032,9 +1203,83 @@ affinities = calculate_affinity_matrix(
display(affinities) display(affinities)
# %% # %% #两个函数分别实现了关联结果聚合和轨迹更新的核心逻辑
def update_tracking(tracking: Tracking, detection: Detection): def affinity_result_by_tracking( #关联结果聚合
delta_t_ = detection.timestamp - tracking.last_active_timestamp results: Iterable[AffinityResult],
delta_t = max(delta_t_, DELTA_T_MIN) min_affinity: float = 0.0,
) -> dict[TrackingID, list[Detection]]:
"""
Group affinity results by target ID.
return tracking Args:
results: the affinity results to group
min_affinity: the minimum affinity to consider
Returns:
a dictionary mapping tracking IDs to a list of detections
"""
res: dict[TrackingID, list[Detection]] = defaultdict(list)
for affinity_result in results:
for affinity, t, d in affinity_result.tracking_association():
if affinity < min_affinity:
continue
res[t.id].append(d)
return res
def update_tracking( #更新流程
tracking: Tracking,
detections: Sequence[Detection],
max_delta_t: timedelta = timedelta(milliseconds=100),
lambda_t: float = 10.0,
) -> None:
"""
update the tracking with a new set of detections
Args:
tracking: the tracking to update
detections: the detections to update the tracking with
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
lambda_t: the lambda value for the time difference
Note:
the function would mutate the tracking object
"""
last_active_timestamp = tracking.state.last_active_timestamp
latest_timestamp = max(d.timestamp for d in detections)
d = thaw(tracking.state.historical_detections_by_camera)
for detection in detections:
d[detection.camera.id] = detection
for camera_id, detection in d.items():
if detection.timestamp - latest_timestamp > max_delta_t:
del d[camera_id]
new_detections = freeze(d)
new_detections_list = list(new_detections.values())
project_matrices = jnp.stack(
[detection.camera.params.projection_matrix for detection in new_detections_list]
)
delta_t = jnp.array(
[
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
for detection in new_detections_list
]
)
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
conf = jnp.stack([detection.confidences for detection in new_detections_list])
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
project_matrices, kps, delta_t, lambda_t, conf
)
new_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=new_detections,
)
tracking.update(kps_3d, latest_timestamp)
tracking.state = new_state
# %% #多目标跟踪系统中轨迹更新的核心流程
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values()) # 1. 按轨迹ID聚合所有相机的匹配检测结果
for tracking_id, detections in affinity_results_by_tracking.items(): # 2. 遍历每个轨迹ID用匹配的检测结果更新轨迹
update_tracking(global_tracking_state.trackings[tracking_id], detections)
# %%

406
plot_epipolar_lines.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,7 @@ dependencies = [
"jaxtyping>=0.2.38", "jaxtyping>=0.2.38",
"jupytext>=1.17.0", "jupytext>=1.17.0",
"matplotlib>=3.10.1", "matplotlib>=3.10.1",
"more-itertools>=10.7.0",
"opencv-python-headless>=4.11.0.86", "opencv-python-headless>=4.11.0.86",
"optax>=0.2.4", "optax>=0.2.4",
"orjson>=3.10.15", "orjson>=3.10.15",
@ -23,6 +24,7 @@ dependencies = [
"pyrsistent>=0.20.0", "pyrsistent>=0.20.0",
"pytest>=8.3.5", "pytest>=8.3.5",
"scipy>=1.15.2", "scipy>=1.15.2",
"shapely>=2.1.1",
"torch>=2.6.0", "torch>=2.6.0",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"typeguard>=4.4.2", "typeguard>=4.4.2",

File diff suppressed because one or more lines are too long

1062
rebuild_by_epipolar_line.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

122
smooth_3d_kps.ipynb Normal file
View File

@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"id": "0d48b7eb",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "dfd27584",
"metadata": {},
"outputs": [],
"source": [
"KPS_PATH = Path(\"samples/WeiHua_03.json\")\n",
"with open(KPS_PATH, \"r\") as file:\n",
" data = json.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "360f9c50",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'index:1, shape: (33, 133, 3)'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'index:2, shape: (662, 133, 3)'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for item_object_index in data.keys():\n",
" item_object = np.array(data[item_object_index])\n",
" display(f'index:{item_object_index}, shape: {item_object.shape}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 对data['2']的662帧3d关键点数据进行滑动窗口平滑处理\n",
"object_points = np.array(data['2']) # shape: (662, 133, 3)\n",
"window_size = 5\n",
"kernel = np.ones(window_size) / window_size\n",
"# 对每个关键点的每个坐标轴分别做滑动平均\n",
"smoothed_points = np.zeros_like(object_points)\n",
"# 遍历133个关节\n",
"for kp_idx in range(object_points.shape[1]):\n",
" # 遍历每个关节的空间三维坐标点\n",
" for axis in range(3):\n",
" # 对第i帧的滑动平滑方式 smoothed[i] = (point[i-2] + point[i-1] + point[i] + point[i+1] + point[i+2]) / 5\n",
" smoothed_points[:, kp_idx, axis] = np.convolve(object_points[:, kp_idx, axis], kernel, mode='same')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "24c6c0c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'smoothed_points shape: (662, 133, 3)'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(f'smoothed_points shape: {smoothed_points.shape}')\n",
"with open(\"samples/smoothed_3d_kps.json\", \"w\") as file:\n",
" json.dump({'1':smoothed_points.tolist()}, file)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cvth3pe",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,193 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "11cc2345",
"metadata": {},
"outputs": [],
"source": [
"import awkward as ak\n",
"import numpy as np\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "84348d97",
"metadata": {},
"outputs": [],
"source": [
"CAMERA_INDEX ={\n",
" 2:\"5602\",\n",
" 4:\"5604\",\n",
"}\n",
"index = 4\n",
"CAMERA_PATH = Path(\"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/camera_params\")\n",
"camera_data = ak.from_parquet(CAMERA_PATH / CAMERA_INDEX[index]/ \"extrinsic.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1d771740",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>[{rvec: [[-2.26], [0.0669], [-2.15]], tvec: [[0.166], ...]},\n",
" {rvec: [[2.07], [0.144], [2.21]], tvec: [[0.143], ...]},\n",
" {rvec: [[2.09], [0.0872], [2.25]], tvec: [[0.141], ...]},\n",
" {rvec: [[2.16], [0.172], [2.09]], tvec: [[0.162], ...]},\n",
" {rvec: [[2.15], [0.18], [2.09]], tvec: [[0.162], ...]},\n",
" {rvec: [[-2.22], [0.117], [-2.14]], tvec: [[0.162], ...]},\n",
" {rvec: [[2.18], [0.176], [2.08]], tvec: [[0.166], ...]},\n",
" {rvec: [[2.18], [0.176], [2.08]], tvec: [[0.166], ...]},\n",
" {rvec: [[-2.26], [0.116], [-2.1]], tvec: [[0.17], ...]},\n",
" {rvec: [[-2.26], [0.124], [-2.09]], tvec: [[0.171], ...]},\n",
" ...,\n",
" {rvec: [[-2.2], [0.0998], [-2.17]], tvec: [[0.158], ...]},\n",
" {rvec: [[-2.2], [0.0998], [-2.17]], tvec: [[0.158], ...]},\n",
" {rvec: [[2.12], [0.151], [2.16]], tvec: [[0.152], ...]},\n",
" {rvec: [[-2.3], [0.0733], [-2.1]], tvec: [[0.175], ...]},\n",
" {rvec: [[2.1], [0.16], [2.17]], tvec: [[0.149], ...]},\n",
" {rvec: [[2.1], [0.191], [2.13]], tvec: [[0.153], ...]},\n",
" {rvec: [[2.11], [0.196], [2.12]], tvec: [[0.154], ...]},\n",
" {rvec: [[2.19], [0.171], [2.08]], tvec: [[0.166], ...]},\n",
" {rvec: [[2.24], [0.0604], [2.12]], tvec: [[0.166], ...]}]\n",
"---------------------------------------------------------------------------\n",
"backend: cpu\n",
"nbytes: 10.1 kB\n",
"type: 90 * {\n",
" rvec: var * var * float64,\n",
" tvec: var * var * float64\n",
"}</pre>"
],
"text/plain": [
"<Array [{rvec: [...], tvec: [...]}, ..., {...}] type='90 * {rvec: var * var...'>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(camera_data)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "59fde11b",
"metadata": {},
"outputs": [],
"source": [
"data = []\n",
"for element in camera_data:\n",
" rvec = element[\"rvec\"]\n",
" if rvec[0]<0:\n",
" data.append({\"rvec\": rvec, \"tvec\": element[\"tvec\"]})"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "4792cbc4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<pyarrow._parquet.FileMetaData object at 0x7799cbf62d40>\n",
" created_by: parquet-cpp-arrow version 19.0.1\n",
" num_columns: 2\n",
" num_rows: 30\n",
" num_row_groups: 1\n",
" format_version: 2.6\n",
" serialized_size: 0"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ak.to_parquet(ak.from_iter(data),\"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/camera_params/5604/re_extrinsic.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8225ee33",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre>[{rvec: [[-2.26], [0.0669], [-2.15]], tvec: [[0.166], ...]},\n",
" {rvec: [[-2.22], [0.117], [-2.14]], tvec: [[0.162], ...]},\n",
" {rvec: [[-2.26], [0.116], [-2.1]], tvec: [[0.17], ...]},\n",
" {rvec: [[-2.26], [0.124], [-2.09]], tvec: [[0.171], ...]},\n",
" {rvec: [[-2.24], [0.133], [-2.11]], tvec: [[0.167], ...]},\n",
" {rvec: [[-2.22], [0.0556], [-2.2]], tvec: [[0.158], ...]},\n",
" {rvec: [[-2.27], [0.119], [-2.09]], tvec: [[0.172], ...]},\n",
" {rvec: [[-2.34], [0.0663], [-2.06]], tvec: [[0.181], ...]},\n",
" {rvec: [[-2.21], [0.117], [-2.15]], tvec: [[0.161], ...]},\n",
" {rvec: [[-2.33], [0.0731], [-2.08]], tvec: [[0.179], ...]},\n",
" ...,\n",
" {rvec: [[-2.23], [0.106], [-2.13]], tvec: [[0.166], ...]},\n",
" {rvec: [[-2.21], [0.054], [-2.2]], tvec: [[0.157], ...]},\n",
" {rvec: [[-2.19], [0.0169], [-2.25]], tvec: [[0.151], ...]},\n",
" {rvec: [[-2.2], [0.0719], [-2.19]], tvec: [[0.157], ...]},\n",
" {rvec: [[-2.22], [0.0726], [-2.18]], tvec: [[0.161], ...]},\n",
" {rvec: [[-2.2], [0.0742], [-2.19]], tvec: [[0.158], ...]},\n",
" {rvec: [[-2.2], [0.0998], [-2.17]], tvec: [[0.158], ...]},\n",
" {rvec: [[-2.2], [0.0998], [-2.17]], tvec: [[0.158], ...]},\n",
" {rvec: [[-2.3], [0.0733], [-2.1]], tvec: [[0.175], ...]}]\n",
"---------------------------------------------------------------------------\n",
"backend: cpu\n",
"nbytes: 3.4 kB\n",
"type: 30 * {\n",
" rvec: var * var * float64,\n",
" tvec: var * var * float64\n",
"}</pre>"
],
"text/plain": [
"<Array [{rvec: [...], tvec: [...]}, ..., {...}] type='30 * {rvec: var * var...'>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"temp_data = ak.from_parquet(\"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/camera_params/5604/re_extrinsic.parquet\")\n",
"display(temp_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

3745
uv.lock generated

File diff suppressed because it is too large Load Diff