6.3 KiB
6.3 KiB
name, description
| name | description |
|---|---|
| python-tensor-typing | Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested. |
Python Tensor Typing
Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.
Defaults
- Use
DType[ArrayType, "shape names"], for exampleFloat[np.ndarray, "batch channels"]. - Reuse axis names to express shared dimensions across arguments and returns.
- Prefer reusable aliases for common tensor shapes.
- Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
- Prefer
@jaxtyped(typechecker=beartype)on stable boundaries and test-targeted helpers by default so tensor shape or dtype mismatches fail early. - Avoid applying runtime checking blindly to hot inner loops.
- Only avoid
from __future__ import annotationsin modules that rely on runtime annotation inspection.
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped
Batch = Float[np.ndarray, "batch channels"]
@jaxtyped(typechecker=beartype)
def normalize(x: Batch) -> Batch:
...
Integrated Jaxtyping Reference
Use this section as the built-in cheat sheet. Do not rely on a separate reference file.
Mental Model
jaxtypingprovides shape and dtype annotations plus runtime type checking for JAX, PyTorch, NumPy, MLX, and TensorFlow arrays and tensors.- The name is historical; it is not JAX-only.
- Static type checkers do not fully understand shape constraints. In practice
dtype[array, shape]is mostly treated as justarray, so use runtime checks to catch shape or dtype mistakes early. - This skill defaults to
beartypeas the runtime checker.
Core Syntax
- Annotate arrays as
DType[ArrayType, "shape"]. - Common dtype families:
Shapedfor any dtypeBoolNumfor numeric tensorsReal,Float,Complex,Int,UInt- precision-specific forms like
Float32,Int64,UInt8
- Representative examples:
Float[np.ndarray, "batch channels"]Int[np.ndarray, "persons"]Shaped[ArrayLike, "batch time features"]Float[Tensor, "... channels"]
Shape Rules
- Use fixed integers for exact axes:
"3 3". - Use named axes to enforce equality across values:
"batch time"and"time features". - Use symbolic expressions when the return shape is derived from argument shapes:
"dim-1". - Use
...for anonymous zero-or-more axes. You may only use one variadic axis per annotation. - Use
*namefor a named variadic axis. - Use
#namewhen broadcasting with size1should be accepted. - Use
_nameor just_when an axis is documentation-only and should not be runtime-checked. - Use
name=...for documentation-only labels likerows=4. - Use
""for a scalar shape. - Use
"..."when you only want dtype checking and do not want to constrain shape. - Prefer meaningful names like
batch,frames,persons,keypoints,dims,channels.
Array And Alias Guidance
- Prefer concrete normalized array types in core logic:
Float[np.ndarray, "..."]Float[torch.Tensor, "..."]Float[jax.Array, "..."]Float[tf.Tensor, "..."]Float[mx.array, "..."]
- Use broader input types like
Shaped[ArrayLike, "..."]only at ingestion boundaries. - Keep repeated shapes in local aliases instead of rewriting them in every signature.
- You can nest existing
jaxtypingaliases:
from jaxtyping import Float
import jax
Image = Float[jax.Array, "channels height width"]
BatchImage = Float[Image, "batch"]
- Duck-typed arrays also work if they expose
.shapeand.dtype. typing.Any, unions, type aliases, and boundedTypeVars can all be used as the array type parameter when needed.
Runtime Checking Patterns
- Prefer
@jaxtyped(typechecker=beartype)for function-level checks. - Use it on adapters, public tensor utilities, deserializers, boundary normalization functions, and tests.
- Do not decorate the hottest inner loops unless the runtime cost is acceptable.
- Avoid stringized annotations and avoid
from __future__ import annotationsin code paths that depend on runtime inspection. - If static typing and runtime truth diverge, validate at runtime first and then use a narrowly scoped commented
cast(...).
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped
@jaxtyped(typechecker=beartype)
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
return x - x.mean(axis=0)
- For package-wide or test-only enforcement,
jaxtyping.install_import_hookexists. - Pytest can enable it with
--jaxtyping-packages=foo,bar.baz,beartype.beartypeor equivalentaddopts. - In notebooks, you can load the IPython extension and set
%jaxtyping.typechecker beartype.beartype. - With
jax.jit, runtime shape checks occur during tracing, so the compiled code does not keep the checking overhead.
Advanced Notes
print_bindingscan help inspect axis bindings when debugging shape mismatches.AbstractDtypecan define custom dtypes for duck-typed arrays when a project needs that level of control.- For unusual edge cases or less common features, see the upstream docs:
https://github.com/patrick-kidger/jaxtyping/tree/main/docs.
Tooling Gotchas
pyrightandmypyusually treatdtype[array, shape]as justarray; do not expect them to prove shape safety.- Ruff or
flake8may complain about shape strings in annotations. DisablingF722is the usual fix for multidimensional forms. - For one-dimensional annotations that trigger undefined-name lint errors, prefixing the shape with a space can route the linter to
F722instead. - Dataclass fields with stringified annotations can be skipped or mis-checked at runtime. Avoid them.
Anti-Goals
- Do not introduce
jaxtypingorbeartypeinto non-numerical work just because this skill is loaded. - Do not annotate every local scratch tensor when the extra ceremony does not improve clarity.
- Do not add runtime checking to hot loops unless the cost is acceptable.