diff --git a/.cursor/.rules/jaxtyping.mdc b/.cursor/.rules/jaxtyping.mdc index 4dae6e3..e03aad0 100644 --- a/.cursor/.rules/jaxtyping.mdc +++ b/.cursor/.rules/jaxtyping.mdc @@ -156,19 +156,3 @@ class MyDataclass: x: Float[Array, "batch"] y: Float[Array, "batch"] ``` - -## Scalars, PRNG keys - -For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as: - -```python -Scalar = Shaped[Array, ""] -ScalarLike = Shaped[ArrayLike, ""] - -# Left: new-style typed keys; right: old-style keys. See JEP 9263. -PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]] -``` - -Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`. - -Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed. \ No newline at end of file diff --git a/app/camera/__init__.py b/app/camera/__init__.py index 0a1580b..84ee773 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -6,7 +6,7 @@ from typing import Any, TypeAlias, TypedDict, Optional from beartype import beartype import jax from jax import numpy as jnp -from jaxtyping import Num, jaxtyped, Array +from jaxtyping import Num, jaxtyped, Array, Float from cv2 import undistortPoints import numpy as np @@ -87,6 +87,81 @@ def distortion( return jnp.stack([u, v], axis=1) +@jaxtyped(typechecker=beartype) +def unproject_points_onto_plane( + points_2d: Float[Array, "N 2"], + plane_normal: Float[Array, "3"], + plane_point: Float[Array, "3"], + K: Float[Array, "3 3"], # pylint: disable=invalid-name + dist_coeffs: Float[Array, "5"], + pose_matrix: Float[Array, "4 4"], +) -> Float[Array, "N 3"]: + """ + Un-project 2D image points onto an arbitrary 3D plane. + This function computes the ray-plane intersections, since every `points_2d` + could be treated as a ray. + + (i.e. back-project points onto a plane) + + Args: + points_2d: [..., 2] image pixel coordinates + plane_normal: (3,) normal vector of the plane in world coords + plane_point: (3,) a known point on the plane in world coords + K: Camera intrinsic matrix + dist_coeffs: Distortion coefficients + pose_matrix: Camera-to-World (C2W) transformation matrix + + Note: + `pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C), + but the inverse of it. + + Returns: + [..., 3] world-space intersection points + """ + # Step 1: undistort (no-op here) + pts = undistort_points( + np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs) + ) + + # Step 2: normalize image coordinates into camera rays + fx, fy = K[0, 0], K[1, 1] + cx, cy = K[0, 2], K[1, 2] + dirs_cam = jnp.stack( + [(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])], + axis=-1, + ) # (N, 3) + + # Step 3: transform rays into world space + c2w = pose_matrix + ray_orig = c2w[:3, 3] # (3,) + R_world = c2w[:3, :3] # (3,3) + ray_dirs = (R_world @ dirs_cam.T).T # (N, 3) + + # Step 4: plane intersection + n = plane_normal / jnp.linalg.norm(plane_normal) + p0 = plane_point + denom = jnp.dot(ray_dirs, n) # (N,) + numer = jnp.dot((p0 - ray_orig), n) # scalar + t = numer / denom # (N,) + points_world = ray_orig + ray_dirs * t[:, None] + return points_world + + +@jaxtyped(typechecker=beartype) +def unproject_points_to_z_plane( + points_2d: Float[Array, "N 2"], + K: Float[Array, "3 3"], + dist_coeffs: Float[Array, "5"], + pose_matrix: Float[Array, "4 4"], + z: float = 0.0, +) -> Float[Array, "N 3"]: + plane_normal = jnp.array([0.0, 0.0, 1.0]) + plane_point = jnp.array([0.0, 0.0, z]) + return unproject_points_onto_plane( + points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix + ) + + @jaxtyped(typechecker=beartype) def project( points_3d: Num[Array, "N 3"], @@ -242,6 +317,9 @@ class Camera: Camera parameters """ + def __repr__(self) -> str: + return f"" + def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]: """ Project 3D points to 2D points using this camera's parameters @@ -292,6 +370,20 @@ class Camera: dist_coeffs=self.params.dist_coeffs, ) + def unproject_points_to_z_plane( + self, points_2d: Num[Array, "N 2"], z: float = 0.0 + ) -> Num[Array, "N 3"]: + """ + Unproject 2D points to 3D points on a plane at z = constant. + """ + return unproject_points_to_z_plane( + points_2d, + self.params.K, + self.params.dist_coeffs, + self.params.pose_matrix, + z, + ) + @jaxtyped(typechecker=beartype) @dataclass(frozen=True) @@ -337,6 +429,9 @@ class Detection: object.__setattr__(self, "_kp_undistorted", kpu) return kpu + def __repr__(self) -> str: + return f"Detection({self.camera}, {self.timestamp})" + def classify_by_camera( detections: list[Detection], diff --git a/play.ipynb b/play.ipynb index d191fbb..567f1a1 100644 --- a/play.ipynb +++ b/play.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -339,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -348,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -432,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -530,21 +530,32 @@ "\n", "\n", "@jaxtyped(typechecker=beartype)\n", - "@dataclass\n", + "@dataclass(frozen=True)\n", "class Tracking:\n", " id: int\n", " keypoints: Float[Array, \"J 3\"]\n", + " last_active_timestamp: datetime\n", + "\n", + " def __repr__(self) -> str:\n", + " return f\"Tracking({self.id}, {self.last_active_timestamp})\"\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", - "def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"N 3\"]:\n", + "def triangle_from_cluster(\n", + " cluster: list[Detection],\n", + ") -> tuple[Float[Array, \"N 3\"], datetime]:\n", " proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n", " points = jnp.array([el.keypoints_undistorted for el in cluster])\n", " confidences = jnp.array([el.confidences for el in cluster])\n", - " return triangulate_points_from_multiple_views_linear(\n", - " proj_matrices, points, confidences=confidences\n", + " latest_timestamp = max(el.timestamp for el in cluster)\n", + " return (\n", + " triangulate_points_from_multiple_views_linear(\n", + " proj_matrices, points, confidences=confidences\n", + " ),\n", + " latest_timestamp,\n", " )\n", "\n", + "\n", "# res = {\n", "# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n", "# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n", @@ -552,28 +563,39 @@ "# with open(\"samples/res.json\", \"wb\") as f:\n", "# f.write(orjson.dumps(res))\n", "\n", + "\n", "class GlobalTrackingState:\n", " _last_id: int\n", - " _trackings: list[Tracking]\n", + " _trackings: dict[int, Tracking]\n", "\n", " def __init__(self):\n", " self._last_id = 0\n", - " self._trackings = []\n", + " self._trackings = {}\n", + "\n", + " def __repr__(self) -> str:\n", + " return (\n", + " f\"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})\"\n", + " )\n", "\n", " @property\n", - " def trackings(self) -> list[Tracking]:\n", + " def trackings(self) -> dict[int, Tracking]:\n", " return shallow_copy(self._trackings)\n", "\n", " def add_tracking(self, cluster: list[Detection]) -> Tracking:\n", - " tracking = Tracking(id=self._last_id, keypoints=triangle_from_cluster(cluster))\n", - " self._last_id += 1\n", - " self._trackings.append(tracking)\n", + " kps_3d, latest_timestamp = triangle_from_cluster(cluster)\n", + " next_id = self._last_id + 1\n", + " tracking = Tracking(\n", + " id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp\n", + " )\n", + " self._trackings[next_id] = tracking\n", + " self._last_id = next_id\n", " return tracking\n", "\n", "\n", "global_tracking_state = GlobalTrackingState()\n", "for cluster in clusters_detections:\n", - " global_tracking_state.add_tracking(cluster)" + " global_tracking_state.add_tracking(cluster)\n", + "display(global_tracking_state)" ] }, { @@ -585,6 +607,28 @@ "next_group = next(sync_gen)\n", "display(next_group)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from app.camera import classify_by_camera\n", + "\n", + "# let's do cross-view association\n", + "trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)\n", + "detections = shallow_copy(next_group)\n", + "# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections\n", + "affinity = np.zeros((len(trackings), len(detections)))\n", + "detection_by_camera = classify_by_camera(detections)\n", + "for i, tracking in enumerate(trackings):\n", + " for c, detections in detection_by_camera.items():\n", + " camera = next(iter(detections)).camera\n", + " # pixel space, unnormalized\n", + " tracking_2d_projection = camera.project(tracking.keypoints)\n", + " \n" + ] } ], "metadata": {