Files
skill-python-style-preferences/python-tensor-typing/SKILL.md
T

6.3 KiB

name, description
name description
python-tensor-typing Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested.

Python Tensor Typing

Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.

Defaults

  • Use DType[ArrayType, "shape names"], for example Float[np.ndarray, "batch channels"].
  • Reuse axis names to express shared dimensions across arguments and returns.
  • Prefer reusable aliases for common tensor shapes.
  • Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
  • Prefer @jaxtyped(typechecker=beartype) on stable boundaries and test-targeted helpers by default so tensor shape or dtype mismatches fail early.
  • Avoid applying runtime checking blindly to hot inner loops.
  • Only avoid from __future__ import annotations in modules that rely on runtime annotation inspection.
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped

Batch = Float[np.ndarray, "batch channels"]


@jaxtyped(typechecker=beartype)
def normalize(x: Batch) -> Batch:
    ...

Integrated Jaxtyping Reference

Use this section as the built-in cheat sheet. Do not rely on a separate reference file.

Mental Model

  • jaxtyping provides shape and dtype annotations plus runtime type checking for JAX, PyTorch, NumPy, MLX, and TensorFlow arrays and tensors.
  • The name is historical; it is not JAX-only.
  • Static type checkers do not fully understand shape constraints. In practice dtype[array, shape] is mostly treated as just array, so use runtime checks to catch shape or dtype mistakes early.
  • This skill defaults to beartype as the runtime checker.

Core Syntax

  • Annotate arrays as DType[ArrayType, "shape"].
  • Common dtype families:
    • Shaped for any dtype
    • Bool
    • Num for numeric tensors
    • Real, Float, Complex, Int, UInt
    • precision-specific forms like Float32, Int64, UInt8
  • Representative examples:
    • Float[np.ndarray, "batch channels"]
    • Int[np.ndarray, "persons"]
    • Shaped[ArrayLike, "batch time features"]
    • Float[Tensor, "... channels"]

Shape Rules

  • Use fixed integers for exact axes: "3 3".
  • Use named axes to enforce equality across values: "batch time" and "time features".
  • Use symbolic expressions when the return shape is derived from argument shapes: "dim-1".
  • Use ... for anonymous zero-or-more axes. You may only use one variadic axis per annotation.
  • Use *name for a named variadic axis.
  • Use #name when broadcasting with size 1 should be accepted.
  • Use _name or just _ when an axis is documentation-only and should not be runtime-checked.
  • Use name=... for documentation-only labels like rows=4.
  • Use "" for a scalar shape.
  • Use "..." when you only want dtype checking and do not want to constrain shape.
  • Prefer meaningful names like batch, frames, persons, keypoints, dims, channels.

Array And Alias Guidance

  • Prefer concrete normalized array types in core logic:
    • Float[np.ndarray, "..."]
    • Float[torch.Tensor, "..."]
    • Float[jax.Array, "..."]
    • Float[tf.Tensor, "..."]
    • Float[mx.array, "..."]
  • Use broader input types like Shaped[ArrayLike, "..."] only at ingestion boundaries.
  • Keep repeated shapes in local aliases instead of rewriting them in every signature.
  • You can nest existing jaxtyping aliases:
from jaxtyping import Float
import jax

Image = Float[jax.Array, "channels height width"]
BatchImage = Float[Image, "batch"]
  • Duck-typed arrays also work if they expose .shape and .dtype.
  • typing.Any, unions, type aliases, and bounded TypeVars can all be used as the array type parameter when needed.

Runtime Checking Patterns

  • Prefer @jaxtyped(typechecker=beartype) for function-level checks.
  • Use it on adapters, public tensor utilities, deserializers, boundary normalization functions, and tests.
  • Do not decorate the hottest inner loops unless the runtime cost is acceptable.
  • Avoid stringized annotations and avoid from __future__ import annotations in code paths that depend on runtime inspection.
  • If static typing and runtime truth diverge, validate at runtime first and then use a narrowly scoped commented cast(...).
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped


@jaxtyped(typechecker=beartype)
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
    return x - x.mean(axis=0)
  • For package-wide or test-only enforcement, jaxtyping.install_import_hook exists.
  • Pytest can enable it with --jaxtyping-packages=foo,bar.baz,beartype.beartype or equivalent addopts.
  • In notebooks, you can load the IPython extension and set %jaxtyping.typechecker beartype.beartype.
  • With jax.jit, runtime shape checks occur during tracing, so the compiled code does not keep the checking overhead.

Advanced Notes

  • print_bindings can help inspect axis bindings when debugging shape mismatches.
  • AbstractDtype can define custom dtypes for duck-typed arrays when a project needs that level of control.
  • For unusual edge cases or less common features, see the upstream docs: https://github.com/patrick-kidger/jaxtyping/tree/main/docs.

Tooling Gotchas

  • pyright and mypy usually treat dtype[array, shape] as just array; do not expect them to prove shape safety.
  • Ruff or flake8 may complain about shape strings in annotations. Disabling F722 is the usual fix for multidimensional forms.
  • For one-dimensional annotations that trigger undefined-name lint errors, prefixing the shape with a space can route the linter to F722 instead.
  • Dataclass fields with stringified annotations can be skipped or mis-checked at runtime. Avoid them.

Anti-Goals

  • Do not introduce jaxtyping or beartype into non-numerical work just because this skill is loaded.
  • Do not annotate every local scratch tensor when the extra ceremony does not improve clarity.
  • Do not add runtime checking to hot loops unless the cost is acceptable.