from collections.abc import Callable import math from typing import cast import cv2 from beartype import beartype import jaxtyping from jaxtyping import Float, UInt8 import numpy as np from numpy import ndarray from numpy.typing import NDArray SIL_HEIGHT = 64 SIL_WIDTH = 44 SIL_FULL_WIDTH = 64 SIDE_CUT = 10 MIN_MASK_AREA = 500 JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] JaxtypedFactory = Callable[..., JaxtypedDecorator] jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) UInt8Array = NDArray[np.uint8] Float32Array = NDArray[np.float32] #: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right. BBoxXYXY = tuple[int, int, int, int] def _read_attr(container: object, key: str) -> object | None: if isinstance(container, dict): dict_obj = cast(dict[object, object], container) return dict_obj.get(key) try: return cast(object, object.__getattribute__(container, key)) except AttributeError: return None def _to_numpy_array(value: object) -> NDArray[np.generic]: current: object = value if isinstance(current, np.ndarray): return current detach_obj = _read_attr(current, "detach") if callable(detach_obj): detach_fn = cast(Callable[[], object], detach_obj) current = detach_fn() cpu_obj = _read_attr(current, "cpu") if callable(cpu_obj): cpu_fn = cast(Callable[[], object], cpu_obj) current = cpu_fn() numpy_obj = _read_attr(current, "numpy") if callable(numpy_obj): numpy_fn = cast(Callable[[], object], numpy_obj) as_numpy = numpy_fn() if isinstance(as_numpy, np.ndarray): return as_numpy return cast(NDArray[np.generic], np.asarray(current)) def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None: """Extract bounding box from binary mask in XYXY format. Args: mask: Binary mask array of shape (H, W) with dtype uint8. Returns: Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty. """ mask_u8 = np.asarray(mask, dtype=np.uint8) coords = np.argwhere(mask_u8 > 0) if int(coords.size) == 0: return None ys = coords[:, 0].astype(np.int64) xs = coords[:, 1].astype(np.int64) x1 = int(np.min(xs)) x2 = int(np.max(xs)) + 1 y1 = int(np.min(ys)) y2 = int(np.max(ys)) + 1 if x2 <= x1 or y2 <= y1: return None return (x1, y1, x2, y2) def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None: """Sanitize bounding box to ensure it's within image bounds. Args: bbox: Bounding box in XYXY format (x1, y1, x2, y2). height: Image height. width: Image width. Returns: Sanitized bounding box in XYXY format, or None if invalid. """ x1, y1, x2, y2 = bbox x1c = max(0, min(int(x1), width - 1)) y1c = max(0, min(int(y1), height - 1)) x2c = max(0, min(int(x2), width)) y2c = max(0, min(int(y2), height)) if x2c <= x1c or y2c <= y1c: return None return (x1c, y1c, x2c, y2c) @jaxtyped(typechecker=beartype) def frame_to_person_mask( result: object, min_area: int = MIN_MASK_AREA ) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None: """Extract person mask and bounding box from detection result. Args: result: Detection results object with boxes and masks attributes. min_area: Minimum mask area to consider valid. Returns: Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2), or None if no valid detections. """ masks_obj = _read_attr(result, "masks") if masks_obj is None: return None masks_data_obj = _read_attr(masks_obj, "data") if masks_data_obj is None: return None masks_raw = _to_numpy_array(masks_data_obj) masks_float = np.asarray(masks_raw, dtype=np.float32) if masks_float.ndim == 2: masks_float = masks_float[np.newaxis, ...] if masks_float.ndim != 3: return None mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0]) if mask_count <= 0: return None box_values: list[tuple[float, float, float, float]] | None = None boxes_obj = _read_attr(result, "boxes") if boxes_obj is not None: xyxy_obj = _read_attr(boxes_obj, "xyxy") if xyxy_obj is not None: xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32) if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4: xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64) x1f = cast(np.float64, xyxy_2d[0, 0]) y1f = cast(np.float64, xyxy_2d[0, 1]) x2f = cast(np.float64, xyxy_2d[0, 2]) y2f = cast(np.float64, xyxy_2d[0, 3]) box_values = [ ( float(x1f), float(y1f), float(x2f), float(y2f), ) ] elif xyxy_raw.ndim == 2: shape_2d = cast(tuple[int, int], xyxy_raw.shape) if int(shape_2d[1]) >= 4: xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64) box_values = [] for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])): x1f = cast(np.float64, xyxy_2d[row_idx, 0]) y1f = cast(np.float64, xyxy_2d[row_idx, 1]) x2f = cast(np.float64, xyxy_2d[row_idx, 2]) y2f = cast(np.float64, xyxy_2d[row_idx, 3]) box_values.append( ( float(x1f), float(y1f), float(x2f), float(y2f), ) ) best_area = -1 best_mask: UInt8[ndarray, "h w"] | None = None best_bbox: BBoxXYXY | None = None for idx in range(mask_count): mask_float = np.asarray(masks_float[idx], dtype=np.float32) if mask_float.ndim != 2: continue mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype( np.uint8 ) mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary) area = int(np.count_nonzero(mask_u8)) if area < min_area: continue bbox: BBoxXYXY | None = None shape_2d = cast(tuple[int, int], mask_binary.shape) h = int(shape_2d[0]) w = int(shape_2d[1]) if box_values is not None: box_count = len(box_values) if idx >= box_count: continue row0, row1, row2, row3 = box_values[idx] bbox_candidate = ( int(math.floor(row0)), int(math.floor(row1)), int(math.ceil(row2)), int(math.ceil(row3)), ) bbox = _sanitize_bbox(bbox_candidate, h, w) if bbox is None: bbox = _bbox_from_mask(mask_u8) if bbox is None: continue if area > best_area: best_area = area best_mask = mask_u8 best_bbox = bbox if best_mask is None or best_bbox is None: return None return best_mask, best_bbox @jaxtyped(typechecker=beartype) def mask_to_silhouette( mask: UInt8[ndarray, "h w"], bbox: BBoxXYXY, ) -> Float[ndarray, "64 44"] | None: """Convert mask to standardized silhouette using bounding box. Args: mask: Binary mask array of shape (H, W) with dtype uint8. bbox: Bounding box in XYXY format (x1, y1, x2, y2). Returns: Standardized silhouette array of shape (64, 44) with dtype float32, or None if conversion fails. """ mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8) if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA: return None mask_shape = cast(tuple[int, int], mask_u8.shape) h = int(mask_shape[0]) w = int(mask_shape[1]) bbox_sanitized = _sanitize_bbox(bbox, h, w) if bbox_sanitized is None: return None x1, y1, x2, y2 = bbox_sanitized cropped = mask_u8[y1:y2, x1:x2] if cropped.size == 0: return None cropped_u8 = np.asarray(cropped, dtype=np.uint8) row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64) row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64) if int(row_nonzero.size) == 0: return None top = int(cast(np.int64, row_nonzero[0])) bottom = int(cast(np.int64, row_nonzero[-1])) + 1 tightened = cropped[top:bottom, :] if tightened.size == 0: return None tight_shape = cast(tuple[int, int], tightened.shape) tight_h = int(tight_shape[0]) tight_w = int(tight_shape[1]) if tight_h <= 0 or tight_w <= 0: return None resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h))) resized = np.asarray( cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC), dtype=np.uint8, ) if resized_w >= SIL_FULL_WIDTH: start = (resized_w - SIL_FULL_WIDTH) // 2 normalized_64 = resized[:, start : start + SIL_FULL_WIDTH] else: pad_left = (SIL_FULL_WIDTH - resized_w) // 2 pad_right = SIL_FULL_WIDTH - resized_w - pad_left normalized_64 = np.pad( resized, ((0, 0), (pad_left, pad_right)), mode="constant", constant_values=0, ) silhouette = np.asarray( normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32 ) if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH): return None silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype( np.float32 ) return cast(Float[ndarray, "64 44"], silhouette_norm)