1
0
forked from HQU-gxy/CVTH3PE

feat: Enhance 3D pose prediction and camera affinity calculations

- Introduced a JAX-friendly implementation of `predict_pose_3d` to predict 3D poses based on tracking velocities, improving performance by avoiding Python control flow.
- Updated `calculate_camera_affinity_matrix_jax` to streamline the affinity matrix calculation, incorporating velocity handling for trackings and enhancing clarity in the distance computation.
- Added properties for normalized keypoints in the `Detection` class, allowing for lazy evaluation and improved usability.
- Enhanced documentation throughout to clarify function purposes and parameters.
This commit is contained in:
2025-04-29 12:06:07 +08:00
parent 6194f083cb
commit 86fcc5f283
2 changed files with 80 additions and 18 deletions

View File

@ -29,7 +29,6 @@ def undistort_points(
return res.reshape(-1, 2)
@jax.jit
@jaxtyped(typechecker=beartype)
def distortion(
points_2d: Num[Array, "N 2"],
@ -169,6 +168,9 @@ def unproject_points_to_z_plane(
Returns:
[..., 3] world-space intersection points
Note:
This function is not JAX-friendly, since it use `cv2.undistortPoints` internally.
"""
plane_normal = jnp.array([0.0, 0.0, 1.0])
plane_point = jnp.array([0.0, 0.0, z])
@ -406,6 +408,9 @@ class Camera:
Returns:
[..., 3] world-space intersection points
Note:
This function is not JAX-friendly, since it use `cv2.undistortPoints` internally.
"""
return unproject_points_to_z_plane(
points_2d,
@ -460,6 +465,32 @@ class Detection:
object.__setattr__(self, "_kp_undistorted", kpu)
return kpu
@property
def normalized_keypoints_undistorted(self) -> Num[Array, "N 2"]:
"""
Returns normalized keypoints, without distortion.
The result is cached on first access. (lazy evaluation)
"""
norm_kps = getattr(self, "_norm_kps_undistorted", None)
if norm_kps is None:
norm_kps = self.keypoints_undistorted / self.camera.params.image_size
object.__setattr__(self, "_norm_kps_undistorted", norm_kps)
return norm_kps
@property
def normalized_keypoints(self) -> Num[Array, "N 2"]:
"""
Returns normalized keypoints.
The result is cached on first access. (lazy evaluation)
"""
norm_kps = getattr(self, "_norm_kps", None)
if norm_kps is None:
norm_kps = self.keypoints / self.camera.params.image_size
object.__setattr__(self, "_norm_kps", norm_kps)
return norm_kps
def __repr__(self) -> str:
return f"Detection({self.camera}, {self.timestamp})"