forked from HQU-gxy/CVTH3PE
feat: Enhance play notebook and camera module with new unprojection functionalities
- Updated the play notebook to include new methods for unprojecting 2D points onto a 3D plane. - Introduced `unproject_points_onto_plane` and `unproject_points_to_z_plane` functions in the camera module for improved point handling. - Enhanced the `Camera` class with a method for unprojecting points to a specified z-plane. - Cleaned up execution counts in the notebook for better organization and clarity.
This commit is contained in:
@ -156,19 +156,3 @@ class MyDataclass:
|
|||||||
x: Float[Array, "batch"]
|
x: Float[Array, "batch"]
|
||||||
y: Float[Array, "batch"]
|
y: Float[Array, "batch"]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Scalars, PRNG keys
|
|
||||||
|
|
||||||
For convenience, jaxtyping also includes `jaxtyping.Scalar`, `jaxtyping.ScalarLike`, and `jaxtyping.PRNGKeyArray`, defined as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Scalar = Shaped[Array, ""]
|
|
||||||
ScalarLike = Shaped[ArrayLike, ""]
|
|
||||||
|
|
||||||
# Left: new-style typed keys; right: old-style keys. See JEP 9263.
|
|
||||||
PRNGKeyArray = Union[Key[Array, ""], UInt32[Array, "2"]]
|
|
||||||
```
|
|
||||||
|
|
||||||
Recalling that shape-and-dtype specified jaxtyping arrays can be nested, this means that e.g. you can annotate the output of `jax.random.split` with `Shaped[PRNGKeyArray, "2"]`, or e.g. an integer scalar with `Int[Scalar, ""]`.
|
|
||||||
|
|
||||||
Note that `jaxtyping.{Scalar, ScalarLike, PRNGKeyArray}` are only available if JAX has been installed.
|
|
||||||
@ -6,7 +6,7 @@ from typing import Any, TypeAlias, TypedDict, Optional
|
|||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jaxtyping import Num, jaxtyped, Array
|
from jaxtyping import Num, jaxtyped, Array, Float
|
||||||
from cv2 import undistortPoints
|
from cv2 import undistortPoints
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -87,6 +87,81 @@ def distortion(
|
|||||||
return jnp.stack([u, v], axis=1)
|
return jnp.stack([u, v], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def unproject_points_onto_plane(
|
||||||
|
points_2d: Float[Array, "N 2"],
|
||||||
|
plane_normal: Float[Array, "3"],
|
||||||
|
plane_point: Float[Array, "3"],
|
||||||
|
K: Float[Array, "3 3"], # pylint: disable=invalid-name
|
||||||
|
dist_coeffs: Float[Array, "5"],
|
||||||
|
pose_matrix: Float[Array, "4 4"],
|
||||||
|
) -> Float[Array, "N 3"]:
|
||||||
|
"""
|
||||||
|
Un-project 2D image points onto an arbitrary 3D plane.
|
||||||
|
This function computes the ray-plane intersections, since every `points_2d`
|
||||||
|
could be treated as a ray.
|
||||||
|
|
||||||
|
(i.e. back-project points onto a plane)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points_2d: [..., 2] image pixel coordinates
|
||||||
|
plane_normal: (3,) normal vector of the plane in world coords
|
||||||
|
plane_point: (3,) a known point on the plane in world coords
|
||||||
|
K: Camera intrinsic matrix
|
||||||
|
dist_coeffs: Distortion coefficients
|
||||||
|
pose_matrix: Camera-to-World (C2W) transformation matrix
|
||||||
|
|
||||||
|
Note:
|
||||||
|
`pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C),
|
||||||
|
but the inverse of it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[..., 3] world-space intersection points
|
||||||
|
"""
|
||||||
|
# Step 1: undistort (no-op here)
|
||||||
|
pts = undistort_points(
|
||||||
|
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: normalize image coordinates into camera rays
|
||||||
|
fx, fy = K[0, 0], K[1, 1]
|
||||||
|
cx, cy = K[0, 2], K[1, 2]
|
||||||
|
dirs_cam = jnp.stack(
|
||||||
|
[(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])],
|
||||||
|
axis=-1,
|
||||||
|
) # (N, 3)
|
||||||
|
|
||||||
|
# Step 3: transform rays into world space
|
||||||
|
c2w = pose_matrix
|
||||||
|
ray_orig = c2w[:3, 3] # (3,)
|
||||||
|
R_world = c2w[:3, :3] # (3,3)
|
||||||
|
ray_dirs = (R_world @ dirs_cam.T).T # (N, 3)
|
||||||
|
|
||||||
|
# Step 4: plane intersection
|
||||||
|
n = plane_normal / jnp.linalg.norm(plane_normal)
|
||||||
|
p0 = plane_point
|
||||||
|
denom = jnp.dot(ray_dirs, n) # (N,)
|
||||||
|
numer = jnp.dot((p0 - ray_orig), n) # scalar
|
||||||
|
t = numer / denom # (N,)
|
||||||
|
points_world = ray_orig + ray_dirs * t[:, None]
|
||||||
|
return points_world
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def unproject_points_to_z_plane(
|
||||||
|
points_2d: Float[Array, "N 2"],
|
||||||
|
K: Float[Array, "3 3"],
|
||||||
|
dist_coeffs: Float[Array, "5"],
|
||||||
|
pose_matrix: Float[Array, "4 4"],
|
||||||
|
z: float = 0.0,
|
||||||
|
) -> Float[Array, "N 3"]:
|
||||||
|
plane_normal = jnp.array([0.0, 0.0, 1.0])
|
||||||
|
plane_point = jnp.array([0.0, 0.0, z])
|
||||||
|
return unproject_points_onto_plane(
|
||||||
|
points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def project(
|
def project(
|
||||||
points_3d: Num[Array, "N 3"],
|
points_3d: Num[Array, "N 3"],
|
||||||
@ -242,6 +317,9 @@ class Camera:
|
|||||||
Camera parameters
|
Camera parameters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Camera id={self.id}>"
|
||||||
|
|
||||||
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||||||
"""
|
"""
|
||||||
Project 3D points to 2D points using this camera's parameters
|
Project 3D points to 2D points using this camera's parameters
|
||||||
@ -292,6 +370,20 @@ class Camera:
|
|||||||
dist_coeffs=self.params.dist_coeffs,
|
dist_coeffs=self.params.dist_coeffs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def unproject_points_to_z_plane(
|
||||||
|
self, points_2d: Num[Array, "N 2"], z: float = 0.0
|
||||||
|
) -> Num[Array, "N 3"]:
|
||||||
|
"""
|
||||||
|
Unproject 2D points to 3D points on a plane at z = constant.
|
||||||
|
"""
|
||||||
|
return unproject_points_to_z_plane(
|
||||||
|
points_2d,
|
||||||
|
self.params.K,
|
||||||
|
self.params.dist_coeffs,
|
||||||
|
self.params.pose_matrix,
|
||||||
|
z,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -337,6 +429,9 @@ class Detection:
|
|||||||
object.__setattr__(self, "_kp_undistorted", kpu)
|
object.__setattr__(self, "_kp_undistorted", kpu)
|
||||||
return kpu
|
return kpu
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Detection({self.camera}, {self.timestamp})"
|
||||||
|
|
||||||
|
|
||||||
def classify_by_camera(
|
def classify_by_camera(
|
||||||
detections: list[Detection],
|
detections: list[Detection],
|
||||||
|
|||||||
84
play.ipynb
84
play.ipynb
@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 28,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -43,7 +43,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 30,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -75,7 +75,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 31,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -97,7 +97,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 33,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -174,7 +174,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 34,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -246,7 +246,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 35,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -339,7 +339,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 37,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -348,7 +348,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 38,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -432,7 +432,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 43,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -530,21 +530,32 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@jaxtyped(typechecker=beartype)\n",
|
"@jaxtyped(typechecker=beartype)\n",
|
||||||
"@dataclass\n",
|
"@dataclass(frozen=True)\n",
|
||||||
"class Tracking:\n",
|
"class Tracking:\n",
|
||||||
" id: int\n",
|
" id: int\n",
|
||||||
" keypoints: Float[Array, \"J 3\"]\n",
|
" keypoints: Float[Array, \"J 3\"]\n",
|
||||||
|
" last_active_timestamp: datetime\n",
|
||||||
|
"\n",
|
||||||
|
" def __repr__(self) -> str:\n",
|
||||||
|
" return f\"Tracking({self.id}, {self.last_active_timestamp})\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@jaxtyped(typechecker=beartype)\n",
|
"@jaxtyped(typechecker=beartype)\n",
|
||||||
"def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"N 3\"]:\n",
|
"def triangle_from_cluster(\n",
|
||||||
|
" cluster: list[Detection],\n",
|
||||||
|
") -> tuple[Float[Array, \"N 3\"], datetime]:\n",
|
||||||
" proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n",
|
" proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n",
|
||||||
" points = jnp.array([el.keypoints_undistorted for el in cluster])\n",
|
" points = jnp.array([el.keypoints_undistorted for el in cluster])\n",
|
||||||
" confidences = jnp.array([el.confidences for el in cluster])\n",
|
" confidences = jnp.array([el.confidences for el in cluster])\n",
|
||||||
" return triangulate_points_from_multiple_views_linear(\n",
|
" latest_timestamp = max(el.timestamp for el in cluster)\n",
|
||||||
" proj_matrices, points, confidences=confidences\n",
|
" return (\n",
|
||||||
|
" triangulate_points_from_multiple_views_linear(\n",
|
||||||
|
" proj_matrices, points, confidences=confidences\n",
|
||||||
|
" ),\n",
|
||||||
|
" latest_timestamp,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# res = {\n",
|
"# res = {\n",
|
||||||
"# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n",
|
"# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n",
|
||||||
"# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n",
|
"# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n",
|
||||||
@ -552,28 +563,39 @@
|
|||||||
"# with open(\"samples/res.json\", \"wb\") as f:\n",
|
"# with open(\"samples/res.json\", \"wb\") as f:\n",
|
||||||
"# f.write(orjson.dumps(res))\n",
|
"# f.write(orjson.dumps(res))\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"class GlobalTrackingState:\n",
|
"class GlobalTrackingState:\n",
|
||||||
" _last_id: int\n",
|
" _last_id: int\n",
|
||||||
" _trackings: list[Tracking]\n",
|
" _trackings: dict[int, Tracking]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def __init__(self):\n",
|
" def __init__(self):\n",
|
||||||
" self._last_id = 0\n",
|
" self._last_id = 0\n",
|
||||||
" self._trackings = []\n",
|
" self._trackings = {}\n",
|
||||||
|
"\n",
|
||||||
|
" def __repr__(self) -> str:\n",
|
||||||
|
" return (\n",
|
||||||
|
" f\"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})\"\n",
|
||||||
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @property\n",
|
" @property\n",
|
||||||
" def trackings(self) -> list[Tracking]:\n",
|
" def trackings(self) -> dict[int, Tracking]:\n",
|
||||||
" return shallow_copy(self._trackings)\n",
|
" return shallow_copy(self._trackings)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def add_tracking(self, cluster: list[Detection]) -> Tracking:\n",
|
" def add_tracking(self, cluster: list[Detection]) -> Tracking:\n",
|
||||||
" tracking = Tracking(id=self._last_id, keypoints=triangle_from_cluster(cluster))\n",
|
" kps_3d, latest_timestamp = triangle_from_cluster(cluster)\n",
|
||||||
" self._last_id += 1\n",
|
" next_id = self._last_id + 1\n",
|
||||||
" self._trackings.append(tracking)\n",
|
" tracking = Tracking(\n",
|
||||||
|
" id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp\n",
|
||||||
|
" )\n",
|
||||||
|
" self._trackings[next_id] = tracking\n",
|
||||||
|
" self._last_id = next_id\n",
|
||||||
" return tracking\n",
|
" return tracking\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"global_tracking_state = GlobalTrackingState()\n",
|
"global_tracking_state = GlobalTrackingState()\n",
|
||||||
"for cluster in clusters_detections:\n",
|
"for cluster in clusters_detections:\n",
|
||||||
" global_tracking_state.add_tracking(cluster)"
|
" global_tracking_state.add_tracking(cluster)\n",
|
||||||
|
"display(global_tracking_state)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -585,6 +607,28 @@
|
|||||||
"next_group = next(sync_gen)\n",
|
"next_group = next(sync_gen)\n",
|
||||||
"display(next_group)"
|
"display(next_group)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from app.camera import classify_by_camera\n",
|
||||||
|
"\n",
|
||||||
|
"# let's do cross-view association\n",
|
||||||
|
"trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)\n",
|
||||||
|
"detections = shallow_copy(next_group)\n",
|
||||||
|
"# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections\n",
|
||||||
|
"affinity = np.zeros((len(trackings), len(detections)))\n",
|
||||||
|
"detection_by_camera = classify_by_camera(detections)\n",
|
||||||
|
"for i, tracking in enumerate(trackings):\n",
|
||||||
|
" for c, detections in detection_by_camera.items():\n",
|
||||||
|
" camera = next(iter(detections)).camera\n",
|
||||||
|
" # pixel space, unnormalized\n",
|
||||||
|
" tracking_2d_projection = camera.project(tracking.keypoints)\n",
|
||||||
|
" \n"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Reference in New Issue
Block a user