fix(pipeline): enhance silhouette selection with structured output
This commit is contained in:
+39
-14
@@ -5,7 +5,7 @@ from contextlib import suppress
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
|
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
|
||||||
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
import click
|
import click
|
||||||
@@ -75,6 +75,24 @@ class _TrackCallable(Protocol):
|
|||||||
) -> object: ...
|
) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _SelectedSilhouette(TypedDict):
|
||||||
|
"""Selected silhouette payload produced from detector outputs.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
|
||||||
|
mask_raw: Full-resolution binary person mask in mask/image space.
|
||||||
|
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
|
||||||
|
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
|
||||||
|
track_id: Tracking ID from detector, or `0` for fallback path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
silhouette: Float[ndarray, "64 44"]
|
||||||
|
mask_raw: UInt8[ndarray, "h w"]
|
||||||
|
bbox_frame: BBoxXYXY
|
||||||
|
bbox_mask: BBoxXYXY
|
||||||
|
track_id: int
|
||||||
|
|
||||||
|
|
||||||
class _FramePacer:
|
class _FramePacer:
|
||||||
_interval_ns: int
|
_interval_ns: int
|
||||||
_next_emit_ns: int | None
|
_next_emit_ns: int | None
|
||||||
@@ -206,16 +224,7 @@ class ScoliosisPipeline:
|
|||||||
def _select_silhouette(
|
def _select_silhouette(
|
||||||
self,
|
self,
|
||||||
result: _DetectionResultsLike,
|
result: _DetectionResultsLike,
|
||||||
) -> (
|
) -> _SelectedSilhouette | None:
|
||||||
tuple[
|
|
||||||
Float[ndarray, "64 44"],
|
|
||||||
UInt8[ndarray, "h w"],
|
|
||||||
BBoxXYXY,
|
|
||||||
BBoxXYXY,
|
|
||||||
int,
|
|
||||||
]
|
|
||||||
| None
|
|
||||||
):
|
|
||||||
selected = select_person(result)
|
selected = select_person(result)
|
||||||
if selected is not None:
|
if selected is not None:
|
||||||
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
||||||
@@ -224,7 +233,13 @@ class ScoliosisPipeline:
|
|||||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
||||||
)
|
)
|
||||||
if silhouette is not None:
|
if silhouette is not None:
|
||||||
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
|
return {
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"mask_raw": mask_raw,
|
||||||
|
"bbox_frame": bbox_frame,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
|
"track_id": int(track_id),
|
||||||
|
}
|
||||||
|
|
||||||
fallback = cast(
|
fallback = cast(
|
||||||
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
||||||
@@ -266,7 +281,13 @@ class ScoliosisPipeline:
|
|||||||
# Fallback: use mask-space bbox if orig_shape unavailable
|
# Fallback: use mask-space bbox if orig_shape unavailable
|
||||||
bbox_frame = bbox_mask
|
bbox_frame = bbox_mask
|
||||||
# For fallback case, mask_raw is the same as mask_u8
|
# For fallback case, mask_raw is the same as mask_u8
|
||||||
return silhouette, mask_u8, bbox_frame, bbox_mask, 0
|
return {
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"mask_raw": mask_u8,
|
||||||
|
"bbox_frame": bbox_frame,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
|
"track_id": 0,
|
||||||
|
}
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def process_frame(
|
def process_frame(
|
||||||
@@ -297,7 +318,11 @@ class ScoliosisPipeline:
|
|||||||
if selected is None:
|
if selected is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
silhouette, mask_raw, bbox, bbox_mask, track_id = selected
|
silhouette = selected["silhouette"]
|
||||||
|
mask_raw = selected["mask_raw"]
|
||||||
|
bbox = selected["bbox_frame"]
|
||||||
|
bbox_mask = selected["bbox_mask"]
|
||||||
|
track_id = selected["track_id"]
|
||||||
|
|
||||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||||
|
|||||||
Reference in New Issue
Block a user