2.0 KiB
2.0 KiB
Jaxtyping Summary
Use this reference when the task involves NumPy, JAX, PyTorch, TensorFlow, MLX, or array-like inputs that should carry shape and dtype information.
Core syntax
- Use
DType[array_type, "shape"]. - Examples:
Float[np.ndarray, "batch channels"]Int[np.ndarray, "persons"]Shaped[ArrayLike, "batch time features"]Float[Tensor, "... channels"]
Shape rules
- Reuse names to enforce equality across values:
"batch time"with"time features". - Use fixed integers for exact sizes:
"3 3". - Use
...for zero or more anonymous axes. - Use
*namefor a named variadic axis. - Use
#namewhen size1should also be accepted for broadcasting. - Use
_orname=...only for documentation when runtime enforcement is not wanted.
Array type guidance
- Prefer concrete normalized types in core logic:
Float[np.ndarray, "..."]Float[torch.Tensor, "..."]Float[jax.Array, "..."]
- Use
Shaped[ArrayLike, "..."]or another broader input type only at ingestion boundaries. - Create aliases for repeated shapes instead of rewriting them in every signature.
from jaxtyping import Float
import numpy as np
FramePoints = Float[np.ndarray, "frames keypoints dims"]
Runtime checking
- Pair
jaxtypingwithbeartypefor runtime validation:
from beartype import beartype
from jaxtyping import Float, jaxtyped
import numpy as np
@jaxtyped(typechecker=beartype)
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
return x - x.mean(axis=0)
- Apply this at stable boundaries and in tests, not blindly on every hot loop.
- Avoid
from __future__ import annotationswhen relying on runtime checking.
Practical defaults
- Prefer meaningful axis names like
batch,frames,persons,keypoints,dims,channels. - Keep aliases near the module or domain where they are used.
- If static typing and runtime truth diverge, validate at runtime first, then use a commented
cast(...)at the narrow boundary.