Merge AnyIO defaults and inline tensor docs
This commit is contained in:
@@ -7,32 +7,13 @@ description: Use when working on tensor-heavy or numerical Python code in reposi
|
||||
|
||||
Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.
|
||||
|
||||
## Priority Order
|
||||
|
||||
1. Explicit user instructions
|
||||
2. Existing repository tensor and verification conventions
|
||||
3. This skill
|
||||
|
||||
Only apply this skill when the task is numerical or tensor-heavy and the repository already uses `jaxtyping` and `beartype`, or the user explicitly asks for shape-typed numerics.
|
||||
|
||||
## Before Applying This Skill
|
||||
|
||||
Check the local project first:
|
||||
|
||||
- whether the task actually involves arrays, tensors, or numerical kernels
|
||||
- which array types are already used: NumPy, PyTorch, JAX, TensorFlow, MLX
|
||||
- whether `jaxtyping` and `beartype` are already present
|
||||
- what verification commands already exist
|
||||
|
||||
If the repository does not already use this stack and the task is not explicitly about numerical typing, do not introduce it.
|
||||
|
||||
## Defaults
|
||||
|
||||
- Use `DType[ArrayType, "shape names"]`, for example `Float[np.ndarray, "batch channels"]`.
|
||||
- Reuse axis names to express shared dimensions across arguments and returns.
|
||||
- Prefer reusable aliases for common tensor shapes.
|
||||
- Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
|
||||
- Use `@jaxtyped(typechecker=beartype)` on stable boundaries and test-targeted helpers when the runtime cost is acceptable.
|
||||
- Prefer `@jaxtyped(typechecker=beartype)` on stable boundaries and test-targeted helpers by default so tensor shape or dtype mismatches fail early.
|
||||
- Avoid applying runtime checking blindly to hot inner loops.
|
||||
- Only avoid `from __future__ import annotations` in modules that rely on runtime annotation inspection.
|
||||
|
||||
@@ -49,17 +30,105 @@ def normalize(x: Batch) -> Batch:
|
||||
...
|
||||
```
|
||||
|
||||
Read `references/jaxtyping-summary.md` when writing or reviewing array or tensor annotations.
|
||||
## Integrated Jaxtyping Reference
|
||||
|
||||
## Verification
|
||||
Use this section as the built-in cheat sheet. Do not rely on a separate reference file.
|
||||
|
||||
Use the repository's existing verification workflow first.
|
||||
### Mental Model
|
||||
|
||||
If no local workflow exists and the repository is already aligned with this stack:
|
||||
- `jaxtyping` provides shape and dtype annotations plus runtime type checking for JAX, PyTorch, NumPy, MLX, and TensorFlow arrays and tensors.
|
||||
- The name is historical; it is not JAX-only.
|
||||
- Static type checkers do not fully understand shape constraints. In practice `dtype[array, shape]` is mostly treated as just `array`, so use runtime checks to catch shape or dtype mistakes early.
|
||||
- This skill defaults to `beartype` as the runtime checker.
|
||||
|
||||
1. run the configured type checker
|
||||
2. run the numerical test suite or `pytest`
|
||||
3. run a module smoke test that exercises the typed tensor boundary when relevant
|
||||
### Core Syntax
|
||||
|
||||
- Annotate arrays as `DType[ArrayType, "shape"]`.
|
||||
- Common dtype families:
|
||||
- `Shaped` for any dtype
|
||||
- `Bool`
|
||||
- `Num` for numeric tensors
|
||||
- `Real`, `Float`, `Complex`, `Int`, `UInt`
|
||||
- precision-specific forms like `Float32`, `Int64`, `UInt8`
|
||||
- Representative examples:
|
||||
- `Float[np.ndarray, "batch channels"]`
|
||||
- `Int[np.ndarray, "persons"]`
|
||||
- `Shaped[ArrayLike, "batch time features"]`
|
||||
- `Float[Tensor, "... channels"]`
|
||||
|
||||
### Shape Rules
|
||||
|
||||
- Use fixed integers for exact axes: `"3 3"`.
|
||||
- Use named axes to enforce equality across values: `"batch time"` and `"time features"`.
|
||||
- Use symbolic expressions when the return shape is derived from argument shapes: `"dim-1"`.
|
||||
- Use `...` for anonymous zero-or-more axes. You may only use one variadic axis per annotation.
|
||||
- Use `*name` for a named variadic axis.
|
||||
- Use `#name` when broadcasting with size `1` should be accepted.
|
||||
- Use `_name` or just `_` when an axis is documentation-only and should not be runtime-checked.
|
||||
- Use `name=...` for documentation-only labels like `rows=4`.
|
||||
- Use `""` for a scalar shape.
|
||||
- Use `"..."` when you only want dtype checking and do not want to constrain shape.
|
||||
- Prefer meaningful names like `batch`, `frames`, `persons`, `keypoints`, `dims`, `channels`.
|
||||
|
||||
### Array And Alias Guidance
|
||||
|
||||
- Prefer concrete normalized array types in core logic:
|
||||
- `Float[np.ndarray, "..."]`
|
||||
- `Float[torch.Tensor, "..."]`
|
||||
- `Float[jax.Array, "..."]`
|
||||
- `Float[tf.Tensor, "..."]`
|
||||
- `Float[mx.array, "..."]`
|
||||
- Use broader input types like `Shaped[ArrayLike, "..."]` only at ingestion boundaries.
|
||||
- Keep repeated shapes in local aliases instead of rewriting them in every signature.
|
||||
- You can nest existing `jaxtyping` aliases:
|
||||
|
||||
```python
|
||||
from jaxtyping import Float
|
||||
import jax
|
||||
|
||||
Image = Float[jax.Array, "channels height width"]
|
||||
BatchImage = Float[Image, "batch"]
|
||||
```
|
||||
|
||||
- Duck-typed arrays also work if they expose `.shape` and `.dtype`.
|
||||
- `typing.Any`, unions, type aliases, and bounded `TypeVar`s can all be used as the array type parameter when needed.
|
||||
|
||||
### Runtime Checking Patterns
|
||||
|
||||
- Prefer `@jaxtyped(typechecker=beartype)` for function-level checks.
|
||||
- Use it on adapters, public tensor utilities, deserializers, boundary normalization functions, and tests.
|
||||
- Do not decorate the hottest inner loops unless the runtime cost is acceptable.
|
||||
- Avoid stringized annotations and avoid `from __future__ import annotations` in code paths that depend on runtime inspection.
|
||||
- If static typing and runtime truth diverge, validate at runtime first and then use a narrowly scoped commented `cast(...)`.
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from beartype import beartype
|
||||
from jaxtyping import Float, jaxtyped
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
|
||||
return x - x.mean(axis=0)
|
||||
```
|
||||
|
||||
- For package-wide or test-only enforcement, `jaxtyping.install_import_hook` exists.
|
||||
- Pytest can enable it with `--jaxtyping-packages=foo,bar.baz,beartype.beartype` or equivalent `addopts`.
|
||||
- In notebooks, you can load the IPython extension and set `%jaxtyping.typechecker beartype.beartype`.
|
||||
- With `jax.jit`, runtime shape checks occur during tracing, so the compiled code does not keep the checking overhead.
|
||||
|
||||
### Advanced Notes
|
||||
|
||||
- `print_bindings` can help inspect axis bindings when debugging shape mismatches.
|
||||
- `AbstractDtype` can define custom dtypes for duck-typed arrays when a project needs that level of control.
|
||||
- For unusual edge cases or less common features, see the upstream docs: `https://github.com/patrick-kidger/jaxtyping/tree/main/docs`.
|
||||
|
||||
### Tooling Gotchas
|
||||
|
||||
- `pyright` and `mypy` usually treat `dtype[array, shape]` as just `array`; do not expect them to prove shape safety.
|
||||
- Ruff or `flake8` may complain about shape strings in annotations. Disabling `F722` is the usual fix for multidimensional forms.
|
||||
- For one-dimensional annotations that trigger undefined-name lint errors, prefixing the shape with a space can route the linter to `F722` instead.
|
||||
- Dataclass fields with stringified annotations can be skipped or mis-checked at runtime. Avoid them.
|
||||
|
||||
## Anti-Goals
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# Jaxtyping Summary
|
||||
|
||||
Use this reference when the task involves NumPy, JAX, PyTorch, TensorFlow, MLX, or array-like inputs that should carry shape and dtype information.
|
||||
|
||||
## Core syntax
|
||||
|
||||
- Use `DType[array_type, "shape"]`.
|
||||
- Examples:
|
||||
- `Float[np.ndarray, "batch channels"]`
|
||||
- `Int[np.ndarray, "persons"]`
|
||||
- `Shaped[ArrayLike, "batch time features"]`
|
||||
- `Float[Tensor, "... channels"]`
|
||||
|
||||
## Shape rules
|
||||
|
||||
- Reuse names to enforce equality across values: `"batch time"` with `"time features"`.
|
||||
- Use fixed integers for exact sizes: `"3 3"`.
|
||||
- Use `...` for zero or more anonymous axes.
|
||||
- Use `*name` for a named variadic axis.
|
||||
- Use `#name` when size `1` should also be accepted for broadcasting.
|
||||
- Use `_` or `name=...` only for documentation when runtime enforcement is not wanted.
|
||||
|
||||
## Array type guidance
|
||||
|
||||
- Prefer concrete normalized types in core logic:
|
||||
- `Float[np.ndarray, "..."]`
|
||||
- `Float[torch.Tensor, "..."]`
|
||||
- `Float[jax.Array, "..."]`
|
||||
- Use `Shaped[ArrayLike, "..."]` or another broader input type only at ingestion boundaries.
|
||||
- Create aliases for repeated shapes instead of rewriting them in every signature.
|
||||
|
||||
```python
|
||||
from jaxtyping import Float
|
||||
import numpy as np
|
||||
|
||||
FramePoints = Float[np.ndarray, "frames keypoints dims"]
|
||||
```
|
||||
|
||||
## Runtime checking
|
||||
|
||||
- Pair `jaxtyping` with `beartype` for runtime validation:
|
||||
|
||||
```python
|
||||
from beartype import beartype
|
||||
from jaxtyping import Float, jaxtyped
|
||||
import numpy as np
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
|
||||
return x - x.mean(axis=0)
|
||||
```
|
||||
|
||||
- Apply this at stable boundaries and in tests, not blindly on every hot loop.
|
||||
- Avoid `from __future__ import annotations` when relying on runtime checking.
|
||||
|
||||
## Practical defaults
|
||||
|
||||
- Prefer meaningful axis names like `batch`, `frames`, `persons`, `keypoints`, `dims`, `channels`.
|
||||
- Keep aliases near the module or domain where they are used.
|
||||
- If static typing and runtime truth diverge, validate at runtime first, then use a commented `cast(...)` at the narrow boundary.
|
||||
Reference in New Issue
Block a user