fix(pipeline): enhance silhouette selection with structured output
This commit is contained in:
+39
-14
@@ -5,7 +5,7 @@ from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
|
||||
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
@@ -75,6 +75,24 @@ class _TrackCallable(Protocol):
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class _SelectedSilhouette(TypedDict):
|
||||
"""Selected silhouette payload produced from detector outputs.
|
||||
|
||||
Fields:
|
||||
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
|
||||
mask_raw: Full-resolution binary person mask in mask/image space.
|
||||
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
|
||||
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
|
||||
track_id: Tracking ID from detector, or `0` for fallback path.
|
||||
"""
|
||||
|
||||
silhouette: Float[ndarray, "64 44"]
|
||||
mask_raw: UInt8[ndarray, "h w"]
|
||||
bbox_frame: BBoxXYXY
|
||||
bbox_mask: BBoxXYXY
|
||||
track_id: int
|
||||
|
||||
|
||||
class _FramePacer:
|
||||
_interval_ns: int
|
||||
_next_emit_ns: int | None
|
||||
@@ -206,16 +224,7 @@ class ScoliosisPipeline:
|
||||
def _select_silhouette(
|
||||
self,
|
||||
result: _DetectionResultsLike,
|
||||
) -> (
|
||||
tuple[
|
||||
Float[ndarray, "64 44"],
|
||||
UInt8[ndarray, "h w"],
|
||||
BBoxXYXY,
|
||||
BBoxXYXY,
|
||||
int,
|
||||
]
|
||||
| None
|
||||
):
|
||||
) -> _SelectedSilhouette | None:
|
||||
selected = select_person(result)
|
||||
if selected is not None:
|
||||
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
||||
@@ -224,7 +233,13 @@ class ScoliosisPipeline:
|
||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
||||
)
|
||||
if silhouette is not None:
|
||||
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
|
||||
return {
|
||||
"silhouette": silhouette,
|
||||
"mask_raw": mask_raw,
|
||||
"bbox_frame": bbox_frame,
|
||||
"bbox_mask": bbox_mask,
|
||||
"track_id": int(track_id),
|
||||
}
|
||||
|
||||
fallback = cast(
|
||||
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
||||
@@ -266,7 +281,13 @@ class ScoliosisPipeline:
|
||||
# Fallback: use mask-space bbox if orig_shape unavailable
|
||||
bbox_frame = bbox_mask
|
||||
# For fallback case, mask_raw is the same as mask_u8
|
||||
return silhouette, mask_u8, bbox_frame, bbox_mask, 0
|
||||
return {
|
||||
"silhouette": silhouette,
|
||||
"mask_raw": mask_u8,
|
||||
"bbox_frame": bbox_frame,
|
||||
"bbox_mask": bbox_mask,
|
||||
"track_id": 0,
|
||||
}
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def process_frame(
|
||||
@@ -297,7 +318,11 @@ class ScoliosisPipeline:
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
silhouette, mask_raw, bbox, bbox_mask, track_id = selected
|
||||
silhouette = selected["silhouette"]
|
||||
mask_raw = selected["mask_raw"]
|
||||
bbox = selected["bbox_frame"]
|
||||
bbox_mask = selected["bbox_mask"]
|
||||
track_id = selected["track_id"]
|
||||
|
||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||
|
||||
Reference in New Issue
Block a user