feat(demo): implement ScoNet real-time pipeline runtime
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
This commit is contained in:
@@ -0,0 +1,270 @@
|
||||
from collections.abc import Callable
|
||||
import math
|
||||
from typing import cast
|
||||
|
||||
import cv2
|
||||
from beartype import beartype
|
||||
import jaxtyping
|
||||
from jaxtyping import Float, UInt8
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.typing import NDArray
|
||||
|
||||
SIL_HEIGHT = 64
|
||||
SIL_WIDTH = 44
|
||||
SIL_FULL_WIDTH = 64
|
||||
SIDE_CUT = 10
|
||||
MIN_MASK_AREA = 500
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
UInt8Array = NDArray[np.uint8]
|
||||
Float32Array = NDArray[np.float32]
|
||||
|
||||
|
||||
def _read_attr(container: object, key: str) -> object | None:
|
||||
if isinstance(container, dict):
|
||||
dict_obj = cast(dict[object, object], container)
|
||||
return dict_obj.get(key)
|
||||
try:
|
||||
return cast(object, object.__getattribute__(container, key))
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
||||
current: object = value
|
||||
if isinstance(current, np.ndarray):
|
||||
return current
|
||||
|
||||
detach_obj = _read_attr(current, "detach")
|
||||
if callable(detach_obj):
|
||||
detach_fn = cast(Callable[[], object], detach_obj)
|
||||
current = detach_fn()
|
||||
|
||||
cpu_obj = _read_attr(current, "cpu")
|
||||
if callable(cpu_obj):
|
||||
cpu_fn = cast(Callable[[], object], cpu_obj)
|
||||
current = cpu_fn()
|
||||
|
||||
numpy_obj = _read_attr(current, "numpy")
|
||||
if callable(numpy_obj):
|
||||
numpy_fn = cast(Callable[[], object], numpy_obj)
|
||||
as_numpy = numpy_fn()
|
||||
if isinstance(as_numpy, np.ndarray):
|
||||
return as_numpy
|
||||
|
||||
return cast(NDArray[np.generic], np.asarray(current))
|
||||
|
||||
|
||||
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
|
||||
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||
coords = np.argwhere(mask_u8 > 0)
|
||||
if int(coords.size) == 0:
|
||||
return None
|
||||
|
||||
ys = coords[:, 0].astype(np.int64)
|
||||
xs = coords[:, 1].astype(np.int64)
|
||||
x1 = int(np.min(xs))
|
||||
x2 = int(np.max(xs)) + 1
|
||||
y1 = int(np.min(ys))
|
||||
y2 = int(np.max(ys)) + 1
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return None
|
||||
return (x1, y1, x2, y2)
|
||||
|
||||
|
||||
def _sanitize_bbox(
|
||||
bbox: tuple[int, int, int, int], height: int, width: int
|
||||
) -> tuple[int, int, int, int] | None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1c = max(0, min(int(x1), width - 1))
|
||||
y1c = max(0, min(int(y1), height - 1))
|
||||
x2c = max(0, min(int(x2), width))
|
||||
y2c = max(0, min(int(y2), height))
|
||||
if x2c <= x1c or y2c <= y1c:
|
||||
return None
|
||||
return (x1c, y1c, x2c, y2c)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def frame_to_person_mask(
|
||||
result: object, min_area: int = MIN_MASK_AREA
|
||||
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
|
||||
masks_obj = _read_attr(result, "masks")
|
||||
if masks_obj is None:
|
||||
return None
|
||||
|
||||
masks_data_obj = _read_attr(masks_obj, "data")
|
||||
if masks_data_obj is None:
|
||||
return None
|
||||
|
||||
masks_raw = _to_numpy_array(masks_data_obj)
|
||||
masks_float = np.asarray(masks_raw, dtype=np.float32)
|
||||
if masks_float.ndim == 2:
|
||||
masks_float = masks_float[np.newaxis, ...]
|
||||
if masks_float.ndim != 3:
|
||||
return None
|
||||
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
|
||||
if mask_count <= 0:
|
||||
return None
|
||||
|
||||
box_values: list[tuple[float, float, float, float]] | None = None
|
||||
boxes_obj = _read_attr(result, "boxes")
|
||||
if boxes_obj is not None:
|
||||
xyxy_obj = _read_attr(boxes_obj, "xyxy")
|
||||
if xyxy_obj is not None:
|
||||
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
|
||||
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
|
||||
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
|
||||
x1f = cast(np.float64, xyxy_2d[0, 0])
|
||||
y1f = cast(np.float64, xyxy_2d[0, 1])
|
||||
x2f = cast(np.float64, xyxy_2d[0, 2])
|
||||
y2f = cast(np.float64, xyxy_2d[0, 3])
|
||||
box_values = [
|
||||
(
|
||||
float(x1f),
|
||||
float(y1f),
|
||||
float(x2f),
|
||||
float(y2f),
|
||||
)
|
||||
]
|
||||
elif xyxy_raw.ndim == 2:
|
||||
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
|
||||
if int(shape_2d[1]) >= 4:
|
||||
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
|
||||
box_values = []
|
||||
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
|
||||
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
|
||||
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
|
||||
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
|
||||
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
|
||||
box_values.append(
|
||||
(
|
||||
float(x1f),
|
||||
float(y1f),
|
||||
float(x2f),
|
||||
float(y2f),
|
||||
)
|
||||
)
|
||||
|
||||
best_area = -1
|
||||
best_mask: UInt8[ndarray, "h w"] | None = None
|
||||
best_bbox: tuple[int, int, int, int] | None = None
|
||||
|
||||
for idx in range(mask_count):
|
||||
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
||||
if mask_float.ndim != 2:
|
||||
continue
|
||||
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
|
||||
|
||||
area = int(np.count_nonzero(mask_u8))
|
||||
if area < min_area:
|
||||
continue
|
||||
|
||||
bbox: tuple[int, int, int, int] | None = None
|
||||
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
||||
h = int(shape_2d[0])
|
||||
w = int(shape_2d[1])
|
||||
if box_values is not None:
|
||||
box_count = len(box_values)
|
||||
if idx >= box_count:
|
||||
continue
|
||||
row0, row1, row2, row3 = box_values[idx]
|
||||
bbox_candidate = (
|
||||
int(math.floor(row0)),
|
||||
int(math.floor(row1)),
|
||||
int(math.ceil(row2)),
|
||||
int(math.ceil(row3)),
|
||||
)
|
||||
bbox = _sanitize_bbox(bbox_candidate, h, w)
|
||||
|
||||
if bbox is None:
|
||||
bbox = _bbox_from_mask(mask_u8)
|
||||
|
||||
if bbox is None:
|
||||
continue
|
||||
|
||||
if area > best_area:
|
||||
best_area = area
|
||||
best_mask = mask_u8
|
||||
best_bbox = bbox
|
||||
|
||||
if best_mask is None or best_bbox is None:
|
||||
return None
|
||||
|
||||
return best_mask, best_bbox
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def mask_to_silhouette(
|
||||
mask: UInt8[ndarray, "h w"],
|
||||
bbox: tuple[int, int, int, int],
|
||||
) -> Float[ndarray, "64 44"] | None:
|
||||
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
||||
return None
|
||||
|
||||
mask_shape = cast(tuple[int, int], mask_u8.shape)
|
||||
h = int(mask_shape[0])
|
||||
w = int(mask_shape[1])
|
||||
bbox_sanitized = _sanitize_bbox(bbox, h, w)
|
||||
if bbox_sanitized is None:
|
||||
return None
|
||||
|
||||
x1, y1, x2, y2 = bbox_sanitized
|
||||
cropped = mask_u8[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
return None
|
||||
|
||||
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
|
||||
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
|
||||
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
|
||||
if int(row_nonzero.size) == 0:
|
||||
return None
|
||||
top = int(cast(np.int64, row_nonzero[0]))
|
||||
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
|
||||
tightened = cropped[top:bottom, :]
|
||||
if tightened.size == 0:
|
||||
return None
|
||||
|
||||
tight_shape = cast(tuple[int, int], tightened.shape)
|
||||
tight_h = int(tight_shape[0])
|
||||
tight_w = int(tight_shape[1])
|
||||
if tight_h <= 0 or tight_w <= 0:
|
||||
return None
|
||||
|
||||
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
|
||||
resized = np.asarray(
|
||||
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
if resized_w >= SIL_FULL_WIDTH:
|
||||
start = (resized_w - SIL_FULL_WIDTH) // 2
|
||||
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
|
||||
else:
|
||||
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
|
||||
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
|
||||
normalized_64 = np.pad(
|
||||
resized,
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
silhouette = np.asarray(
|
||||
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
|
||||
)
|
||||
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||
return None
|
||||
|
||||
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
|
||||
np.float32
|
||||
)
|
||||
return cast(Float[ndarray, "64 44"], silhouette_norm)
|
||||
Reference in New Issue
Block a user