From f6859cfa79a134cbed976c50653824c4278afc59 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Mon, 2 Mar 2026 16:45:20 +0800 Subject: [PATCH] fix(demo): harden mask hole-filling for border-touching cases --- opengait/demo/preprocess.py | 24 ++++++++++++++++++++++++ tests/demo/test_preprocess.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/opengait/demo/preprocess.py b/opengait/demo/preprocess.py index 9801cc7..4977d2a 100644 --- a/opengait/demo/preprocess.py +++ b/opengait/demo/preprocess.py @@ -64,6 +64,29 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]: return cast(NDArray[np.generic], np.asarray(current)) +def _fill_binary_holes(mask_u8: UInt8Array) -> UInt8Array: + mask_bin = np.where(mask_u8 > 0, np.uint8(255), np.uint8(0)).astype(np.uint8) + h, w = cast(tuple[int, int], mask_bin.shape) + if h <= 2 or w <= 2: + return mask_bin + + seed_candidates = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)] + seed: tuple[int, int] | None = None + for x, y in seed_candidates: + if int(mask_bin[y, x]) == 0: + seed = (x, y) + break + if seed is None: + return mask_bin + + flood = mask_bin.copy() + flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8) + _ = cv2.floodFill(flood, flood_mask, seed, 255) + holes = cv2.bitwise_not(flood) + filled = cv2.bitwise_or(mask_bin, holes) + return cast(UInt8Array, filled) + + def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None: """Extract bounding box from binary mask in XYXY format. @@ -248,6 +271,7 @@ def mask_to_silhouette( or None if conversion fails. """ mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8) + mask_u8 = _fill_binary_holes(mask_u8) if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA: return None diff --git a/tests/demo/test_preprocess.py b/tests/demo/test_preprocess.py index 9d58c39..7290db6 100644 --- a/tests/demo/test_preprocess.py +++ b/tests/demo/test_preprocess.py @@ -95,6 +95,34 @@ class TestMaskToSilhouette: np.testing.assert_array_equal(result1, result2) np.testing.assert_array_equal(result2, result3) + def test_hole_inside_mask_is_filled(self) -> None: + h, w = 200, 160 + mask = np.zeros((h, w), dtype=np.uint8) + mask[30:170, 40:120] = 255 + mask[80:120, 70:90] = 0 + bbox = (40, 30, 120, 170) + + result = mask_to_silhouette(mask, bbox) + + assert result is not None + result_arr = cast(NDArray[np.float32], result) + hole_patch = result_arr[26:38, 18:26] + assert float(np.mean(hole_patch)) > 0.8 + + def test_hole_fill_works_when_mask_touches_corner(self) -> None: + h, w = 220, 180 + mask = np.zeros((h, w), dtype=np.uint8) + mask[0:180, 0:130] = 255 + mask[70:120, 55:95] = 0 + bbox = (0, 0, 130, 180) + + result = mask_to_silhouette(mask, bbox) + + assert result is not None + result_arr = cast(NDArray[np.float32], result) + hole_patch = result_arr[24:40, 16:28] + assert float(np.mean(hole_patch)) > 0.75 + def test_tall_narrow_mask_valid_output(self) -> None: """Tall narrow mask should produce valid silhouette.""" h, w = 400, 50