Files
crosstyan 00fcda4fe3 feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
2026-03-07 18:14:13 +08:00

336 lines
11 KiB
Python

from collections.abc import Callable
import math
from typing import cast
import cv2
from beartype import beartype
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
SIL_HEIGHT = 64
SIL_WIDTH = 44
SIL_FULL_WIDTH = 64
SIDE_CUT = 10
MIN_MASK_AREA = 500
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32]
BBoxXYXY = tuple[int, int, int, int]
"""
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
"""
def _read_attr(container: object, key: str) -> object | None:
if isinstance(container, dict):
dict_obj = cast(dict[object, object], container)
return dict_obj.get(key)
try:
return cast(object, object.__getattribute__(container, key))
except AttributeError:
return None
def _to_numpy_array(value: object) -> NDArray[np.generic]:
current: object = value
if isinstance(current, np.ndarray):
return current
detach_obj = _read_attr(current, "detach")
if callable(detach_obj):
detach_fn = cast(Callable[[], object], detach_obj)
current = detach_fn()
cpu_obj = _read_attr(current, "cpu")
if callable(cpu_obj):
cpu_fn = cast(Callable[[], object], cpu_obj)
current = cpu_fn()
numpy_obj = _read_attr(current, "numpy")
if callable(numpy_obj):
numpy_fn = cast(Callable[[], object], numpy_obj)
as_numpy = numpy_fn()
if isinstance(as_numpy, np.ndarray):
return as_numpy
return cast(NDArray[np.generic], np.asarray(current))
def _fill_binary_holes(mask_u8: UInt8Array) -> UInt8Array:
mask_bin = np.where(mask_u8 > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
h, w = cast(tuple[int, int], mask_bin.shape)
if h <= 2 or w <= 2:
return mask_bin
seed_candidates = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)]
seed: tuple[int, int] | None = None
for x, y in seed_candidates:
if int(mask_bin[y, x]) == 0:
seed = (x, y)
break
if seed is None:
return mask_bin
flood = mask_bin.copy()
flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
_ = cv2.floodFill(flood, flood_mask, seed, 255)
holes = cv2.bitwise_not(flood)
filled = cv2.bitwise_or(mask_bin, holes)
return cast(UInt8Array, filled)
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
"""Extract bounding box from binary mask in XYXY format.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
Returns:
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
"""
mask_u8 = np.asarray(mask, dtype=np.uint8)
coords = np.argwhere(mask_u8 > 0)
if int(coords.size) == 0:
return None
ys = coords[:, 0].astype(np.int64)
xs = coords[:, 1].astype(np.int64)
x1 = int(np.min(xs))
x2 = int(np.max(xs)) + 1
y1 = int(np.min(ys))
y2 = int(np.max(ys)) + 1
if x2 <= x1 or y2 <= y1:
return None
return (x1, y1, x2, y2)
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
"""Sanitize bounding box to ensure it's within image bounds.
Args:
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
height: Image height.
width: Image width.
Returns:
Sanitized bounding box in XYXY format, or None if invalid.
"""
x1, y1, x2, y2 = bbox
x1c = max(0, min(int(x1), width - 1))
y1c = max(0, min(int(y1), height - 1))
x2c = max(0, min(int(x2), width))
y2c = max(0, min(int(y2), height))
if x2c <= x1c or y2c <= y1c:
return None
return (x1c, y1c, x2c, y2c)
@jaxtyped(typechecker=beartype)
def frame_to_person_mask(
result: object, min_area: int = MIN_MASK_AREA
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
"""Extract person mask and bounding box from detection result.
Args:
result: Detection results object with boxes and masks attributes.
min_area: Minimum mask area to consider valid.
Returns:
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
or None if no valid detections.
"""
masks_obj = _read_attr(result, "masks")
if masks_obj is None:
return None
masks_data_obj = _read_attr(masks_obj, "data")
if masks_data_obj is None:
return None
masks_raw = _to_numpy_array(masks_data_obj)
masks_float = np.asarray(masks_raw, dtype=np.float32)
if masks_float.ndim == 2:
masks_float = masks_float[np.newaxis, ...]
if masks_float.ndim != 3:
return None
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
if mask_count <= 0:
return None
box_values: list[tuple[float, float, float, float]] | None = None
boxes_obj = _read_attr(result, "boxes")
if boxes_obj is not None:
xyxy_obj = _read_attr(boxes_obj, "xyxy")
if xyxy_obj is not None:
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
x1f = cast(np.float64, xyxy_2d[0, 0])
y1f = cast(np.float64, xyxy_2d[0, 1])
x2f = cast(np.float64, xyxy_2d[0, 2])
y2f = cast(np.float64, xyxy_2d[0, 3])
box_values = [
(
float(x1f),
float(y1f),
float(x2f),
float(y2f),
)
]
elif xyxy_raw.ndim == 2:
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
if int(shape_2d[1]) >= 4:
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
box_values = []
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
box_values.append(
(
float(x1f),
float(y1f),
float(x2f),
float(y2f),
)
)
best_area = -1
best_mask: UInt8[ndarray, "h w"] | None = None
best_bbox: BBoxXYXY | None = None
for idx in range(mask_count):
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
if mask_float.ndim != 2:
continue
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
area = int(np.count_nonzero(mask_u8))
if area < min_area:
continue
bbox: BBoxXYXY | None = None
shape_2d = cast(tuple[int, int], mask_binary.shape)
h = int(shape_2d[0])
w = int(shape_2d[1])
if box_values is not None:
box_count = len(box_values)
if idx >= box_count:
continue
row0, row1, row2, row3 = box_values[idx]
bbox_candidate = (
int(math.floor(row0)),
int(math.floor(row1)),
int(math.ceil(row2)),
int(math.ceil(row3)),
)
bbox = _sanitize_bbox(bbox_candidate, h, w)
if bbox is None:
bbox = _bbox_from_mask(mask_u8)
if bbox is None:
continue
if area > best_area:
best_area = area
best_mask = mask_u8
best_bbox = bbox
if best_mask is None or best_bbox is None:
return None
return best_mask, best_bbox
@jaxtyped(typechecker=beartype)
def mask_to_silhouette(
mask: UInt8[ndarray, "h w"],
bbox: BBoxXYXY,
) -> Float[ndarray, "64 44"] | None:
"""Convert mask to standardized silhouette using bounding box.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
Returns:
Standardized silhouette array of shape (64, 44) with dtype float32,
or None if conversion fails.
"""
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
mask_u8 = _fill_binary_holes(mask_u8)
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
return None
mask_shape = cast(tuple[int, int], mask_u8.shape)
h = int(mask_shape[0])
w = int(mask_shape[1])
bbox_sanitized = _sanitize_bbox(bbox, h, w)
if bbox_sanitized is None:
return None
x1, y1, x2, y2 = bbox_sanitized
cropped = mask_u8[y1:y2, x1:x2]
if cropped.size == 0:
return None
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
if int(row_nonzero.size) == 0:
return None
top = int(cast(np.int64, row_nonzero[0]))
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
tightened = cropped[top:bottom, :]
if tightened.size == 0:
return None
tight_shape = cast(tuple[int, int], tightened.shape)
tight_h = int(tight_shape[0])
tight_w = int(tight_shape[1])
if tight_h <= 0 or tight_w <= 0:
return None
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
resized = np.asarray(
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
dtype=np.uint8,
)
if resized_w >= SIL_FULL_WIDTH:
start = (resized_w - SIL_FULL_WIDTH) // 2
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
else:
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
normalized_64 = np.pad(
resized,
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
silhouette = np.asarray(
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
)
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
return None
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
np.float32
)
return cast(Float[ndarray, "64 44"], silhouette_norm)