--- name: python-tensor-typing description: Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested. --- # Python Tensor Typing Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work. ## Defaults - Use `DType[ArrayType, "shape names"]`, for example `Float[np.ndarray, "batch channels"]`. - Reuse axis names to express shared dimensions across arguments and returns. - Prefer reusable aliases for common tensor shapes. - Prefer concrete array types after normalization; use broader input types only at ingestion boundaries. - Prefer `@jaxtyped(typechecker=beartype)` on stable boundaries and test-targeted helpers by default so tensor shape or dtype mismatches fail early. - Avoid applying runtime checking blindly to hot inner loops. - Only avoid `from __future__ import annotations` in modules that rely on runtime annotation inspection. ```python import numpy as np from beartype import beartype from jaxtyping import Float, jaxtyped Batch = Float[np.ndarray, "batch channels"] @jaxtyped(typechecker=beartype) def normalize(x: Batch) -> Batch: ... ``` ## Integrated Jaxtyping Reference Use this section as the built-in cheat sheet. Do not rely on a separate reference file. ### Mental Model - `jaxtyping` provides shape and dtype annotations plus runtime type checking for JAX, PyTorch, NumPy, MLX, and TensorFlow arrays and tensors. - The name is historical; it is not JAX-only. - Static type checkers do not fully understand shape constraints. In practice `dtype[array, shape]` is mostly treated as just `array`, so use runtime checks to catch shape or dtype mistakes early. - This skill defaults to `beartype` as the runtime checker. ### Core Syntax - Annotate arrays as `DType[ArrayType, "shape"]`. - Common dtype families: - `Shaped` for any dtype - `Bool` - `Num` for numeric tensors - `Real`, `Float`, `Complex`, `Int`, `UInt` - precision-specific forms like `Float32`, `Int64`, `UInt8` - Representative examples: - `Float[np.ndarray, "batch channels"]` - `Int[np.ndarray, "persons"]` - `Shaped[ArrayLike, "batch time features"]` - `Float[Tensor, "... channels"]` ### Shape Rules - Use fixed integers for exact axes: `"3 3"`. - Use named axes to enforce equality across values: `"batch time"` and `"time features"`. - Use symbolic expressions when the return shape is derived from argument shapes: `"dim-1"`. - Use `...` for anonymous zero-or-more axes. You may only use one variadic axis per annotation. - Use `*name` for a named variadic axis. - Use `#name` when broadcasting with size `1` should be accepted. - Use `_name` or just `_` when an axis is documentation-only and should not be runtime-checked. - Use `name=...` for documentation-only labels like `rows=4`. - Use `""` for a scalar shape. - Use `"..."` when you only want dtype checking and do not want to constrain shape. - Prefer meaningful names like `batch`, `frames`, `persons`, `keypoints`, `dims`, `channels`. ### Array And Alias Guidance - Prefer concrete normalized array types in core logic: - `Float[np.ndarray, "..."]` - `Float[torch.Tensor, "..."]` - `Float[jax.Array, "..."]` - `Float[tf.Tensor, "..."]` - `Float[mx.array, "..."]` - Use broader input types like `Shaped[ArrayLike, "..."]` only at ingestion boundaries. - Keep repeated shapes in local aliases instead of rewriting them in every signature. - You can nest existing `jaxtyping` aliases: ```python from jaxtyping import Float import jax Image = Float[jax.Array, "channels height width"] BatchImage = Float[Image, "batch"] ``` - Duck-typed arrays also work if they expose `.shape` and `.dtype`. - `typing.Any`, unions, type aliases, and bounded `TypeVar`s can all be used as the array type parameter when needed. ### Runtime Checking Patterns - Prefer `@jaxtyped(typechecker=beartype)` for function-level checks. - Use it on adapters, public tensor utilities, deserializers, boundary normalization functions, and tests. - Do not decorate the hottest inner loops unless the runtime cost is acceptable. - Avoid stringized annotations and avoid `from __future__ import annotations` in code paths that depend on runtime inspection. - If static typing and runtime truth diverge, validate at runtime first and then use a narrowly scoped commented `cast(...)`. ```python import numpy as np from beartype import beartype from jaxtyping import Float, jaxtyped @jaxtyped(typechecker=beartype) def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]: return x - x.mean(axis=0) ``` - For package-wide or test-only enforcement, `jaxtyping.install_import_hook` exists. - Pytest can enable it with `--jaxtyping-packages=foo,bar.baz,beartype.beartype` or equivalent `addopts`. - In notebooks, you can load the IPython extension and set `%jaxtyping.typechecker beartype.beartype`. - With `jax.jit`, runtime shape checks occur during tracing, so the compiled code does not keep the checking overhead. ### Advanced Notes - `print_bindings` can help inspect axis bindings when debugging shape mismatches. - `AbstractDtype` can define custom dtypes for duck-typed arrays when a project needs that level of control. - For unusual edge cases or less common features, see the upstream docs: `https://github.com/patrick-kidger/jaxtyping/tree/main/docs`. ### Tooling Gotchas - `pyright` and `mypy` usually treat `dtype[array, shape]` as just `array`; do not expect them to prove shape safety. - Ruff or `flake8` may complain about shape strings in annotations. Disabling `F722` is the usual fix for multidimensional forms. - For one-dimensional annotations that trigger undefined-name lint errors, prefixing the shape with a space can route the linter to `F722` instead. - Dataclass fields with stringified annotations can be skipped or mis-checked at runtime. Avoid them. ## Anti-Goals - Do not introduce `jaxtyping` or `beartype` into non-numerical work just because this skill is loaded. - Do not annotate every local scratch tensor when the extra ceremony does not improve clarity. - Do not add runtime checking to hot loops unless the cost is acceptable.