Files
CVTH3PE/playground.py
crosstyan 5b5ccbd92b feat: Enhance playground.py with new tracking and affinity calculation functionalities
- Introduced new functions for calculating 2D distances and affinities between detections, improving tracking capabilities.
- Added a `Tracking` dataclass with detailed docstrings for better clarity on its attributes.
- Refactored code to utilize `shallow_copy` for handling detections and improved organization of imports.
- Enhanced the cross-view association logic to accommodate the new functionalities, ensuring better integration with existing tracking systems.
2025-04-27 16:03:05 +08:00

659 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.0
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---
# %%
from copy import copy as shallow_copy
from copy import deepcopy
from copy import deepcopy as deep_copy
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import (
Any,
Generator,
Optional,
Sequence,
TypeAlias,
TypedDict,
cast,
overload,
)
import awkward as ak
import jax
import jax.numpy as jnp
import numpy as np
import orjson
from beartype import beartype
from cv2 import undistortPoints
from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from scipy.spatial.transform import Rotation as R
from app.camera import (
Camera,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
# %%
DATASET_PATH = Path("samples") / "04_02"
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
display(AK_CAMERA_DATASET)
# %%
class Resolution(TypedDict):
width: int
height: int
class Intrinsic(TypedDict):
camera_matrix: Num[Array, "3 3"]
"""
K
"""
distortion_coefficients: Num[Array, "N"]
"""
distortion coefficients; usually 5
"""
class Extrinsic(TypedDict):
rvec: Num[NDArray, "3"]
tvec: Num[NDArray, "3"]
class ExternalCameraParams(TypedDict):
name: str
port: int
intrinsic: Intrinsic
extrinsic: Extrinsic
resolution: Resolution
# %%
def read_dataset_by_port(port: int) -> ak.Array:
P = DATASET_PATH / f"{port}.parquet"
return ak.from_parquet(P)
KEYPOINT_DATASET = {
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
}
# %%
class KeypointDataset(TypedDict):
frame_index: int
boxes: Num[NDArray, "N 4"]
kps: Num[NDArray, "N J 2"]
kps_scores: Num[NDArray, "N J"]
@jaxtyped(typechecker=beartype)
def to_transformation_matrix(
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"]
) -> Num[NDArray, "4 4"]:
res = np.eye(4)
res[:3, :3] = R.from_rotvec(rvec).as_matrix()
res[:3, 3] = tvec
return res
@jaxtyped(typechecker=beartype)
def undistort_points(
points: Num[NDArray, "M 2"],
camera_matrix: Num[NDArray, "3 3"],
dist_coeffs: Num[NDArray, "N"],
) -> Num[NDArray, "M 2"]:
K = camera_matrix
dist = dist_coeffs
res = undistortPoints(points, K, dist, P=K) # type: ignore
return res.reshape(-1, 2)
def from_camera_params(camera: ExternalCameraParams) -> Camera:
rt = jnp.array(
to_transformation_matrix(
ak.to_numpy(camera["extrinsic"]["rvec"]),
ak.to_numpy(camera["extrinsic"]["tvec"]),
)
)
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3)
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"])
image_size = jnp.array(
(camera["resolution"]["width"], camera["resolution"]["height"])
)
return Camera(
id=camera["name"],
params=CameraParams(
K=K,
Rt=rt,
dist_coeffs=dist_coeffs,
image_size=image_size,
),
)
def preprocess_keypoint_dataset(
dataset: Sequence[KeypointDataset],
camera: Camera,
fps: float,
start_timestamp: datetime,
) -> Generator[Detection, None, None]:
frame_interval_s = 1 / fps
for el in dataset:
frame_index = el["frame_index"]
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
for kp, kp_score in zip(el["kps"], el["kps_scores"]):
yield Detection(
keypoints=jnp.array(kp),
confidences=jnp.array(kp_score),
camera=camera,
timestamp=timestamp,
)
# %%
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta):
"""
given a list of detection generators, return a generator that yields a batch of detections
Args:
gens: list of detection generators
diff: maximum timestamp difference between detections to consider them part of the same batch
"""
N = len(gens)
last_batch_timestamp: Optional[datetime] = None
next_batch_timestamp: Optional[datetime] = None
current_batch: list[Detection] = []
next_batch: list[Detection] = []
paused: list[bool] = [False] * N
finished: list[bool] = [False] * N
def reset_paused():
"""
reset paused list based on finished list
"""
for i in range(N):
if not finished[i]:
paused[i] = False
else:
paused[i] = True
EPS = 1e-6
# a small epsilon to avoid floating point precision issues
diff_esp = diff - timedelta(seconds=EPS)
while True:
for i, gen in enumerate(gens):
try:
if finished[i] or paused[i]:
continue
val = next(gen)
if last_batch_timestamp is None:
last_batch_timestamp = val.timestamp
current_batch.append(val)
else:
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
next_batch.append(val)
if next_batch_timestamp is None:
next_batch_timestamp = val.timestamp
paused[i] = True
if all(paused):
yield current_batch
current_batch = next_batch
next_batch = []
last_batch_timestamp = next_batch_timestamp
next_batch_timestamp = None
reset_paused()
else:
current_batch.append(val)
except StopIteration:
finished[i] = True
paused[i] = True
if all(finished):
if len(current_batch) > 0:
# All generators exhausted, flush remaining batch and exit
yield current_batch
break
# %%
@overload
def to_projection_matrix(
transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"]
) -> Num[NDArray, "3 4"]: ...
@overload
def to_projection_matrix(
transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"]
) -> Num[Array, "3 4"]: ...
@jaxtyped(typechecker=beartype)
def to_projection_matrix(
transformation_matrix: Num[Any, "4 4"],
camera_matrix: Num[Any, "3 3"],
) -> Num[Any, "3 4"]:
return camera_matrix @ transformation_matrix[:3, :]
to_projection_matrix_jit = jax.jit(to_projection_matrix)
@jaxtyped(typechecker=beartype)
def dlt(
H1: Num[NDArray, "3 4"],
H2: Num[NDArray, "3 4"],
p1: Num[NDArray, "2"],
p2: Num[NDArray, "2"],
) -> Num[NDArray, "3"]:
"""
Direct Linear Transformation
"""
A = [
p1[1] * H1[2, :] - H1[1, :],
H1[0, :] - p1[0] * H1[2, :],
p2[1] * H2[2, :] - H2[1, :],
H2[0, :] - p2[0] * H2[2, :],
]
A = np.array(A).reshape((4, 4))
B = A.transpose() @ A
from scipy import linalg
U, s, Vh = linalg.svd(B, full_matrices=False)
return Vh[3, 0:3] / Vh[3, 3]
@overload
def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ...
@overload
def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ...
@jaxtyped(typechecker=beartype)
def homogeneous_to_euclidean(
points: Num[Any, "N 4"],
) -> Num[Any, "N 3"]:
"""
将齐次坐标转换为欧几里得坐标
Args:
points: homogeneous coordinates (x, y, z, w) in numpy array or jax array
Returns:
euclidean coordinates (x, y, z) in numpy array or jax array
"""
return points[..., :-1] / points[..., -1:]
# %%
FPS = 24
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
display(1 / FPS)
sync_gen = sync_batch_gen(
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
)
# %%
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
next(sync_gen), alpha_2d=2000
)
display(sorted_detections)
# %%
display(
list(
map(
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
sorted_detections,
)
)
)
with jnp.printoptions(precision=3, suppress=True):
display(affinity_matrix)
# %%
def clusters_to_detections(
clusters: list[list[int]], sorted_detections: list[Detection]
) -> list[list[Detection]]:
"""
given a list of clusters (which is the indices of the detections in the sorted_detections list),
extract the detections from the sorted_detections list
Args:
clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list
sorted_detections: list of SORTED detections
Returns:
list of clusters, each cluster is a list of detections
"""
return [[sorted_detections[i] for i in cluster] for cluster in clusters]
solver = GLPKSolver()
aff_np = np.asarray(affinity_matrix).astype(np.float64)
clusters, sol_matrix = solver.solve(aff_np)
display(clusters)
display(sol_matrix)
# %%
WIDTH = 2560
HEIGHT = 1440
clusters_detections = clusters_to_detections(clusters, sorted_detections)
im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[0]:
im = visualize_whole_body(np.asarray(el.keypoints), im)
p = plt.imshow(im)
display(p)
# %%
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[1]:
im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)
p_prime = plt.imshow(im_prime)
display(p_prime)
# %%
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Args:
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
points: 形状为(N, 2)的点坐标序列
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
Returns:
point_3d: 形状为(3,)的三角测量得到的3D点
"""
assert len(proj_matrices) == len(points)
N = len(proj_matrices)
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
# Use square root of confidences for weighting - more balanced approach
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
A = A.at[2 * i].mul(confi[i])
A = A.at[2 * i + 1].mul(confi[i])
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# replace the Python `if` with a jnp.where
point_3d_homo = jnp.where(
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
-point_3d_homo, # if True
point_3d_homo, # if False
)
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Batch-triangulate P points observed by N cameras, linearly via SVD.
Args:
proj_matrices: (N, 3, 4) projection matrices
points: (N, P, 2) image-coordinates per view
confidences: (N, P, 1) optional per-view confidences in [0,1]
Returns:
(P, 3) 3D point for each of the P tracks
"""
N, P, _ = points.shape
assert proj_matrices.shape[0] == N
if confidences is None:
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))
# vectorize your onepoint routine over P
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear,
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
out_axes=0,
)
return vmap_triangulate(proj_matrices, points, conf)
# %%
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
@jaxtyped(typechecker=beartype)
def triangle_from_cluster(
cluster: list[Detection],
) -> tuple[Float[Array, "N 3"], datetime]:
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
points = jnp.array([el.keypoints_undistorted for el in cluster])
confidences = jnp.array([el.confidences for el in cluster])
latest_timestamp = max(el.timestamp for el in cluster)
return (
triangulate_points_from_multiple_views_linear(
proj_matrices, points, confidences=confidences
),
latest_timestamp,
)
# res = {
# "a": triangle_from_cluster(clusters_detections[0]).tolist(),
# "b": triangle_from_cluster(clusters_detections[1]).tolist(),
# }
# with open("samples/res.json", "wb") as f:
# f.write(orjson.dumps(res))
class GlobalTrackingState:
_last_id: int
_trackings: dict[int, Tracking]
def __init__(self):
self._last_id = 0
self._trackings = {}
def __repr__(self) -> str:
return (
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
)
@property
def trackings(self) -> dict[int, Tracking]:
return shallow_copy(self._trackings)
def add_tracking(self, cluster: list[Detection]) -> Tracking:
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking = Tracking(
id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp
)
self._trackings[next_id] = tracking
self._last_id = next_id
return tracking
global_tracking_state = GlobalTrackingState()
for cluster in clusters_detections:
global_tracking_state.add_tracking(cluster)
display(global_tracking_state)
# %%
next_group = next(sync_gen)
display(next_group)
# %%
@jaxtyped(typechecker=beartype)
def calculate_distance_2d(
left: Num[Array, "J 2"],
right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1),
):
"""
Calculate the *normalized* distance between two sets of keypoints.
Args:
left: The left keypoints
right: The right keypoints
image_size: The size of the image
"""
w, h = image_size
if w == 1 and h == 1:
# already normalized
left_normalized = left
right_normalized = right
else:
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
@jaxtyped(typechecker=beartype)
def calculate_affinity_2d(
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float
) -> float:
"""
Calculate the affinity between two detections based on the distance between their keypoints.
Args:
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`)
w_2d: The weight of the distance (parameter)
alpha_2d: The alpha value for the distance calculation (parameter)
lambda_a: The lambda value for the distance calculation (parameter)
delta_t: The time delta between the two detections, in seconds
"""
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t)
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]]
):
"""
Calculate the perpendicular distance between a point and a line.
where `line` is represented by two points: `(line_start, line_end)`
"""
line_start, line_end = line
distance = jnp.linalg.norm(
jnp.cross(line_end - line_start, line_start - point)
) / jnp.linalg.norm(line_end - line_start)
return distance
def predict_pose_3d(
tracking: Tracking,
delta_t: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
"""
if tracking.velocity is None:
return tracking.keypoints
return tracking.keypoints + tracking.velocity * delta_t
# %%
# let's do cross-view association
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of
# trackings, D is the number of detections
# layout:
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd
# a_t2_c1_d1,...
# ...
# a_tt_c1_d1,... , a_tt_cc_dd
#
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
# detections from camera `n`
affinity = np.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings):
for c, detections in detection_by_camera.items():
camera = next(iter(detections)).camera
# pixel space, unnormalized
tracking_2d_projection = camera.project(tracking.keypoints)
for det in detections:
...