feat: Enhance 3D pose prediction and camera affinity calculations
- Introduced a JAX-friendly implementation of `predict_pose_3d` to predict 3D poses based on tracking velocities, improving performance by avoiding Python control flow. - Updated `calculate_camera_affinity_matrix_jax` to streamline the affinity matrix calculation, incorporating velocity handling for trackings and enhancing clarity in the distance computation. - Added properties for normalized keypoints in the `Detection` class, allowing for lazy evaluation and improved usability. - Enhanced documentation throughout to clarify function purposes and parameters.
This commit is contained in:
@ -29,7 +29,6 @@ def undistort_points(
|
||||
return res.reshape(-1, 2)
|
||||
|
||||
|
||||
@jax.jit
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def distortion(
|
||||
points_2d: Num[Array, "N 2"],
|
||||
@ -169,6 +168,9 @@ def unproject_points_to_z_plane(
|
||||
|
||||
Returns:
|
||||
[..., 3] world-space intersection points
|
||||
|
||||
Note:
|
||||
This function is not JAX-friendly, since it use `cv2.undistortPoints` internally.
|
||||
"""
|
||||
plane_normal = jnp.array([0.0, 0.0, 1.0])
|
||||
plane_point = jnp.array([0.0, 0.0, z])
|
||||
@ -406,6 +408,9 @@ class Camera:
|
||||
|
||||
Returns:
|
||||
[..., 3] world-space intersection points
|
||||
|
||||
Note:
|
||||
This function is not JAX-friendly, since it use `cv2.undistortPoints` internally.
|
||||
"""
|
||||
return unproject_points_to_z_plane(
|
||||
points_2d,
|
||||
@ -460,6 +465,32 @@ class Detection:
|
||||
object.__setattr__(self, "_kp_undistorted", kpu)
|
||||
return kpu
|
||||
|
||||
@property
|
||||
def normalized_keypoints_undistorted(self) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Returns normalized keypoints, without distortion.
|
||||
|
||||
The result is cached on first access. (lazy evaluation)
|
||||
"""
|
||||
norm_kps = getattr(self, "_norm_kps_undistorted", None)
|
||||
if norm_kps is None:
|
||||
norm_kps = self.keypoints_undistorted / self.camera.params.image_size
|
||||
object.__setattr__(self, "_norm_kps_undistorted", norm_kps)
|
||||
return norm_kps
|
||||
|
||||
@property
|
||||
def normalized_keypoints(self) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Returns normalized keypoints.
|
||||
|
||||
The result is cached on first access. (lazy evaluation)
|
||||
"""
|
||||
norm_kps = getattr(self, "_norm_kps", None)
|
||||
if norm_kps is None:
|
||||
norm_kps = self.keypoints / self.camera.params.image_size
|
||||
object.__setattr__(self, "_norm_kps", norm_kps)
|
||||
return norm_kps
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Detection({self.camera}, {self.timestamp})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user