fix(demo): stabilize visualizer bbox and mask rendering

Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
This commit is contained in:
2026-02-28 18:05:33 +08:00
parent 06a6cd1ccf
commit 7f073179d7
7 changed files with 416 additions and 73 deletions
+47 -8
View File
@@ -23,6 +23,9 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32]
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
BBoxXYXY = tuple[int, int, int, int]
def _read_attr(container: object, key: str) -> object | None:
if isinstance(container, dict):
@@ -59,7 +62,15 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]:
return cast(NDArray[np.generic], np.asarray(current))
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
"""Extract bounding box from binary mask in XYXY format.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
Returns:
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
"""
mask_u8 = np.asarray(mask, dtype=np.uint8)
coords = np.argwhere(mask_u8 > 0)
if int(coords.size) == 0:
@@ -76,9 +87,17 @@ def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] |
return (x1, y1, x2, y2)
def _sanitize_bbox(
bbox: tuple[int, int, int, int], height: int, width: int
) -> tuple[int, int, int, int] | None:
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
"""Sanitize bounding box to ensure it's within image bounds.
Args:
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
height: Image height.
width: Image width.
Returns:
Sanitized bounding box in XYXY format, or None if invalid.
"""
x1, y1, x2, y2 = bbox
x1c = max(0, min(int(x1), width - 1))
y1c = max(0, min(int(y1), height - 1))
@@ -92,7 +111,17 @@ def _sanitize_bbox(
@jaxtyped(typechecker=beartype)
def frame_to_person_mask(
result: object, min_area: int = MIN_MASK_AREA
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
"""Extract person mask and bounding box from detection result.
Args:
result: Detection results object with boxes and masks attributes.
min_area: Minimum mask area to consider valid.
Returns:
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
or None if no valid detections.
"""
masks_obj = _read_attr(result, "masks")
if masks_obj is None:
return None
@@ -152,7 +181,7 @@ def frame_to_person_mask(
best_area = -1
best_mask: UInt8[ndarray, "h w"] | None = None
best_bbox: tuple[int, int, int, int] | None = None
best_bbox: BBoxXYXY | None = None
for idx in range(mask_count):
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
@@ -167,7 +196,7 @@ def frame_to_person_mask(
if area < min_area:
continue
bbox: tuple[int, int, int, int] | None = None
bbox: BBoxXYXY | None = None
shape_2d = cast(tuple[int, int], mask_binary.shape)
h = int(shape_2d[0])
w = int(shape_2d[1])
@@ -204,8 +233,18 @@ def frame_to_person_mask(
@jaxtyped(typechecker=beartype)
def mask_to_silhouette(
mask: UInt8[ndarray, "h w"],
bbox: tuple[int, int, int, int],
bbox: BBoxXYXY,
) -> Float[ndarray, "64 44"] | None:
"""Convert mask to standardized silhouette using bounding box.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
Returns:
Standardized silhouette array of shape (64, 44) with dtype float32,
or None if conversion fails.
"""
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
return None