Split Python style skill into focused modules

This commit is contained in:
2026-03-16 14:18:05 +08:00
parent 0bb3ec31a4
commit 19b12dbd17
9 changed files with 260 additions and 115 deletions
+68
View File
@@ -0,0 +1,68 @@
---
name: python-tensor-typing
description: Use when working on tensor-heavy or numerical Python code in repositories that already use or are explicitly standardizing on jaxtyping and beartype. Apply shape and dtype annotations plus boundary-focused runtime validation without introducing these tools to unrelated code unless requested.
---
# Python Tensor Typing
Use this skill for tensor-heavy or numerical code that benefits from explicit shape and dtype contracts. It should not leak into unrelated Python work.
## Priority Order
1. Explicit user instructions
2. Existing repository tensor and verification conventions
3. This skill
Only apply this skill when the task is numerical or tensor-heavy and the repository already uses `jaxtyping` and `beartype`, or the user explicitly asks for shape-typed numerics.
## Before Applying This Skill
Check the local project first:
- whether the task actually involves arrays, tensors, or numerical kernels
- which array types are already used: NumPy, PyTorch, JAX, TensorFlow, MLX
- whether `jaxtyping` and `beartype` are already present
- what verification commands already exist
If the repository does not already use this stack and the task is not explicitly about numerical typing, do not introduce it.
## Defaults
- Use `DType[ArrayType, "shape names"]`, for example `Float[np.ndarray, "batch channels"]`.
- Reuse axis names to express shared dimensions across arguments and returns.
- Prefer reusable aliases for common tensor shapes.
- Prefer concrete array types after normalization; use broader input types only at ingestion boundaries.
- Use `@jaxtyped(typechecker=beartype)` on stable boundaries and test-targeted helpers when the runtime cost is acceptable.
- Avoid applying runtime checking blindly to hot inner loops.
- Only avoid `from __future__ import annotations` in modules that rely on runtime annotation inspection.
```python
import numpy as np
from beartype import beartype
from jaxtyping import Float, jaxtyped
Batch = Float[np.ndarray, "batch channels"]
@jaxtyped(typechecker=beartype)
def normalize(x: Batch) -> Batch:
...
```
Read `references/jaxtyping-summary.md` when writing or reviewing array or tensor annotations.
## Verification
Use the repository's existing verification workflow first.
If no local workflow exists and the repository is already aligned with this stack:
1. run the configured type checker
2. run the numerical test suite or `pytest`
3. run a module smoke test that exercises the typed tensor boundary when relevant
## Anti-Goals
- Do not introduce `jaxtyping` or `beartype` into non-numerical work just because this skill is loaded.
- Do not annotate every local scratch tensor when the extra ceremony does not improve clarity.
- Do not add runtime checking to hot loops unless the cost is acceptable.
+4
View File
@@ -0,0 +1,4 @@
interface:
display_name: "Python Tensor Typing"
short_description: "Tensor typing with jaxtyping"
default_prompt: "Apply jaxtyping and beartype patterns for tensor-heavy Python code while preserving repository conventions. Only use this skill when the repo already uses these tools or the user explicitly asks to standardize on them."
@@ -0,0 +1,61 @@
# Jaxtyping Summary
Use this reference when the task involves NumPy, JAX, PyTorch, TensorFlow, MLX, or array-like inputs that should carry shape and dtype information.
## Core syntax
- Use `DType[array_type, "shape"]`.
- Examples:
- `Float[np.ndarray, "batch channels"]`
- `Int[np.ndarray, "persons"]`
- `Shaped[ArrayLike, "batch time features"]`
- `Float[Tensor, "... channels"]`
## Shape rules
- Reuse names to enforce equality across values: `"batch time"` with `"time features"`.
- Use fixed integers for exact sizes: `"3 3"`.
- Use `...` for zero or more anonymous axes.
- Use `*name` for a named variadic axis.
- Use `#name` when size `1` should also be accepted for broadcasting.
- Use `_` or `name=...` only for documentation when runtime enforcement is not wanted.
## Array type guidance
- Prefer concrete normalized types in core logic:
- `Float[np.ndarray, "..."]`
- `Float[torch.Tensor, "..."]`
- `Float[jax.Array, "..."]`
- Use `Shaped[ArrayLike, "..."]` or another broader input type only at ingestion boundaries.
- Create aliases for repeated shapes instead of rewriting them in every signature.
```python
from jaxtyping import Float
import numpy as np
FramePoints = Float[np.ndarray, "frames keypoints dims"]
```
## Runtime checking
- Pair `jaxtyping` with `beartype` for runtime validation:
```python
from beartype import beartype
from jaxtyping import Float, jaxtyped
import numpy as np
@jaxtyped(typechecker=beartype)
def center(x: Float[np.ndarray, "batch dims"]) -> Float[np.ndarray, "batch dims"]:
return x - x.mean(axis=0)
```
- Apply this at stable boundaries and in tests, not blindly on every hot loop.
- Avoid `from __future__ import annotations` when relying on runtime checking.
## Practical defaults
- Prefer meaningful axis names like `batch`, `frames`, `persons`, `keypoints`, `dims`, `channels`.
- Keep aliases near the module or domain where they are used.
- If static typing and runtime truth diverge, validate at runtime first, then use a commented `cast(...)` at the narrow boundary.