Files
skill-python-style-preferences/references/jaxtyping-summary.md
T
2026-03-13 14:54:25 +08:00

2.0 KiB

Jaxtyping Summary

Use this reference when the task involves NumPy, JAX, PyTorch, TensorFlow, MLX, or array-like inputs that should carry shape and dtype information.

Core syntax

  • Use DType[array_type, "shape"].
  • Examples:
    • Float[np.ndarray, "batch channels"]
    • Int[np.ndarray, "persons"]
    • Shaped[ArrayLike, "batch time features"]
    • Float[Tensor, "... channels"]

Shape rules

  • Reuse names to enforce equality across values: "batch time" with "time features".
  • Use fixed integers for exact sizes: "3 3".
  • Use ... for zero or more anonymous axes.
  • Use *name for a named variadic axis.
  • Use #name when size 1 should also be accepted for broadcasting.
  • Use _ or name=... only for documentation when runtime enforcement is not wanted.

Array type guidance

  • Prefer concrete normalized types in core logic:
    • Float[np.ndarray, "..."]
    • Float[torch.Tensor, "..."]
    • Float[jax.Array, "..."]
  • Use Shaped[ArrayLike, "..."] or another broader input type only at ingestion boundaries.
  • Create aliases for repeated shapes instead of rewriting them in every signature.
from jaxtyping import Float
import numpy as np

FramePoints = Float[np.ndarray, "frames keypoints dims"]

Runtime checking

  • Pair jaxtyping with beartype for runtime validation:
from beartype import beartype
from jaxtyping import Float, jaxtyped
import numpy as np


@jaxtyped(typechecker=beartype)
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
    return x - x.mean(axis=0)
  • Apply this at stable boundaries and in tests, not blindly on every hot loop.
  • Avoid from __future__ import annotations when relying on runtime checking.

Practical defaults

  • Prefer meaningful axis names like batch, frames, persons, keypoints, dims, channels.
  • Keep aliases near the module or domain where they are used.
  • If static typing and runtime truth diverge, validate at runtime first, then use a commented cast(...) at the narrow boundary.