feat: Enhance play notebook and camera module with new unprojection functionalities
- Updated the play notebook to include new methods for unprojecting 2D points onto a 3D plane. - Introduced `unproject_points_onto_plane` and `unproject_points_to_z_plane` functions in the camera module for improved point handling. - Enhanced the `Camera` class with a method for unprojecting points to a specified z-plane. - Cleaned up execution counts in the notebook for better organization and clarity.
This commit is contained in:
@ -156,19 +156,3 @@ class MyDataclass:
|
||||
x: Float[Array, "batch"]
|
||||
y: Float[Array, "batch"]
|
||||
```
|
||||
|
||||
## Scalars, PRNG keys
|
||||
|
||||
For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as:
|
||||
|
||||
```python
|
||||
Scalar = Shaped[Array, ""]
|
||||
ScalarLike = Shaped[ArrayLike, ""]
|
||||
|
||||
# Left: new-style typed keys; right: old-style keys. See JEP 9263.
|
||||
PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]]
|
||||
```
|
||||
|
||||
Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`.
|
||||
|
||||
Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed.
|
||||
@ -6,7 +6,7 @@ from typing import Any, TypeAlias, TypedDict, Optional
|
||||
from beartype import beartype
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jaxtyping import Num, jaxtyped, Array
|
||||
from jaxtyping import Num, jaxtyped, Array, Float
|
||||
from cv2 import undistortPoints
|
||||
import numpy as np
|
||||
|
||||
@ -87,6 +87,81 @@ def distortion(
|
||||
return jnp.stack([u, v], axis=1)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def unproject_points_onto_plane(
|
||||
points_2d: Float[Array, "N 2"],
|
||||
plane_normal: Float[Array, "3"],
|
||||
plane_point: Float[Array, "3"],
|
||||
K: Float[Array, "3 3"], # pylint: disable=invalid-name
|
||||
dist_coeffs: Float[Array, "5"],
|
||||
pose_matrix: Float[Array, "4 4"],
|
||||
) -> Float[Array, "N 3"]:
|
||||
"""
|
||||
Un-project 2D image points onto an arbitrary 3D plane.
|
||||
This function computes the ray-plane intersections, since every `points_2d`
|
||||
could be treated as a ray.
|
||||
|
||||
(i.e. back-project points onto a plane)
|
||||
|
||||
Args:
|
||||
points_2d: [..., 2] image pixel coordinates
|
||||
plane_normal: (3,) normal vector of the plane in world coords
|
||||
plane_point: (3,) a known point on the plane in world coords
|
||||
K: Camera intrinsic matrix
|
||||
dist_coeffs: Distortion coefficients
|
||||
pose_matrix: Camera-to-World (C2W) transformation matrix
|
||||
|
||||
Note:
|
||||
`pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C),
|
||||
but the inverse of it.
|
||||
|
||||
Returns:
|
||||
[..., 3] world-space intersection points
|
||||
"""
|
||||
# Step 1: undistort (no-op here)
|
||||
pts = undistort_points(
|
||||
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
|
||||
)
|
||||
|
||||
# Step 2: normalize image coordinates into camera rays
|
||||
fx, fy = K[0, 0], K[1, 1]
|
||||
cx, cy = K[0, 2], K[1, 2]
|
||||
dirs_cam = jnp.stack(
|
||||
[(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])],
|
||||
axis=-1,
|
||||
) # (N, 3)
|
||||
|
||||
# Step 3: transform rays into world space
|
||||
c2w = pose_matrix
|
||||
ray_orig = c2w[:3, 3] # (3,)
|
||||
R_world = c2w[:3, :3] # (3,3)
|
||||
ray_dirs = (R_world @ dirs_cam.T).T # (N, 3)
|
||||
|
||||
# Step 4: plane intersection
|
||||
n = plane_normal / jnp.linalg.norm(plane_normal)
|
||||
p0 = plane_point
|
||||
denom = jnp.dot(ray_dirs, n) # (N,)
|
||||
numer = jnp.dot((p0 - ray_orig), n) # scalar
|
||||
t = numer / denom # (N,)
|
||||
points_world = ray_orig + ray_dirs * t[:, None]
|
||||
return points_world
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def unproject_points_to_z_plane(
|
||||
points_2d: Float[Array, "N 2"],
|
||||
K: Float[Array, "3 3"],
|
||||
dist_coeffs: Float[Array, "5"],
|
||||
pose_matrix: Float[Array, "4 4"],
|
||||
z: float = 0.0,
|
||||
) -> Float[Array, "N 3"]:
|
||||
plane_normal = jnp.array([0.0, 0.0, 1.0])
|
||||
plane_point = jnp.array([0.0, 0.0, z])
|
||||
return unproject_points_onto_plane(
|
||||
points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix
|
||||
)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def project(
|
||||
points_3d: Num[Array, "N 3"],
|
||||
@ -242,6 +317,9 @@ class Camera:
|
||||
Camera parameters
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Camera id={self.id}>"
|
||||
|
||||
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Project 3D points to 2D points using this camera's parameters
|
||||
@ -292,6 +370,20 @@ class Camera:
|
||||
dist_coeffs=self.params.dist_coeffs,
|
||||
)
|
||||
|
||||
def unproject_points_to_z_plane(
|
||||
self, points_2d: Num[Array, "N 2"], z: float = 0.0
|
||||
) -> Num[Array, "N 3"]:
|
||||
"""
|
||||
Unproject 2D points to 3D points on a plane at z = constant.
|
||||
"""
|
||||
return unproject_points_to_z_plane(
|
||||
points_2d,
|
||||
self.params.K,
|
||||
self.params.dist_coeffs,
|
||||
self.params.pose_matrix,
|
||||
z,
|
||||
)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
@ -337,6 +429,9 @@ class Detection:
|
||||
object.__setattr__(self, "_kp_undistorted", kpu)
|
||||
return kpu
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Detection({self.camera}, {self.timestamp})"
|
||||
|
||||
|
||||
def classify_by_camera(
|
||||
detections: list[Detection],
|
||||
|
||||
82
play.ipynb
82
play.ipynb
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -43,7 +43,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -75,7 +75,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -97,7 +97,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -174,7 +174,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -246,7 +246,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -339,7 +339,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -432,7 +432,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -530,21 +530,32 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"@jaxtyped(typechecker=beartype)\n",
|
||||
"@dataclass\n",
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class Tracking:\n",
|
||||
" id: int\n",
|
||||
" keypoints: Float[Array, \"J 3\"]\n",
|
||||
" last_active_timestamp: datetime\n",
|
||||
"\n",
|
||||
" def __repr__(self) -> str:\n",
|
||||
" return f\"Tracking({self.id}, {self.last_active_timestamp})\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@jaxtyped(typechecker=beartype)\n",
|
||||
"def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"N 3\"]:\n",
|
||||
"def triangle_from_cluster(\n",
|
||||
" cluster: list[Detection],\n",
|
||||
") -> tuple[Float[Array, \"N 3\"], datetime]:\n",
|
||||
" proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n",
|
||||
" points = jnp.array([el.keypoints_undistorted for el in cluster])\n",
|
||||
" confidences = jnp.array([el.confidences for el in cluster])\n",
|
||||
" return triangulate_points_from_multiple_views_linear(\n",
|
||||
" latest_timestamp = max(el.timestamp for el in cluster)\n",
|
||||
" return (\n",
|
||||
" triangulate_points_from_multiple_views_linear(\n",
|
||||
" proj_matrices, points, confidences=confidences\n",
|
||||
" ),\n",
|
||||
" latest_timestamp,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# res = {\n",
|
||||
"# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n",
|
||||
"# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n",
|
||||
@ -552,28 +563,39 @@
|
||||
"# with open(\"samples/res.json\", \"wb\") as f:\n",
|
||||
"# f.write(orjson.dumps(res))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GlobalTrackingState:\n",
|
||||
" _last_id: int\n",
|
||||
" _trackings: list[Tracking]\n",
|
||||
" _trackings: dict[int, Tracking]\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" self._last_id = 0\n",
|
||||
" self._trackings = []\n",
|
||||
" self._trackings = {}\n",
|
||||
"\n",
|
||||
" def __repr__(self) -> str:\n",
|
||||
" return (\n",
|
||||
" f\"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def trackings(self) -> list[Tracking]:\n",
|
||||
" def trackings(self) -> dict[int, Tracking]:\n",
|
||||
" return shallow_copy(self._trackings)\n",
|
||||
"\n",
|
||||
" def add_tracking(self, cluster: list[Detection]) -> Tracking:\n",
|
||||
" tracking = Tracking(id=self._last_id, keypoints=triangle_from_cluster(cluster))\n",
|
||||
" self._last_id += 1\n",
|
||||
" self._trackings.append(tracking)\n",
|
||||
" kps_3d, latest_timestamp = triangle_from_cluster(cluster)\n",
|
||||
" next_id = self._last_id + 1\n",
|
||||
" tracking = Tracking(\n",
|
||||
" id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp\n",
|
||||
" )\n",
|
||||
" self._trackings[next_id] = tracking\n",
|
||||
" self._last_id = next_id\n",
|
||||
" return tracking\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"global_tracking_state = GlobalTrackingState()\n",
|
||||
"for cluster in clusters_detections:\n",
|
||||
" global_tracking_state.add_tracking(cluster)"
|
||||
" global_tracking_state.add_tracking(cluster)\n",
|
||||
"display(global_tracking_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -585,6 +607,28 @@
|
||||
"next_group = next(sync_gen)\n",
|
||||
"display(next_group)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from app.camera import classify_by_camera\n",
|
||||
"\n",
|
||||
"# let's do cross-view association\n",
|
||||
"trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)\n",
|
||||
"detections = shallow_copy(next_group)\n",
|
||||
"# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections\n",
|
||||
"affinity = np.zeros((len(trackings), len(detections)))\n",
|
||||
"detection_by_camera = classify_by_camera(detections)\n",
|
||||
"for i, tracking in enumerate(trackings):\n",
|
||||
" for c, detections in detection_by_camera.items():\n",
|
||||
" camera = next(iter(detections)).camera\n",
|
||||
" # pixel space, unnormalized\n",
|
||||
" tracking_2d_projection = camera.project(tracking.keypoints)\n",
|
||||
" \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user