fix(demo): harden mask hole-filling for border-touching cases
This commit is contained in:
@@ -64,6 +64,29 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
|||||||
return cast(NDArray[np.generic], np.asarray(current))
|
return cast(NDArray[np.generic], np.asarray(current))
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_binary_holes(mask_u8: UInt8Array) -> UInt8Array:
|
||||||
|
mask_bin = np.where(mask_u8 > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||||
|
h, w = cast(tuple[int, int], mask_bin.shape)
|
||||||
|
if h <= 2 or w <= 2:
|
||||||
|
return mask_bin
|
||||||
|
|
||||||
|
seed_candidates = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)]
|
||||||
|
seed: tuple[int, int] | None = None
|
||||||
|
for x, y in seed_candidates:
|
||||||
|
if int(mask_bin[y, x]) == 0:
|
||||||
|
seed = (x, y)
|
||||||
|
break
|
||||||
|
if seed is None:
|
||||||
|
return mask_bin
|
||||||
|
|
||||||
|
flood = mask_bin.copy()
|
||||||
|
flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
|
||||||
|
_ = cv2.floodFill(flood, flood_mask, seed, 255)
|
||||||
|
holes = cv2.bitwise_not(flood)
|
||||||
|
filled = cv2.bitwise_or(mask_bin, holes)
|
||||||
|
return cast(UInt8Array, filled)
|
||||||
|
|
||||||
|
|
||||||
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
|
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
|
||||||
"""Extract bounding box from binary mask in XYXY format.
|
"""Extract bounding box from binary mask in XYXY format.
|
||||||
|
|
||||||
@@ -248,6 +271,7 @@ def mask_to_silhouette(
|
|||||||
or None if conversion fails.
|
or None if conversion fails.
|
||||||
"""
|
"""
|
||||||
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||||
|
mask_u8 = _fill_binary_holes(mask_u8)
|
||||||
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,34 @@ class TestMaskToSilhouette:
|
|||||||
np.testing.assert_array_equal(result1, result2)
|
np.testing.assert_array_equal(result1, result2)
|
||||||
np.testing.assert_array_equal(result2, result3)
|
np.testing.assert_array_equal(result2, result3)
|
||||||
|
|
||||||
|
def test_hole_inside_mask_is_filled(self) -> None:
|
||||||
|
h, w = 200, 160
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
mask[30:170, 40:120] = 255
|
||||||
|
mask[80:120, 70:90] = 0
|
||||||
|
bbox = (40, 30, 120, 170)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_arr = cast(NDArray[np.float32], result)
|
||||||
|
hole_patch = result_arr[26:38, 18:26]
|
||||||
|
assert float(np.mean(hole_patch)) > 0.8
|
||||||
|
|
||||||
|
def test_hole_fill_works_when_mask_touches_corner(self) -> None:
|
||||||
|
h, w = 220, 180
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
mask[0:180, 0:130] = 255
|
||||||
|
mask[70:120, 55:95] = 0
|
||||||
|
bbox = (0, 0, 130, 180)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_arr = cast(NDArray[np.float32], result)
|
||||||
|
hole_patch = result_arr[24:40, 16:28]
|
||||||
|
assert float(np.mean(hole_patch)) > 0.75
|
||||||
|
|
||||||
def test_tall_narrow_mask_valid_output(self) -> None:
|
def test_tall_narrow_mask_valid_output(self) -> None:
|
||||||
"""Tall narrow mask should produce valid silhouette."""
|
"""Tall narrow mask should produce valid silhouette."""
|
||||||
h, w = 400, 50
|
h, w = 400, 50
|
||||||
|
|||||||
Reference in New Issue
Block a user