from typing import TypeVar import numpy as np from jaxtyping import Float, Int, Bool, Num T = TypeVar("T") NDArray = np.ndarray