Files

2.8 KiB

name, description
name description
python-tensor-typing Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested.

Python Tensor Typing

Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.

Priority Order

  1. Explicit user instructions
  2. Existing repository tensor and verification conventions
  3. This skill

Only apply this skill when the task is numerical or tensor-heavy and the repository already uses jaxtyping and beartype, or the user explicitly asks for shape-typed numerics.

Before Applying This Skill

Check the local project first:

  • whether the task actually involves arrays, tensors, or numerical kernels
  • which array types are already used: NumPy, PyTorch, JAX, TensorFlow, MLX
  • whether jaxtyping and beartype are already present
  • what verification commands already exist

If the repository does not already use this stack and the task is not explicitly about numerical typing, do not introduce it.

Defaults

  • Use DType[ArrayType, "shape names"], for example Float[np.ndarray, "batch channels"].
  • Reuse axis names to express shared dimensions across arguments and returns.
  • Prefer reusable aliases for common tensor shapes.
  • Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
  • Use @jaxtyped(typechecker=beartype) on stable boundaries and test-targeted helpers when the runtime cost is acceptable.
  • Avoid applying runtime checking blindly to hot inner loops.
  • Only avoid from __future__ import annotations in modules that rely on runtime annotation inspection.
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped

Batch = Float[np.ndarray, "batch channels"]


@jaxtyped(typechecker=beartype)
def normalize(x: Batch) -> Batch:
    ...

Read references/jaxtyping-summary.md when writing or reviewing array or tensor annotations.

Verification

Use the repository's existing verification workflow first.

If no local workflow exists and the repository is already aligned with this stack:

  1. run the configured type checker
  2. run the numerical test suite or pytest
  3. run a module smoke test that exercises the typed tensor boundary when relevant

Anti-Goals

  • Do not introduce jaxtyping or beartype into non-numerical work just because this skill is loaded.
  • Do not annotate every local scratch tensor when the extra ceremony does not improve clarity.
  • Do not add runtime checking to hot loops unless the cost is acceptable.