# Jaxtyping Summary Use this reference when the task involves NumPy, JAX, PyTorch, TensorFlow, MLX, or array-like inputs that should carry shape and dtype information. ## Core syntax - Use `DType[array_type, "shape"]`. - Examples: - `Float[np.ndarray, "batch channels"]` - `Int[np.ndarray, "persons"]` - `Shaped[ArrayLike, "batch time features"]` - `Float[Tensor, "... channels"]` ## Shape rules - Reuse names to enforce equality across values: `"batch time"` with `"time features"`. - Use fixed integers for exact sizes: `"3 3"`. - Use `...` for zero or more anonymous axes. - Use `*name` for a named variadic axis. - Use `#name` when size `1` should also be accepted for broadcasting. - Use `_` or `name=...` only for documentation when runtime enforcement is not wanted. ## Array type guidance - Prefer concrete normalized types in core logic: - `Float[np.ndarray, "..."]` - `Float[torch.Tensor, "..."]` - `Float[jax.Array, "..."]` - Use `Shaped[ArrayLike, "..."]` or another broader input type only at ingestion boundaries. - Create aliases for repeated shapes instead of rewriting them in every signature. ```python from jaxtyping import Float import numpy as np FramePoints = Float[np.ndarray, "frames keypoints dims"] ``` ## Runtime checking - Pair `jaxtyping` with `beartype` for runtime validation: ```python from beartype import beartype from jaxtyping import Float, jaxtyped import numpy as np @jaxtyped(typechecker=beartype) def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]: return x - x.mean(axis=0) ``` - Apply this at stable boundaries and in tests, not blindly on every hot loop. - Avoid `from __future__ import annotations` when relying on runtime checking. ## Practical defaults - Prefer meaningful axis names like `batch`, `frames`, `persons`, `keypoints`, `dims`, `channels`. - Keep aliases near the module or domain where they are used. - If static typing and runtime truth diverge, validate at runtime first, then use a commented `cast(...)` at the narrow boundary.