forked from HQU-gxy/CVTH3PE
revert
This commit is contained in:
@ -227,11 +227,13 @@ def project(
|
||||
# Fall back to normalized coordinates if image_size not provided
|
||||
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
||||
|
||||
# only valid points need distortion
|
||||
if jnp.any(valid):
|
||||
valid_p2d = p2d[valid]
|
||||
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
|
||||
p2d = p2d.at[valid].set(distorted_valid)
|
||||
# Distort *all* points, then blend results using `where` to keep
|
||||
# numerical traces inside JAX – this avoids Python ``if`` with a traced
|
||||
# value (which triggers TracerBoolConversionError when the function is
|
||||
# vmapped/jitted).
|
||||
distorted_all = distortion(p2d, K, dist_coeffs)
|
||||
# Broadcast the valid mask over the last (x,y) dimension
|
||||
p2d = jnp.where(valid[:, None], distorted_all, p2d)
|
||||
elif dist_coeffs is None and K is None:
|
||||
pass
|
||||
else:
|
||||
@ -239,7 +241,7 @@ def project(
|
||||
"dist_coeffs and K must be provided together to compute distortion"
|
||||
)
|
||||
|
||||
return jnp.squeeze(p2d)
|
||||
return p2d # type: ignore
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
|
||||
Reference in New Issue
Block a user