336 lines
11 KiB
Python
336 lines
11 KiB
Python
from collections.abc import Callable
|
|
import math
|
|
from typing import cast
|
|
|
|
import cv2
|
|
from beartype import beartype
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
|
|
SIL_HEIGHT = 64
|
|
SIL_WIDTH = 44
|
|
SIL_FULL_WIDTH = 64
|
|
SIDE_CUT = 10
|
|
MIN_MASK_AREA = 500
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
UInt8Array = NDArray[np.uint8]
|
|
Float32Array = NDArray[np.float32]
|
|
|
|
BBoxXYXY = tuple[int, int, int, int]
|
|
"""
|
|
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
|
|
"""
|
|
|
|
|
|
def _read_attr(container: object, key: str) -> object | None:
|
|
if isinstance(container, dict):
|
|
dict_obj = cast(dict[object, object], container)
|
|
return dict_obj.get(key)
|
|
try:
|
|
return cast(object, object.__getattribute__(container, key))
|
|
except AttributeError:
|
|
return None
|
|
|
|
|
|
def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
|
current: object = value
|
|
if isinstance(current, np.ndarray):
|
|
return current
|
|
|
|
detach_obj = _read_attr(current, "detach")
|
|
if callable(detach_obj):
|
|
detach_fn = cast(Callable[[], object], detach_obj)
|
|
current = detach_fn()
|
|
|
|
cpu_obj = _read_attr(current, "cpu")
|
|
if callable(cpu_obj):
|
|
cpu_fn = cast(Callable[[], object], cpu_obj)
|
|
current = cpu_fn()
|
|
|
|
numpy_obj = _read_attr(current, "numpy")
|
|
if callable(numpy_obj):
|
|
numpy_fn = cast(Callable[[], object], numpy_obj)
|
|
as_numpy = numpy_fn()
|
|
if isinstance(as_numpy, np.ndarray):
|
|
return as_numpy
|
|
|
|
return cast(NDArray[np.generic], np.asarray(current))
|
|
|
|
|
|
def _fill_binary_holes(mask_u8: UInt8Array) -> UInt8Array:
|
|
mask_bin = np.where(mask_u8 > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
|
h, w = cast(tuple[int, int], mask_bin.shape)
|
|
if h <= 2 or w <= 2:
|
|
return mask_bin
|
|
|
|
seed_candidates = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)]
|
|
seed: tuple[int, int] | None = None
|
|
for x, y in seed_candidates:
|
|
if int(mask_bin[y, x]) == 0:
|
|
seed = (x, y)
|
|
break
|
|
if seed is None:
|
|
return mask_bin
|
|
|
|
flood = mask_bin.copy()
|
|
flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
|
|
_ = cv2.floodFill(flood, flood_mask, seed, 255)
|
|
holes = cv2.bitwise_not(flood)
|
|
filled = cv2.bitwise_or(mask_bin, holes)
|
|
return cast(UInt8Array, filled)
|
|
|
|
|
|
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
|
|
"""Extract bounding box from binary mask in XYXY format.
|
|
|
|
Args:
|
|
mask: Binary mask array of shape (H, W) with dtype uint8.
|
|
|
|
Returns:
|
|
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
|
|
"""
|
|
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
|
coords = np.argwhere(mask_u8 > 0)
|
|
if int(coords.size) == 0:
|
|
return None
|
|
|
|
ys = coords[:, 0].astype(np.int64)
|
|
xs = coords[:, 1].astype(np.int64)
|
|
x1 = int(np.min(xs))
|
|
x2 = int(np.max(xs)) + 1
|
|
y1 = int(np.min(ys))
|
|
y2 = int(np.max(ys)) + 1
|
|
if x2 <= x1 or y2 <= y1:
|
|
return None
|
|
return (x1, y1, x2, y2)
|
|
|
|
|
|
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
|
|
"""Sanitize bounding box to ensure it's within image bounds.
|
|
|
|
Args:
|
|
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
|
|
height: Image height.
|
|
width: Image width.
|
|
|
|
Returns:
|
|
Sanitized bounding box in XYXY format, or None if invalid.
|
|
"""
|
|
x1, y1, x2, y2 = bbox
|
|
x1c = max(0, min(int(x1), width - 1))
|
|
y1c = max(0, min(int(y1), height - 1))
|
|
x2c = max(0, min(int(x2), width))
|
|
y2c = max(0, min(int(y2), height))
|
|
if x2c <= x1c or y2c <= y1c:
|
|
return None
|
|
return (x1c, y1c, x2c, y2c)
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def frame_to_person_mask(
|
|
result: object, min_area: int = MIN_MASK_AREA
|
|
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
|
|
"""Extract person mask and bounding box from detection result.
|
|
|
|
Args:
|
|
result: Detection results object with boxes and masks attributes.
|
|
min_area: Minimum mask area to consider valid.
|
|
|
|
Returns:
|
|
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
|
|
or None if no valid detections.
|
|
"""
|
|
masks_obj = _read_attr(result, "masks")
|
|
if masks_obj is None:
|
|
return None
|
|
|
|
masks_data_obj = _read_attr(masks_obj, "data")
|
|
if masks_data_obj is None:
|
|
return None
|
|
|
|
masks_raw = _to_numpy_array(masks_data_obj)
|
|
masks_float = np.asarray(masks_raw, dtype=np.float32)
|
|
if masks_float.ndim == 2:
|
|
masks_float = masks_float[np.newaxis, ...]
|
|
if masks_float.ndim != 3:
|
|
return None
|
|
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
|
|
if mask_count <= 0:
|
|
return None
|
|
|
|
box_values: list[tuple[float, float, float, float]] | None = None
|
|
boxes_obj = _read_attr(result, "boxes")
|
|
if boxes_obj is not None:
|
|
xyxy_obj = _read_attr(boxes_obj, "xyxy")
|
|
if xyxy_obj is not None:
|
|
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
|
|
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
|
|
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
|
|
x1f = cast(np.float64, xyxy_2d[0, 0])
|
|
y1f = cast(np.float64, xyxy_2d[0, 1])
|
|
x2f = cast(np.float64, xyxy_2d[0, 2])
|
|
y2f = cast(np.float64, xyxy_2d[0, 3])
|
|
box_values = [
|
|
(
|
|
float(x1f),
|
|
float(y1f),
|
|
float(x2f),
|
|
float(y2f),
|
|
)
|
|
]
|
|
elif xyxy_raw.ndim == 2:
|
|
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
|
|
if int(shape_2d[1]) >= 4:
|
|
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
|
|
box_values = []
|
|
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
|
|
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
|
|
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
|
|
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
|
|
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
|
|
box_values.append(
|
|
(
|
|
float(x1f),
|
|
float(y1f),
|
|
float(x2f),
|
|
float(y2f),
|
|
)
|
|
)
|
|
|
|
best_area = -1
|
|
best_mask: UInt8[ndarray, "h w"] | None = None
|
|
best_bbox: BBoxXYXY | None = None
|
|
|
|
for idx in range(mask_count):
|
|
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
|
if mask_float.ndim != 2:
|
|
continue
|
|
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
|
|
np.uint8
|
|
)
|
|
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
|
|
|
|
area = int(np.count_nonzero(mask_u8))
|
|
if area < min_area:
|
|
continue
|
|
|
|
bbox: BBoxXYXY | None = None
|
|
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
|
h = int(shape_2d[0])
|
|
w = int(shape_2d[1])
|
|
if box_values is not None:
|
|
box_count = len(box_values)
|
|
if idx >= box_count:
|
|
continue
|
|
row0, row1, row2, row3 = box_values[idx]
|
|
bbox_candidate = (
|
|
int(math.floor(row0)),
|
|
int(math.floor(row1)),
|
|
int(math.ceil(row2)),
|
|
int(math.ceil(row3)),
|
|
)
|
|
bbox = _sanitize_bbox(bbox_candidate, h, w)
|
|
|
|
if bbox is None:
|
|
bbox = _bbox_from_mask(mask_u8)
|
|
|
|
if bbox is None:
|
|
continue
|
|
|
|
if area > best_area:
|
|
best_area = area
|
|
best_mask = mask_u8
|
|
best_bbox = bbox
|
|
|
|
if best_mask is None or best_bbox is None:
|
|
return None
|
|
|
|
return best_mask, best_bbox
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def mask_to_silhouette(
|
|
mask: UInt8[ndarray, "h w"],
|
|
bbox: BBoxXYXY,
|
|
) -> Float[ndarray, "64 44"] | None:
|
|
"""Convert mask to standardized silhouette using bounding box.
|
|
|
|
Args:
|
|
mask: Binary mask array of shape (H, W) with dtype uint8.
|
|
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
|
|
|
|
Returns:
|
|
Standardized silhouette array of shape (64, 44) with dtype float32,
|
|
or None if conversion fails.
|
|
"""
|
|
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
|
mask_u8 = _fill_binary_holes(mask_u8)
|
|
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
|
return None
|
|
|
|
mask_shape = cast(tuple[int, int], mask_u8.shape)
|
|
h = int(mask_shape[0])
|
|
w = int(mask_shape[1])
|
|
bbox_sanitized = _sanitize_bbox(bbox, h, w)
|
|
if bbox_sanitized is None:
|
|
return None
|
|
|
|
x1, y1, x2, y2 = bbox_sanitized
|
|
cropped = mask_u8[y1:y2, x1:x2]
|
|
if cropped.size == 0:
|
|
return None
|
|
|
|
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
|
|
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
|
|
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
|
|
if int(row_nonzero.size) == 0:
|
|
return None
|
|
top = int(cast(np.int64, row_nonzero[0]))
|
|
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
|
|
tightened = cropped[top:bottom, :]
|
|
if tightened.size == 0:
|
|
return None
|
|
|
|
tight_shape = cast(tuple[int, int], tightened.shape)
|
|
tight_h = int(tight_shape[0])
|
|
tight_w = int(tight_shape[1])
|
|
if tight_h <= 0 or tight_w <= 0:
|
|
return None
|
|
|
|
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
|
|
resized = np.asarray(
|
|
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
|
|
dtype=np.uint8,
|
|
)
|
|
|
|
if resized_w >= SIL_FULL_WIDTH:
|
|
start = (resized_w - SIL_FULL_WIDTH) // 2
|
|
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
|
|
else:
|
|
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
|
|
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
|
|
normalized_64 = np.pad(
|
|
resized,
|
|
((0, 0), (pad_left, pad_right)),
|
|
mode="constant",
|
|
constant_values=0,
|
|
)
|
|
|
|
silhouette = np.asarray(
|
|
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
|
|
)
|
|
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
|
|
return None
|
|
|
|
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
|
|
np.float32
|
|
)
|
|
return cast(Float[ndarray, "64 44"], silhouette_norm)
|