69 lines
2.8 KiB
Markdown
69 lines
2.8 KiB
Markdown
---
|
|
name: python-tensor-typing
|
|
description: Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested.
|
|
---
|
|
|
|
# Python Tensor Typing
|
|
|
|
Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.
|
|
|
|
## Priority Order
|
|
|
|
1. Explicit user instructions
|
|
2. Existing repository tensor and verification conventions
|
|
3. This skill
|
|
|
|
Only apply this skill when the task is numerical or tensor-heavy and the repository already uses `jaxtyping` and `beartype`, or the user explicitly asks for shape-typed numerics.
|
|
|
|
## Before Applying This Skill
|
|
|
|
Check the local project first:
|
|
|
|
- whether the task actually involves arrays, tensors, or numerical kernels
|
|
- which array types are already used: NumPy, PyTorch, JAX, TensorFlow, MLX
|
|
- whether `jaxtyping` and `beartype` are already present
|
|
- what verification commands already exist
|
|
|
|
If the repository does not already use this stack and the task is not explicitly about numerical typing, do not introduce it.
|
|
|
|
## Defaults
|
|
|
|
- Use `DType[ArrayType, "shape names"]`, for example `Float[np.ndarray, "batch channels"]`.
|
|
- Reuse axis names to express shared dimensions across arguments and returns.
|
|
- Prefer reusable aliases for common tensor shapes.
|
|
- Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
|
|
- Use `@jaxtyped(typechecker=beartype)` on stable boundaries and test-targeted helpers when the runtime cost is acceptable.
|
|
- Avoid applying runtime checking blindly to hot inner loops.
|
|
- Only avoid `from __future__ import annotations` in modules that rely on runtime annotation inspection.
|
|
|
|
```python
|
|
import numpy as np
|
|
from beartype import beartype
|
|
from jaxtyping import Float, jaxtyped
|
|
|
|
Batch = Float[np.ndarray, "batch channels"]
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def normalize(x: Batch) -> Batch:
|
|
...
|
|
```
|
|
|
|
Read `references/jaxtyping-summary.md` when writing or reviewing array or tensor annotations.
|
|
|
|
## Verification
|
|
|
|
Use the repository's existing verification workflow first.
|
|
|
|
If no local workflow exists and the repository is already aligned with this stack:
|
|
|
|
1. run the configured type checker
|
|
2. run the numerical test suite or `pytest`
|
|
3. run a module smoke test that exercises the typed tensor boundary when relevant
|
|
|
|
## Anti-Goals
|
|
|
|
- Do not introduce `jaxtyping` or `beartype` into non-numerical work just because this skill is loaded.
|
|
- Do not annotate every local scratch tensor when the extra ceremony does not improve clarity.
|
|
- Do not add runtime checking to hot loops unless the cost is acceptable.
|