62 lines
2.0 KiB
Markdown
62 lines
2.0 KiB
Markdown
# Jaxtyping Summary
|
|
|
|
Use this reference when the task involves NumPy, JAX, PyTorch, TensorFlow, MLX, or array-like inputs that should carry shape and dtype information.
|
|
|
|
## Core syntax
|
|
|
|
- Use `DType[array_type, "shape"]`.
|
|
- Examples:
|
|
- `Float[np.ndarray, "batch channels"]`
|
|
- `Int[np.ndarray, "persons"]`
|
|
- `Shaped[ArrayLike, "batch time features"]`
|
|
- `Float[Tensor, "... channels"]`
|
|
|
|
## Shape rules
|
|
|
|
- Reuse names to enforce equality across values: `"batch time"` with `"time features"`.
|
|
- Use fixed integers for exact sizes: `"3 3"`.
|
|
- Use `...` for zero or more anonymous axes.
|
|
- Use `*name` for a named variadic axis.
|
|
- Use `#name` when size `1` should also be accepted for broadcasting.
|
|
- Use `_` or `name=...` only for documentation when runtime enforcement is not wanted.
|
|
|
|
## Array type guidance
|
|
|
|
- Prefer concrete normalized types in core logic:
|
|
- `Float[np.ndarray, "..."]`
|
|
- `Float[torch.Tensor, "..."]`
|
|
- `Float[jax.Array, "..."]`
|
|
- Use `Shaped[ArrayLike, "..."]` or another broader input type only at ingestion boundaries.
|
|
- Create aliases for repeated shapes instead of rewriting them in every signature.
|
|
|
|
```python
|
|
from jaxtyping import Float
|
|
import numpy as np
|
|
|
|
FramePoints = Float[np.ndarray, "frames keypoints dims"]
|
|
```
|
|
|
|
## Runtime checking
|
|
|
|
- Pair `jaxtyping` with `beartype` for runtime validation:
|
|
|
|
```python
|
|
from beartype import beartype
|
|
from jaxtyping import Float, jaxtyped
|
|
import numpy as np
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
|
|
return x - x.mean(axis=0)
|
|
```
|
|
|
|
- Apply this at stable boundaries and in tests, not blindly on every hot loop.
|
|
- Avoid `from __future__ import annotations` when relying on runtime checking.
|
|
|
|
## Practical defaults
|
|
|
|
- Prefer meaningful axis names like `batch`, `frames`, `persons`, `keypoints`, `dims`, `channels`.
|
|
- Keep aliases near the module or domain where they are used.
|
|
- If static typing and runtime truth diverge, validate at runtime first, then use a commented `cast(...)` at the narrow boundary.
|