1
0
forked from HQU-gxy/CVTH3PE

single peopele detect and tracking

This commit is contained in:
2025-06-13 16:02:15 +08:00
parent eb9738cb02
commit 492b4fba04
8 changed files with 6734 additions and 1874 deletions

View File

@ -31,13 +31,13 @@ from typing import (
TypeVar,
cast,
overload,
Iterable,
)
import awkward as ak
import jax
import jax.numpy as jnp
import numpy as np
import orjson
from beartype import beartype
from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints
@ -46,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from optax.assignment import hungarian_algorithm as linear_sum_assignment
from pyrsistent import v, pvector
from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated
from collections import defaultdict
from app.camera import (
Camera,
@ -59,17 +60,21 @@ from app.camera import (
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.tracking import AffinityResult, Tracking
from app.tracking import (
TrackingID,
AffinityResult,
LastDifferenceVelocityFilter,
Tracking,
TrackingState,
)
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
# %%
CAMERA_PATH = Path(
"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/camera_params"
)
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(CAMERA_PATH / "camera_params.parquet")
DELTA_T_MIN = timedelta(milliseconds=10)
DATASET_PATH = Path("samples") / "04_02"
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore
DELTA_T_MIN = timedelta(milliseconds=1)
display(AK_CAMERA_DATASET)
@ -104,13 +109,6 @@ class ExternalCameraParams(TypedDict):
# %%
# %%
DATASET_PATH = Path(
"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/detect_result/segement_1"
)
def read_dataset_by_port(port: int) -> ak.Array:
P = DATASET_PATH / f"{port}.parquet"
return ak.from_parquet(P)
@ -119,7 +117,6 @@ def read_dataset_by_port(port: int) -> ak.Array:
KEYPOINT_DATASET = {
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
}
display(KEYPOINT_DATASET)
# %%
@ -194,8 +191,6 @@ def preprocess_keypoint_dataset(
)
# %%
# %%
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
@ -338,31 +333,13 @@ def homogeneous_to_euclidean(
# %%
FPS = 24
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5603 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5603], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5603][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5604 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5604], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5604][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5605 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5605], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5605][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5606 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5606], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5606][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5607 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5607], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5607][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5608 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5608], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5608][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5609 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5609], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5609][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
display(1 / FPS)
sync_gen = sync_batch_gen(
[
image_gen_5601,
# image_gen_5602,
# image_gen_5603,
image_gen_5604,
image_gen_5605,
image_gen_5606,
# image_gen_5607,
image_gen_5608,
image_gen_5609,
],
timedelta(seconds=1 / FPS),
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
)
# %%
@ -375,7 +352,7 @@ display(sorted_detections)
display(
list(
map(
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id, "keypoint":x.keypoints.shape},
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
sorted_detections,
)
)
@ -443,7 +420,6 @@ for el in clusters_detections[0]:
p = plt.imshow(im)
display(p)
# %%
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[1]:
@ -535,6 +511,142 @@ def triangulate_points_from_multiple_views_linear(
return vmap_triangulate(proj_matrices, points, conf)
# %%
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Triangulate one point from multiple views with time-weighted linear least squares.
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
Args:
proj_matrices: Shape (N, 3, 4) projection matrices sequence
points: Shape (N, 2) point coordinates sequence
delta_t: Time differences between current time and each observation (in seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N,) confidence values in range [0.0, 1.0]
Returns:
point_3d: Shape (3,) triangulated 3D point
"""
assert len(proj_matrices) == len(points)
assert len(delta_t) == len(points)
N = len(proj_matrices)
# Prepare confidence weights
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
# First build the coefficient matrix without weights
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
# Then apply the time-based and confidence weights
for i in range(N):
# Calculate time-decay weight: e^(-λ_t * Δt)
time_weight = jnp.exp(-lambda_t * delta_t[i])
# Calculate normalization factor: ||c^i^T||_2
row_norm_1 = jnp.linalg.norm(A[2 * i])
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
# Apply combined weight: time_weight / row_norm * confidence
w1 = (time_weight / row_norm_1) * confi[i]
w2 = (time_weight / row_norm_2) * confi[i]
A = A.at[2 * i].mul(w1)
A = A.at[2 * i + 1].mul(w2)
# Solve using SVD
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# Ensure homogeneous coordinate is positive
point_3d_homo = jnp.where(
point_3d_homo[3] < 0,
-point_3d_homo,
point_3d_homo,
)
# Convert from homogeneous to Euclidean coordinates
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Vectorized version that triangulates P points from N camera views with time-weighting.
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
Args:
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N, P) confidence values for each point in each camera
Returns:
points_3d: Shape (P, 3) triangulated 3D points
"""
N, P, _ = points.shape
assert (
proj_matrices.shape[0] == N
), "Number of projection matrices must match number of cameras"
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
if confidences is None:
# Create uniform confidences if none provided
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = confidences
# Define the vmapped version of the single-point function
# We map over the second dimension (P points) of the input arrays
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear_time_weighted,
in_axes=(
None,
1,
None,
None,
1,
), # proj_matrices and delta_t static, map over points
out_axes=0, # Output has first dimension corresponding to points
)
# For each point p, extract the 2D coordinates from all cameras and triangulate
return vmap_triangulate(
proj_matrices, # (N, 3, 4) - static across points
points, # (N, P, 2) - map over dim 1 (P)
delta_t, # (N,) - static across points
lambda_t, # scalar - static
conf, # (N, P) - map over dim 1 (P)
)
# %%
@ -555,6 +667,21 @@ def triangle_from_cluster(
# %%
def group_by_cluster_by_camera(
cluster: Sequence[Detection],
) -> PMap[CameraID, Detection]:
"""
group the detections by camera, and preserve the latest detection for each camera
"""
r: dict[CameraID, Detection] = {}
for el in cluster:
if el.camera.id in r:
eld = r[el.camera.id]
preserved = max([eld, el], key=lambda x: x.timestamp)
r[el.camera.id] = preserved
return pmap(r)
class GlobalTrackingState:
_last_id: int
_trackings: dict[int, Tracking]
@ -573,13 +700,21 @@ class GlobalTrackingState:
return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
if len(cluster) < 2:
raise ValueError(
"cluster must contain at least 2 detections to form a tracking"
)
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking = Tracking(
id=next_id,
tracking_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster),
historical_detections_by_camera=group_by_cluster_by_camera(cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
)
self._trackings[next_id] = tracking
self._last_id = next_id
@ -702,11 +837,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
# avoid division-by-zero / exploding affinities.
delta_t = max(delta_t, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s)
predicted_pose = tracking.predict(delta_t)
# Back-project the 2D points to 3D space
# intersection with z=0 plane
@ -786,12 +917,12 @@ def calculate_tracking_detection_affinity(
Combined affinity score
"""
camera = detection.camera
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints)
tracking_2d_projection = camera.project(tracking.state.keypoints)
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
@ -871,7 +1002,7 @@ def calculate_camera_affinity_matrix_jax(
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings]
[trk.state.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
# === Detection-side tensors ===
@ -888,12 +1019,12 @@ def calculate_camera_affinity_matrix_jax(
# --- timestamps ----------
t0 = min(
chain(
(trk.last_active_timestamp for trk in trackings),
(trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections),
)
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
@ -1064,8 +1195,82 @@ display(affinities)
# %%
def update_tracking(tracking: Tracking, detection: Detection):
delta_t_ = detection.timestamp - tracking.last_active_timestamp
delta_t = max(delta_t_, DELTA_T_MIN)
def affinity_result_by_tracking(
results: Iterable[AffinityResult],
min_affinity: float = 0.0,
) -> dict[TrackingID, list[Detection]]:
"""
Group affinity results by target ID.
return tracking
Args:
results: the affinity results to group
min_affinity: the minimum affinity to consider
Returns:
a dictionary mapping tracking IDs to a list of detections
"""
res: dict[TrackingID, list[Detection]] = defaultdict(list)
for affinity_result in results:
for affinity, t, d in affinity_result.tracking_association():
if affinity < min_affinity:
continue
res[t.id].append(d)
return res
def update_tracking(
tracking: Tracking,
detections: Sequence[Detection],
max_delta_t: timedelta = timedelta(milliseconds=100),
lambda_t: float = 10.0,
) -> None:
"""
update the tracking with a new set of detections
Args:
tracking: the tracking to update
detections: the detections to update the tracking with
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
lambda_t: the lambda value for the time difference
Note:
the function would mutate the tracking object
"""
last_active_timestamp = tracking.state.last_active_timestamp
latest_timestamp = max(d.timestamp for d in detections)
d = thaw(tracking.state.historical_detections_by_camera)
for detection in detections:
d[detection.camera.id] = detection
for camera_id, detection in d.items():
if detection.timestamp - latest_timestamp > max_delta_t:
del d[camera_id]
new_detections = freeze(d)
new_detections_list = list(new_detections.values())
project_matrices = jnp.stack(
[detection.camera.params.projection_matrix for detection in new_detections_list]
)
delta_t = jnp.array(
[
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
for detection in new_detections_list
]
)
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
conf = jnp.stack([detection.confidences for detection in new_detections_list])
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
project_matrices, kps, delta_t, lambda_t, conf
)
new_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=new_detections,
)
tracking.update(kps_3d, latest_timestamp)
tracking.state = new_state
# %%
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values())
for tracking_id, detections in affinity_results_by_tracking.items():
update_tracking(global_tracking_state.trackings[tracking_id], detections)
# %%