feat: Enhance play notebook and camera module with new unprojection functionalities

- Updated the play notebook to include new methods for unprojecting 2D points onto a 3D plane.
- Introduced `unproject_points_onto_plane` and `unproject_points_to_z_plane` functions in the camera module for improved point handling.
- Enhanced the `Camera` class with a method for unprojecting points to a specified z-plane.
- Cleaned up execution counts in the notebook for better organization and clarity.
This commit is contained in:
2025-04-24 18:55:24 +08:00
parent 00481a0d6f
commit c3c93f6ca6
3 changed files with 160 additions and 37 deletions

View File

@ -156,19 +156,3 @@ class MyDataclass:
x: Float[Array, "batch"] x: Float[Array, "batch"]
y: Float[Array, "batch"] y: Float[Array, "batch"]
``` ```
## Scalars, PRNG keys
For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as:
```python
Scalar = Shaped[Array, ""]
ScalarLike = Shaped[ArrayLike, ""]
# Left: new-style typed keys; right: old-style keys. See JEP 9263.
PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]]
```
Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`.
Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed.

View File

@ -6,7 +6,7 @@ from typing import Any, TypeAlias, TypedDict, Optional
from beartype import beartype from beartype import beartype
import jax import jax
from jax import numpy as jnp from jax import numpy as jnp
from jaxtyping import Num, jaxtyped, Array from jaxtyping import Num, jaxtyped, Array, Float
from cv2 import undistortPoints from cv2 import undistortPoints
import numpy as np import numpy as np
@ -87,6 +87,81 @@ def distortion(
return jnp.stack([u, v], axis=1) return jnp.stack([u, v], axis=1)
@jaxtyped(typechecker=beartype)
def unproject_points_onto_plane(
points_2d: Float[Array, "N 2"],
plane_normal: Float[Array, "3"],
plane_point: Float[Array, "3"],
K: Float[Array, "3 3"], # pylint: disable=invalid-name
dist_coeffs: Float[Array, "5"],
pose_matrix: Float[Array, "4 4"],
) -> Float[Array, "N 3"]:
"""
Un-project 2D image points onto an arbitrary 3D plane.
This function computes the ray-plane intersections, since every `points_2d`
could be treated as a ray.
(i.e. back-project points onto a plane)
Args:
points_2d: [..., 2] image pixel coordinates
plane_normal: (3,) normal vector of the plane in world coords
plane_point: (3,) a known point on the plane in world coords
K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients
pose_matrix: Camera-to-World (C2W) transformation matrix
Note:
`pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C),
but the inverse of it.
Returns:
[..., 3] world-space intersection points
"""
# Step 1: undistort (no-op here)
pts = undistort_points(
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
)
# Step 2: normalize image coordinates into camera rays
fx, fy = K[0, 0], K[1, 1]
cx, cy = K[0, 2], K[1, 2]
dirs_cam = jnp.stack(
[(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])],
axis=-1,
) # (N, 3)
# Step 3: transform rays into world space
c2w = pose_matrix
ray_orig = c2w[:3, 3] # (3,)
R_world = c2w[:3, :3] # (3,3)
ray_dirs = (R_world @ dirs_cam.T).T # (N, 3)
# Step 4: plane intersection
n = plane_normal / jnp.linalg.norm(plane_normal)
p0 = plane_point
denom = jnp.dot(ray_dirs, n) # (N,)
numer = jnp.dot((p0 - ray_orig), n) # scalar
t = numer / denom # (N,)
points_world = ray_orig + ray_dirs * t[:, None]
return points_world
@jaxtyped(typechecker=beartype)
def unproject_points_to_z_plane(
points_2d: Float[Array, "N 2"],
K: Float[Array, "3 3"],
dist_coeffs: Float[Array, "5"],
pose_matrix: Float[Array, "4 4"],
z: float = 0.0,
) -> Float[Array, "N 3"]:
plane_normal = jnp.array([0.0, 0.0, 1.0])
plane_point = jnp.array([0.0, 0.0, z])
return unproject_points_onto_plane(
points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def project( def project(
points_3d: Num[Array, "N 3"], points_3d: Num[Array, "N 3"],
@ -242,6 +317,9 @@ class Camera:
Camera parameters Camera parameters
""" """
def __repr__(self) -> str:
return f"<Camera id={self.id}>"
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]: def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
""" """
Project 3D points to 2D points using this camera's parameters Project 3D points to 2D points using this camera's parameters
@ -292,6 +370,20 @@ class Camera:
dist_coeffs=self.params.dist_coeffs, dist_coeffs=self.params.dist_coeffs,
) )
def unproject_points_to_z_plane(
self, points_2d: Num[Array, "N 2"], z: float = 0.0
) -> Num[Array, "N 3"]:
"""
Unproject 2D points to 3D points on a plane at z = constant.
"""
return unproject_points_to_z_plane(
points_2d,
self.params.K,
self.params.dist_coeffs,
self.params.pose_matrix,
z,
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -337,6 +429,9 @@ class Detection:
object.__setattr__(self, "_kp_undistorted", kpu) object.__setattr__(self, "_kp_undistorted", kpu)
return kpu return kpu
def __repr__(self) -> str:
return f"Detection({self.camera}, {self.timestamp})"
def classify_by_camera( def classify_by_camera(
detections: list[Detection], detections: list[Detection],

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -43,7 +43,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -75,7 +75,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 31,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -97,7 +97,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 33,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -174,7 +174,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 34,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -246,7 +246,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 35,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -339,7 +339,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 37,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -348,7 +348,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -432,7 +432,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -530,21 +530,32 @@
"\n", "\n",
"\n", "\n",
"@jaxtyped(typechecker=beartype)\n", "@jaxtyped(typechecker=beartype)\n",
"@dataclass\n", "@dataclass(frozen=True)\n",
"class Tracking:\n", "class Tracking:\n",
" id: int\n", " id: int\n",
" keypoints: Float[Array, \"J 3\"]\n", " keypoints: Float[Array, \"J 3\"]\n",
" last_active_timestamp: datetime\n",
"\n",
" def __repr__(self) -> str:\n",
" return f\"Tracking({self.id}, {self.last_active_timestamp})\"\n",
"\n", "\n",
"\n", "\n",
"@jaxtyped(typechecker=beartype)\n", "@jaxtyped(typechecker=beartype)\n",
"def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"N 3\"]:\n", "def triangle_from_cluster(\n",
" cluster: list[Detection],\n",
") -> tuple[Float[Array, \"N 3\"], datetime]:\n",
" proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n", " proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n",
" points = jnp.array([el.keypoints_undistorted for el in cluster])\n", " points = jnp.array([el.keypoints_undistorted for el in cluster])\n",
" confidences = jnp.array([el.confidences for el in cluster])\n", " confidences = jnp.array([el.confidences for el in cluster])\n",
" return triangulate_points_from_multiple_views_linear(\n", " latest_timestamp = max(el.timestamp for el in cluster)\n",
" return (\n",
" triangulate_points_from_multiple_views_linear(\n",
" proj_matrices, points, confidences=confidences\n", " proj_matrices, points, confidences=confidences\n",
" ),\n",
" latest_timestamp,\n",
" )\n", " )\n",
"\n", "\n",
"\n",
"# res = {\n", "# res = {\n",
"# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n", "# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n",
"# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n", "# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n",
@ -552,28 +563,39 @@
"# with open(\"samples/res.json\", \"wb\") as f:\n", "# with open(\"samples/res.json\", \"wb\") as f:\n",
"# f.write(orjson.dumps(res))\n", "# f.write(orjson.dumps(res))\n",
"\n", "\n",
"\n",
"class GlobalTrackingState:\n", "class GlobalTrackingState:\n",
" _last_id: int\n", " _last_id: int\n",
" _trackings: list[Tracking]\n", " _trackings: dict[int, Tracking]\n",
"\n", "\n",
" def __init__(self):\n", " def __init__(self):\n",
" self._last_id = 0\n", " self._last_id = 0\n",
" self._trackings = []\n", " self._trackings = {}\n",
"\n",
" def __repr__(self) -> str:\n",
" return (\n",
" f\"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})\"\n",
" )\n",
"\n", "\n",
" @property\n", " @property\n",
" def trackings(self) -> list[Tracking]:\n", " def trackings(self) -> dict[int, Tracking]:\n",
" return shallow_copy(self._trackings)\n", " return shallow_copy(self._trackings)\n",
"\n", "\n",
" def add_tracking(self, cluster: list[Detection]) -> Tracking:\n", " def add_tracking(self, cluster: list[Detection]) -> Tracking:\n",
" tracking = Tracking(id=self._last_id, keypoints=triangle_from_cluster(cluster))\n", " kps_3d, latest_timestamp = triangle_from_cluster(cluster)\n",
" self._last_id += 1\n", " next_id = self._last_id + 1\n",
" self._trackings.append(tracking)\n", " tracking = Tracking(\n",
" id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp\n",
" )\n",
" self._trackings[next_id] = tracking\n",
" self._last_id = next_id\n",
" return tracking\n", " return tracking\n",
"\n", "\n",
"\n", "\n",
"global_tracking_state = GlobalTrackingState()\n", "global_tracking_state = GlobalTrackingState()\n",
"for cluster in clusters_detections:\n", "for cluster in clusters_detections:\n",
" global_tracking_state.add_tracking(cluster)" " global_tracking_state.add_tracking(cluster)\n",
"display(global_tracking_state)"
] ]
}, },
{ {
@ -585,6 +607,28 @@
"next_group = next(sync_gen)\n", "next_group = next(sync_gen)\n",
"display(next_group)" "display(next_group)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from app.camera import classify_by_camera\n",
"\n",
"# let's do cross-view association\n",
"trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)\n",
"detections = shallow_copy(next_group)\n",
"# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections\n",
"affinity = np.zeros((len(trackings), len(detections)))\n",
"detection_by_camera = classify_by_camera(detections)\n",
"for i, tracking in enumerate(trackings):\n",
" for c, detections in detection_by_camera.items():\n",
" camera = next(iter(detections)).camera\n",
" # pixel space, unnormalized\n",
" tracking_2d_projection = camera.project(tracking.keypoints)\n",
" \n"
]
} }
], ],
"metadata": { "metadata": {