--- name: python-tensor-typing description: Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested. --- # Python Tensor Typing Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work. ## Priority Order 1. Explicit user instructions 2. Existing repository tensor and verification conventions 3. This skill Only apply this skill when the task is numerical or tensor-heavy and the repository already uses `jaxtyping` and `beartype`, or the user explicitly asks for shape-typed numerics. ## Before Applying This Skill Check the local project first: - whether the task actually involves arrays, tensors, or numerical kernels - which array types are already used: NumPy, PyTorch, JAX, TensorFlow, MLX - whether `jaxtyping` and `beartype` are already present - what verification commands already exist If the repository does not already use this stack and the task is not explicitly about numerical typing, do not introduce it. ## Defaults - Use `DType[ArrayType, "shape names"]`, for example `Float[np.ndarray, "batch channels"]`. - Reuse axis names to express shared dimensions across arguments and returns. - Prefer reusable aliases for common tensor shapes. - Prefer concrete array types after normalization; use broader input types only at ingestion boundaries. - Use `@jaxtyped(typechecker=beartype)` on stable boundaries and test-targeted helpers when the runtime cost is acceptable. - Avoid applying runtime checking blindly to hot inner loops. - Only avoid `from __future__ import annotations` in modules that rely on runtime annotation inspection. ```python import numpy as np from beartype import beartype from jaxtyping import Float, jaxtyped Batch = Float[np.ndarray, "batch channels"] @jaxtyped(typechecker=beartype) def normalize(x: Batch) -> Batch: ... ``` Read `references/jaxtyping-summary.md` when writing or reviewing array or tensor annotations. ## Verification Use the repository's existing verification workflow first. If no local workflow exists and the repository is already aligned with this stack: 1. run the configured type checker 2. run the numerical test suite or `pytest` 3. run a module smoke test that exercises the typed tensor boundary when relevant ## Anti-Goals - Do not introduce `jaxtyping` or `beartype` into non-numerical work just because this skill is loaded. - Do not annotate every local scratch tensor when the extra ceremony does not improve clarity. - Do not add runtime checking to hot loops unless the cost is acceptable.