b24644f16e
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
271 lines
8.6 KiB
Python
271 lines
8.6 KiB
Python
from collections.abc import Callable
|
|
import math
|
|
from typing import cast
|
|
|
|
import cv2
|
|
from beartype import beartype
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
|
|
SIL_HEIGHT = 64
|
|
SIL_WIDTH = 44
|
|
SIL_FULL_WIDTH = 64
|
|
SIDE_CUT = 10
|
|
MIN_MASK_AREA = 500
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
UInt8Array = NDArray[np.uint8]
|
|
Float32Array = NDArray[np.float32]
|
|
|
|
|
|
def _read_attr(container: object, key: str) -> object | None:
|
|
if isinstance(container, dict):
|
|
dict_obj = cast(dict[object, object], container)
|
|
return dict_obj.get(key)
|
|
try:
|
|
return cast(object, object.__getattribute__(container, key))
|
|
except AttributeError:
|
|
return None
|
|
|
|
|
|
def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
|
current: object = value
|
|
if isinstance(current, np.ndarray):
|
|
return current
|
|
|
|
detach_obj = _read_attr(current, "detach")
|
|
if callable(detach_obj):
|
|
detach_fn = cast(Callable[[], object], detach_obj)
|
|
current = detach_fn()
|
|
|
|
cpu_obj = _read_attr(current, "cpu")
|
|
if callable(cpu_obj):
|
|
cpu_fn = cast(Callable[[], object], cpu_obj)
|
|
current = cpu_fn()
|
|
|
|
numpy_obj = _read_attr(current, "numpy")
|
|
if callable(numpy_obj):
|
|
numpy_fn = cast(Callable[[], object], numpy_obj)
|
|
as_numpy = numpy_fn()
|
|
if isinstance(as_numpy, np.ndarray):
|
|
return as_numpy
|
|
|
|
return cast(NDArray[np.generic], np.asarray(current))
|
|
|
|
|
|
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
|
|
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
|
coords = np.argwhere(mask_u8 > 0)
|
|
if int(coords.size) == 0:
|
|
return None
|
|
|
|
ys = coords[:, 0].astype(np.int64)
|
|
xs = coords[:, 1].astype(np.int64)
|
|
x1 = int(np.min(xs))
|
|
x2 = int(np.max(xs)) + 1
|
|
y1 = int(np.min(ys))
|
|
y2 = int(np.max(ys)) + 1
|
|
if x2 <= x1 or y2 <= y1:
|
|
return None
|
|
return (x1, y1, x2, y2)
|
|
|
|
|
|
def _sanitize_bbox(
|
|
bbox: tuple[int, int, int, int], height: int, width: int
|
|
) -> tuple[int, int, int, int] | None:
|
|
x1, y1, x2, y2 = bbox
|
|
x1c = max(0, min(int(x1), width - 1))
|
|
y1c = max(0, min(int(y1), height - 1))
|
|
x2c = max(0, min(int(x2), width))
|
|
y2c = max(0, min(int(y2), height))
|
|
if x2c <= x1c or y2c <= y1c:
|
|
return None
|
|
return (x1c, y1c, x2c, y2c)
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def frame_to_person_mask(
|
|
result: object, min_area: int = MIN_MASK_AREA
|
|
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
|
|
masks_obj = _read_attr(result, "masks")
|
|
if masks_obj is None:
|
|
return None
|
|
|
|
masks_data_obj = _read_attr(masks_obj, "data")
|
|
if masks_data_obj is None:
|
|
return None
|
|
|
|
masks_raw = _to_numpy_array(masks_data_obj)
|
|
masks_float = np.asarray(masks_raw, dtype=np.float32)
|
|
if masks_float.ndim == 2:
|
|
masks_float = masks_float[np.newaxis, ...]
|
|
if masks_float.ndim != 3:
|
|
return None
|
|
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
|
|
if mask_count <= 0:
|
|
return None
|
|
|
|
box_values: list[tuple[float, float, float, float]] | None = None
|
|
boxes_obj = _read_attr(result, "boxes")
|
|
if boxes_obj is not None:
|
|
xyxy_obj = _read_attr(boxes_obj, "xyxy")
|
|
if xyxy_obj is not None:
|
|
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
|
|
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
|
|
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
|
|
x1f = cast(np.float64, xyxy_2d[0, 0])
|
|
y1f = cast(np.float64, xyxy_2d[0, 1])
|
|
x2f = cast(np.float64, xyxy_2d[0, 2])
|
|
y2f = cast(np.float64, xyxy_2d[0, 3])
|
|
box_values = [
|
|
(
|
|
float(x1f),
|
|
float(y1f),
|
|
float(x2f),
|
|
float(y2f),
|
|
)
|
|
]
|
|
elif xyxy_raw.ndim == 2:
|
|
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
|
|
if int(shape_2d[1]) >= 4:
|
|
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
|
|
box_values = []
|
|
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
|
|
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
|
|
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
|
|
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
|
|
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
|
|
box_values.append(
|
|
(
|
|
float(x1f),
|
|
float(y1f),
|
|
float(x2f),
|
|
float(y2f),
|
|
)
|
|
)
|
|
|
|
best_area = -1
|
|
best_mask: UInt8[ndarray, "h w"] | None = None
|
|
best_bbox: tuple[int, int, int, int] | None = None
|
|
|
|
for idx in range(mask_count):
|
|
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
|
if mask_float.ndim != 2:
|
|
continue
|
|
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
|
|
np.uint8
|
|
)
|
|
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
|
|
|
|
area = int(np.count_nonzero(mask_u8))
|
|
if area < min_area:
|
|
continue
|
|
|
|
bbox: tuple[int, int, int, int] | None = None
|
|
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
|
h = int(shape_2d[0])
|
|
w = int(shape_2d[1])
|
|
if box_values is not None:
|
|
box_count = len(box_values)
|
|
if idx >= box_count:
|
|
continue
|
|
row0, row1, row2, row3 = box_values[idx]
|
|
bbox_candidate = (
|
|
int(math.floor(row0)),
|
|
int(math.floor(row1)),
|
|
int(math.ceil(row2)),
|
|
int(math.ceil(row3)),
|
|
)
|
|
bbox = _sanitize_bbox(bbox_candidate, h, w)
|
|
|
|
if bbox is None:
|
|
bbox = _bbox_from_mask(mask_u8)
|
|
|
|
if bbox is None:
|
|
continue
|
|
|
|
if area > best_area:
|
|
best_area = area
|
|
best_mask = mask_u8
|
|
best_bbox = bbox
|
|
|
|
if best_mask is None or best_bbox is None:
|
|
return None
|
|
|
|
return best_mask, best_bbox
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def mask_to_silhouette(
|
|
mask: UInt8[ndarray, "h w"],
|
|
bbox: tuple[int, int, int, int],
|
|
) -> Float[ndarray, "64 44"] | None:
|
|
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
|
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
|
return None
|
|
|
|
mask_shape = cast(tuple[int, int], mask_u8.shape)
|
|
h = int(mask_shape[0])
|
|
w = int(mask_shape[1])
|
|
bbox_sanitized = _sanitize_bbox(bbox, h, w)
|
|
if bbox_sanitized is None:
|
|
return None
|
|
|
|
x1, y1, x2, y2 = bbox_sanitized
|
|
cropped = mask_u8[y1:y2, x1:x2]
|
|
if cropped.size == 0:
|
|
return None
|
|
|
|
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
|
|
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
|
|
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
|
|
if int(row_nonzero.size) == 0:
|
|
return None
|
|
top = int(cast(np.int64, row_nonzero[0]))
|
|
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
|
|
tightened = cropped[top:bottom, :]
|
|
if tightened.size == 0:
|
|
return None
|
|
|
|
tight_shape = cast(tuple[int, int], tightened.shape)
|
|
tight_h = int(tight_shape[0])
|
|
tight_w = int(tight_shape[1])
|
|
if tight_h <= 0 or tight_w <= 0:
|
|
return None
|
|
|
|
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
|
|
resized = np.asarray(
|
|
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
|
|
dtype=np.uint8,
|
|
)
|
|
|
|
if resized_w >= SIL_FULL_WIDTH:
|
|
start = (resized_w - SIL_FULL_WIDTH) // 2
|
|
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
|
|
else:
|
|
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
|
|
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
|
|
normalized_64 = np.pad(
|
|
resized,
|
|
((0, 0), (pad_left, pad_right)),
|
|
mode="constant",
|
|
constant_values=0,
|
|
)
|
|
|
|
silhouette = np.asarray(
|
|
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
|
|
)
|
|
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
|
|
return None
|
|
|
|
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
|
|
np.float32
|
|
)
|
|
return cast(Float[ndarray, "64 44"], silhouette_norm)
|