2.8 KiB
2.8 KiB
name, description
| name | description |
|---|---|
| python-tensor-typing | Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested. |
Python Tensor Typing
Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.
Priority Order
- Explicit user instructions
- Existing repository tensor and verification conventions
- This skill
Only apply this skill when the task is numerical or tensor-heavy and the repository already uses jaxtyping and beartype, or the user explicitly asks for shape-typed numerics.
Before Applying This Skill
Check the local project first:
- whether the task actually involves arrays, tensors, or numerical kernels
- which array types are already used: NumPy, PyTorch, JAX, TensorFlow, MLX
- whether
jaxtypingandbeartypeare already present - what verification commands already exist
If the repository does not already use this stack and the task is not explicitly about numerical typing, do not introduce it.
Defaults
- Use
DType[ArrayType, "shape names"], for exampleFloat[np.ndarray, "batch channels"]. - Reuse axis names to express shared dimensions across arguments and returns.
- Prefer reusable aliases for common tensor shapes.
- Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
- Use
@jaxtyped(typechecker=beartype)on stable boundaries and test-targeted helpers when the runtime cost is acceptable. - Avoid applying runtime checking blindly to hot inner loops.
- Only avoid
from __future__ import annotationsin modules that rely on runtime annotation inspection.
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped
Batch = Float[np.ndarray, "batch channels"]
@jaxtyped(typechecker=beartype)
def normalize(x: Batch) -> Batch:
...
Read references/jaxtyping-summary.md when writing or reviewing array or tensor annotations.
Verification
Use the repository's existing verification workflow first.
If no local workflow exists and the repository is already aligned with this stack:
- run the configured type checker
- run the numerical test suite or
pytest - run a module smoke test that exercises the typed tensor boundary when relevant
Anti-Goals
- Do not introduce
jaxtypingorbeartypeinto non-numerical work just because this skill is loaded. - Do not annotate every local scratch tensor when the extra ceremony does not improve clarity.
- Do not add runtime checking to hot loops unless the cost is acceptable.