fix(pipeline): enhance silhouette selection with structured output

This commit is contained in:
2026-03-02 17:22:33 +08:00
parent ab738c1615
commit 654409ff50
+39 -14
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging import logging
from pathlib import Path from pathlib import Path
import time import time
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
from beartype import beartype from beartype import beartype
import click import click
@@ -75,6 +75,24 @@ class _TrackCallable(Protocol):
) -> object: ... ) -> object: ...
class _SelectedSilhouette(TypedDict):
"""Selected silhouette payload produced from detector outputs.
Fields:
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
mask_raw: Full-resolution binary person mask in mask/image space.
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
track_id: Tracking ID from detector, or `0` for fallback path.
"""
silhouette: Float[ndarray, "64 44"]
mask_raw: UInt8[ndarray, "h w"]
bbox_frame: BBoxXYXY
bbox_mask: BBoxXYXY
track_id: int
class _FramePacer: class _FramePacer:
_interval_ns: int _interval_ns: int
_next_emit_ns: int | None _next_emit_ns: int | None
@@ -206,16 +224,7 @@ class ScoliosisPipeline:
def _select_silhouette( def _select_silhouette(
self, self,
result: _DetectionResultsLike, result: _DetectionResultsLike,
) -> ( ) -> _SelectedSilhouette | None:
tuple[
Float[ndarray, "64 44"],
UInt8[ndarray, "h w"],
BBoxXYXY,
BBoxXYXY,
int,
]
| None
):
selected = select_person(result) selected = select_person(result)
if selected is not None: if selected is not None:
mask_raw, bbox_mask, bbox_frame, track_id = selected mask_raw, bbox_mask, bbox_frame, track_id = selected
@@ -224,7 +233,13 @@ class ScoliosisPipeline:
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask), mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
) )
if silhouette is not None: if silhouette is not None:
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id) return {
"silhouette": silhouette,
"mask_raw": mask_raw,
"bbox_frame": bbox_frame,
"bbox_mask": bbox_mask,
"track_id": int(track_id),
}
fallback = cast( fallback = cast(
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None, tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
@@ -266,7 +281,13 @@ class ScoliosisPipeline:
# Fallback: use mask-space bbox if orig_shape unavailable # Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8 # For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox_frame, bbox_mask, 0 return {
"silhouette": silhouette,
"mask_raw": mask_u8,
"bbox_frame": bbox_frame,
"bbox_mask": bbox_mask,
"track_id": 0,
}
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def process_frame( def process_frame(
@@ -297,7 +318,11 @@ class ScoliosisPipeline:
if selected is None: if selected is None:
return None return None
silhouette, mask_raw, bbox, bbox_mask, track_id = selected silhouette = selected["silhouette"]
mask_raw = selected["mask_raw"]
bbox = selected["bbox_frame"]
bbox_mask = selected["bbox_mask"]
track_id = selected["track_id"]
# Store silhouette for export if in preprocess-only mode or if export requested # Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only: if self._silhouette_export_path is not None or self._preprocess_only: