This commit is contained in:
2025-04-28 18:01:24 +08:00
parent 7ee4002567
commit b3ed20296a
2 changed files with 103 additions and 143 deletions

View File

@ -227,11 +227,13 @@ def project(
# Fall back to normalized coordinates if image_size not provided
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
# only valid points need distortion
if jnp.any(valid):
valid_p2d = p2d[valid]
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
p2d = p2d.at[valid].set(distorted_valid)
# Distort *all* points, then blend results using `where` to keep
# numerical traces inside JAX this avoids Python ``if`` with a traced
# value (which triggers TracerBoolConversionError when the function is
# vmapped/jitted).
distorted_all = distortion(p2d, K, dist_coeffs)
# Broadcast the valid mask over the last (x,y) dimension
p2d = jnp.where(valid[:, None], distorted_all, p2d)
elif dist_coeffs is None and K is None:
pass
else:
@ -239,7 +241,7 @@ def project(
"dist_coeffs and K must be provided together to compute distortion"
)
return jnp.squeeze(p2d)
return p2d # type: ignore
@jaxtyped(typechecker=beartype)