feat: Migrate play notebook to Python script and update dependencies
- Removed the `play.ipynb` notebook and created a new `playground.py` script to enhance code organization and maintainability. - Updated `pyproject.toml` to include `jupytext` for Jupyter notebook conversion support. - Added instructions in `README.md` for converting notebooks using Jupytext. - Enhanced the `uv.lock` file to reflect the new dependency on Jupytext.
This commit is contained in:
562
playground.py
Normal file
562
playground.py
Normal file
@ -0,0 +1,562 @@
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.17.0
|
||||
# kernelspec:
|
||||
# display_name: .venv
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %%
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import awkward as ak
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import orjson
|
||||
from beartype import beartype
|
||||
from cv2 import undistortPoints
|
||||
from jaxtyping import Array, Float, Num, jaxtyped
|
||||
from matplotlib import pyplot as plt
|
||||
from numpy.typing import ArrayLike
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from app.camera import Camera, CameraParams, Detection
|
||||
from app.visualize.whole_body import visualize_whole_body
|
||||
from IPython.display import display
|
||||
|
||||
NDArray: TypeAlias = np.ndarray
|
||||
|
||||
# %%
|
||||
DATASET_PATH = Path("samples") / "04_02"
|
||||
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
|
||||
display(AK_CAMERA_DATASET)
|
||||
|
||||
|
||||
# %%
|
||||
class Resolution(TypedDict):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class Intrinsic(TypedDict):
|
||||
camera_matrix: Num[Array, "3 3"]
|
||||
"""
|
||||
K
|
||||
"""
|
||||
distortion_coefficients: Num[Array, "N"]
|
||||
"""
|
||||
distortion coefficients; usually 5
|
||||
"""
|
||||
|
||||
|
||||
class Extrinsic(TypedDict):
|
||||
rvec: Num[NDArray, "3"]
|
||||
tvec: Num[NDArray, "3"]
|
||||
|
||||
|
||||
class ExternalCameraParams(TypedDict):
|
||||
name: str
|
||||
port: int
|
||||
intrinsic: Intrinsic
|
||||
extrinsic: Extrinsic
|
||||
resolution: Resolution
|
||||
|
||||
|
||||
# %%
|
||||
def read_dataset_by_port(port: int) -> ak.Array:
|
||||
P = DATASET_PATH / f"{port}.parquet"
|
||||
return ak.from_parquet(P)
|
||||
|
||||
|
||||
KEYPOINT_DATASET = {
|
||||
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
|
||||
}
|
||||
|
||||
|
||||
# %%
|
||||
class KeypointDataset(TypedDict):
|
||||
frame_index: int
|
||||
boxes: Num[NDArray, "N 4"]
|
||||
kps: Num[NDArray, "N J 2"]
|
||||
kps_scores: Num[NDArray, "N J"]
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def to_transformation_matrix(
|
||||
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"]
|
||||
) -> Num[NDArray, "4 4"]:
|
||||
res = np.eye(4)
|
||||
res[:3, :3] = R.from_rotvec(rvec).as_matrix()
|
||||
res[:3, 3] = tvec
|
||||
return res
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def undistort_points(
|
||||
points: Num[NDArray, "M 2"],
|
||||
camera_matrix: Num[NDArray, "3 3"],
|
||||
dist_coeffs: Num[NDArray, "N"],
|
||||
) -> Num[NDArray, "M 2"]:
|
||||
K = camera_matrix
|
||||
dist = dist_coeffs
|
||||
res = undistortPoints(points, K, dist, P=K) # type: ignore
|
||||
return res.reshape(-1, 2)
|
||||
|
||||
|
||||
def from_camera_params(camera: ExternalCameraParams) -> Camera:
|
||||
rt = jnp.array(
|
||||
to_transformation_matrix(
|
||||
ak.to_numpy(camera["extrinsic"]["rvec"]),
|
||||
ak.to_numpy(camera["extrinsic"]["tvec"]),
|
||||
)
|
||||
)
|
||||
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3)
|
||||
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"])
|
||||
image_size = jnp.array(
|
||||
(camera["resolution"]["width"], camera["resolution"]["height"])
|
||||
)
|
||||
return Camera(
|
||||
id=camera["name"],
|
||||
params=CameraParams(
|
||||
K=K,
|
||||
Rt=rt,
|
||||
dist_coeffs=dist_coeffs,
|
||||
image_size=image_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def preprocess_keypoint_dataset(
|
||||
dataset: Sequence[KeypointDataset],
|
||||
camera: Camera,
|
||||
fps: float,
|
||||
start_timestamp: datetime,
|
||||
) -> Generator[Detection, None, None]:
|
||||
frame_interval_s = 1 / fps
|
||||
for el in dataset:
|
||||
frame_index = el["frame_index"]
|
||||
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
|
||||
for kp, kp_score in zip(el["kps"], el["kps_scores"]):
|
||||
yield Detection(
|
||||
keypoints=jnp.array(kp),
|
||||
confidences=jnp.array(kp_score),
|
||||
camera=camera,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
||||
|
||||
|
||||
def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta):
|
||||
"""
|
||||
given a list of detection generators, return a generator that yields a batch of detections
|
||||
|
||||
Args:
|
||||
gens: list of detection generators
|
||||
diff: maximum timestamp difference between detections to consider them part of the same batch
|
||||
"""
|
||||
N = len(gens)
|
||||
last_batch_timestamp: Optional[datetime] = None
|
||||
next_batch_timestamp: Optional[datetime] = None
|
||||
current_batch: list[Detection] = []
|
||||
next_batch: list[Detection] = []
|
||||
paused: list[bool] = [False] * N
|
||||
finished: list[bool] = [False] * N
|
||||
|
||||
def reset_paused():
|
||||
"""
|
||||
reset paused list based on finished list
|
||||
"""
|
||||
for i in range(N):
|
||||
if not finished[i]:
|
||||
paused[i] = False
|
||||
else:
|
||||
paused[i] = True
|
||||
|
||||
EPS = 1e-6
|
||||
# a small epsilon to avoid floating point precision issues
|
||||
diff_esp = diff - timedelta(seconds=EPS)
|
||||
while True:
|
||||
for i, gen in enumerate(gens):
|
||||
try:
|
||||
if finished[i] or paused[i]:
|
||||
continue
|
||||
val = next(gen)
|
||||
if last_batch_timestamp is None:
|
||||
last_batch_timestamp = val.timestamp
|
||||
current_batch.append(val)
|
||||
else:
|
||||
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
|
||||
next_batch.append(val)
|
||||
if next_batch_timestamp is None:
|
||||
next_batch_timestamp = val.timestamp
|
||||
paused[i] = True
|
||||
if all(paused):
|
||||
yield current_batch
|
||||
current_batch = next_batch
|
||||
next_batch = []
|
||||
last_batch_timestamp = next_batch_timestamp
|
||||
next_batch_timestamp = None
|
||||
reset_paused()
|
||||
else:
|
||||
current_batch.append(val)
|
||||
except StopIteration:
|
||||
finished[i] = True
|
||||
paused[i] = True
|
||||
if all(finished):
|
||||
if len(current_batch) > 0:
|
||||
# All generators exhausted, flush remaining batch and exit
|
||||
yield current_batch
|
||||
break
|
||||
|
||||
|
||||
# %%
|
||||
@overload
|
||||
def to_projection_matrix(
|
||||
transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"]
|
||||
) -> Num[NDArray, "3 4"]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def to_projection_matrix(
|
||||
transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"]
|
||||
) -> Num[Array, "3 4"]: ...
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def to_projection_matrix(
|
||||
transformation_matrix: Num[Any, "4 4"],
|
||||
camera_matrix: Num[Any, "3 3"],
|
||||
) -> Num[Any, "3 4"]:
|
||||
return camera_matrix @ transformation_matrix[:3, :]
|
||||
|
||||
|
||||
to_projection_matrix_jit = jax.jit(to_projection_matrix)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def dlt(
|
||||
H1: Num[NDArray, "3 4"],
|
||||
H2: Num[NDArray, "3 4"],
|
||||
p1: Num[NDArray, "2"],
|
||||
p2: Num[NDArray, "2"],
|
||||
) -> Num[NDArray, "3"]:
|
||||
"""
|
||||
Direct Linear Transformation
|
||||
"""
|
||||
A = [
|
||||
p1[1] * H1[2, :] - H1[1, :],
|
||||
H1[0, :] - p1[0] * H1[2, :],
|
||||
p2[1] * H2[2, :] - H2[1, :],
|
||||
H2[0, :] - p2[0] * H2[2, :],
|
||||
]
|
||||
A = np.array(A).reshape((4, 4))
|
||||
|
||||
B = A.transpose() @ A
|
||||
from scipy import linalg
|
||||
|
||||
U, s, Vh = linalg.svd(B, full_matrices=False)
|
||||
return Vh[3, 0:3] / Vh[3, 3]
|
||||
|
||||
|
||||
@overload
|
||||
def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ...
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def homogeneous_to_euclidean(
|
||||
points: Num[Any, "N 4"],
|
||||
) -> Num[Any, "N 3"]:
|
||||
"""
|
||||
将齐次坐标转换为欧几里得坐标
|
||||
|
||||
Args:
|
||||
points: homogeneous coordinates (x, y, z, w) in numpy array or jax array
|
||||
|
||||
Returns:
|
||||
euclidean coordinates (x, y, z) in numpy array or jax array
|
||||
"""
|
||||
return points[..., :-1] / points[..., -1:]
|
||||
|
||||
|
||||
# %%
|
||||
FPS = 24
|
||||
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||||
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||||
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||||
|
||||
display(1 / FPS)
|
||||
sync_gen = sync_batch_gen(
|
||||
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
|
||||
)
|
||||
|
||||
# %%
|
||||
detections = next(sync_gen)
|
||||
|
||||
# %%
|
||||
from app.camera import calculate_affinity_matrix_by_epipolar_constraint
|
||||
|
||||
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
|
||||
detections, alpha_2d=2000
|
||||
)
|
||||
|
||||
# %%
|
||||
display(
|
||||
list(
|
||||
map(
|
||||
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
|
||||
sorted_detections,
|
||||
)
|
||||
)
|
||||
)
|
||||
with jnp.printoptions(precision=3, suppress=True):
|
||||
display(affinity_matrix)
|
||||
|
||||
# %%
|
||||
from app.solver._old import GLPKSolver
|
||||
|
||||
|
||||
def clusters_to_detections(
|
||||
clusters: list[list[int]], sorted_detections: list[Detection]
|
||||
) -> list[list[Detection]]:
|
||||
"""
|
||||
given a list of clusters (which is the indices of the detections in the sorted_detections list),
|
||||
extract the detections from the sorted_detections list
|
||||
|
||||
Args:
|
||||
clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list
|
||||
sorted_detections: list of SORTED detections
|
||||
|
||||
Returns:
|
||||
list of clusters, each cluster is a list of detections
|
||||
"""
|
||||
return [[sorted_detections[i] for i in cluster] for cluster in clusters]
|
||||
|
||||
|
||||
solver = GLPKSolver()
|
||||
aff_np = np.asarray(affinity_matrix).astype(np.float64)
|
||||
clusters, sol_matrix = solver.solve(aff_np)
|
||||
display(clusters)
|
||||
display(sol_matrix)
|
||||
|
||||
# %%
|
||||
WIDTH = 2560
|
||||
HEIGHT = 1440
|
||||
|
||||
clusters_detections = clusters_to_detections(clusters, sorted_detections)
|
||||
im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||||
for el in clusters_detections[0]:
|
||||
im = visualize_whole_body(np.asarray(el.keypoints), im)
|
||||
|
||||
p = plt.imshow(im)
|
||||
display(p)
|
||||
|
||||
# %%
|
||||
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||||
for el in clusters_detections[1]:
|
||||
im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)
|
||||
|
||||
p_prime = plt.imshow(im_prime)
|
||||
display(p_prime)
|
||||
|
||||
|
||||
# %%
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def triangulate_one_point_from_multiple_views_linear(
|
||||
proj_matrices: Float[Array, "N 3 4"],
|
||||
points: Num[Array, "N 2"],
|
||||
confidences: Optional[Float[Array, "N"]] = None,
|
||||
) -> Float[Array, "3"]:
|
||||
"""
|
||||
Args:
|
||||
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
|
||||
points: 形状为(N, 2)的点坐标序列
|
||||
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
|
||||
|
||||
Returns:
|
||||
point_3d: 形状为(3,)的三角测量得到的3D点
|
||||
"""
|
||||
assert len(proj_matrices) == len(points)
|
||||
|
||||
N = len(proj_matrices)
|
||||
confi: Float[Array, "N"]
|
||||
if confidences is None:
|
||||
confi = jnp.ones(N, dtype=np.float32)
|
||||
else:
|
||||
# Use square root of confidences for weighting - more balanced approach
|
||||
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
|
||||
|
||||
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||||
for i in range(N):
|
||||
x, y = points[i]
|
||||
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
|
||||
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
|
||||
A = A.at[2 * i].mul(confi[i])
|
||||
A = A.at[2 * i + 1].mul(confi[i])
|
||||
|
||||
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
|
||||
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||||
point_3d_homo = vh[-1] # shape (4,)
|
||||
|
||||
# replace the Python `if` with a jnp.where
|
||||
point_3d_homo = jnp.where(
|
||||
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
|
||||
-point_3d_homo, # if True
|
||||
point_3d_homo, # if False
|
||||
)
|
||||
|
||||
point_3d = point_3d_homo[:3] / point_3d_homo[3]
|
||||
return point_3d
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def triangulate_points_from_multiple_views_linear(
|
||||
proj_matrices: Float[Array, "N 3 4"],
|
||||
points: Num[Array, "N P 2"],
|
||||
confidences: Optional[Float[Array, "N P"]] = None,
|
||||
) -> Float[Array, "P 3"]:
|
||||
"""
|
||||
Batch-triangulate P points observed by N cameras, linearly via SVD.
|
||||
|
||||
Args:
|
||||
proj_matrices: (N, 3, 4) projection matrices
|
||||
points: (N, P, 2) image-coordinates per view
|
||||
confidences: (N, P, 1) optional per-view confidences in [0,1]
|
||||
|
||||
Returns:
|
||||
(P, 3) 3D point for each of the P tracks
|
||||
"""
|
||||
N, P, _ = points.shape
|
||||
assert proj_matrices.shape[0] == N
|
||||
if confidences is None:
|
||||
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||||
else:
|
||||
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))
|
||||
|
||||
# vectorize your one‐point routine over P
|
||||
vmap_triangulate = jax.vmap(
|
||||
triangulate_one_point_from_multiple_views_linear,
|
||||
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
|
||||
out_axes=0,
|
||||
)
|
||||
|
||||
# returns (P, 3)
|
||||
return vmap_triangulate(proj_matrices, points, conf)
|
||||
|
||||
|
||||
# %%
|
||||
from dataclasses import dataclass
|
||||
from copy import copy as shallow_copy, deepcopy as deep_copy
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
class Tracking:
|
||||
id: int
|
||||
keypoints: Float[Array, "J 3"]
|
||||
last_active_timestamp: datetime
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def triangle_from_cluster(
|
||||
cluster: list[Detection],
|
||||
) -> tuple[Float[Array, "N 3"], datetime]:
|
||||
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
|
||||
points = jnp.array([el.keypoints_undistorted for el in cluster])
|
||||
confidences = jnp.array([el.confidences for el in cluster])
|
||||
latest_timestamp = max(el.timestamp for el in cluster)
|
||||
return (
|
||||
triangulate_points_from_multiple_views_linear(
|
||||
proj_matrices, points, confidences=confidences
|
||||
),
|
||||
latest_timestamp,
|
||||
)
|
||||
|
||||
|
||||
# res = {
|
||||
# "a": triangle_from_cluster(clusters_detections[0]).tolist(),
|
||||
# "b": triangle_from_cluster(clusters_detections[1]).tolist(),
|
||||
# }
|
||||
# with open("samples/res.json", "wb") as f:
|
||||
# f.write(orjson.dumps(res))
|
||||
|
||||
|
||||
class GlobalTrackingState:
|
||||
_last_id: int
|
||||
_trackings: dict[int, Tracking]
|
||||
|
||||
def __init__(self):
|
||||
self._last_id = 0
|
||||
self._trackings = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
|
||||
)
|
||||
|
||||
@property
|
||||
def trackings(self) -> dict[int, Tracking]:
|
||||
return shallow_copy(self._trackings)
|
||||
|
||||
def add_tracking(self, cluster: list[Detection]) -> Tracking:
|
||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||
next_id = self._last_id + 1
|
||||
tracking = Tracking(
|
||||
id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp
|
||||
)
|
||||
self._trackings[next_id] = tracking
|
||||
self._last_id = next_id
|
||||
return tracking
|
||||
|
||||
|
||||
global_tracking_state = GlobalTrackingState()
|
||||
for cluster in clusters_detections:
|
||||
global_tracking_state.add_tracking(cluster)
|
||||
display(global_tracking_state)
|
||||
|
||||
# %%
|
||||
next_group = next(sync_gen)
|
||||
display(next_group)
|
||||
|
||||
# %%
|
||||
from app.camera import classify_by_camera
|
||||
|
||||
# let's do cross-view association
|
||||
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||
detections = shallow_copy(next_group)
|
||||
# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections
|
||||
affinity = np.zeros((len(trackings), len(detections)))
|
||||
detection_by_camera = classify_by_camera(detections)
|
||||
for i, tracking in enumerate(trackings):
|
||||
for c, detections in detection_by_camera.items():
|
||||
camera = next(iter(detections)).camera
|
||||
# pixel space, unnormalized
|
||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||
Reference in New Issue
Block a user