Add proxy eval and skeleton experiment tooling
This commit is contained in:
@@ -517,7 +517,15 @@ class GatherTransform(object):
|
||||
"""
|
||||
Gather the different transforms.
|
||||
"""
|
||||
def __init__(self, base_transform, transform_bone, transform_joint, align_transform=None):
|
||||
def __init__(
|
||||
self,
|
||||
base_transform,
|
||||
transform_bone,
|
||||
transform_joint,
|
||||
align_transform=None,
|
||||
limb_gain: float = 1.0,
|
||||
joint_gain: float = 1.0,
|
||||
) -> None:
|
||||
|
||||
"""
|
||||
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
|
||||
@@ -528,6 +536,22 @@ class GatherTransform(object):
|
||||
self.transform_bone = transform_bone
|
||||
self.transform_joint = transform_joint
|
||||
self.align_transform = align_transform
|
||||
self.limb_gain = limb_gain
|
||||
self.joint_gain = joint_gain
|
||||
|
||||
def _apply_channel_gains(self, heatmap: np.ndarray) -> np.ndarray:
|
||||
if self.limb_gain == 1.0 and self.joint_gain == 1.0:
|
||||
return heatmap
|
||||
|
||||
original_dtype = heatmap.dtype
|
||||
scaled = heatmap.astype(np.float32, copy=True)
|
||||
scaled[:, 0] *= self.limb_gain
|
||||
scaled[:, 1] *= self.joint_gain
|
||||
scaled = np.clip(scaled, 0.0, 255.0)
|
||||
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
return scaled.astype(original_dtype)
|
||||
return scaled.astype(original_dtype)
|
||||
|
||||
def __call__(self, pose_data):
|
||||
x = self.base_transform(pose_data)
|
||||
@@ -536,38 +560,109 @@ class GatherTransform(object):
|
||||
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
|
||||
if self.align_transform is not None:
|
||||
heatmap = self.align_transform(heatmap)
|
||||
return heatmap
|
||||
return self._apply_channel_gains(heatmap)
|
||||
|
||||
AlignmentScope = Literal["frame", "sequence"]
|
||||
AlignmentCropMode = Literal["square_center", "bbox_pad"]
|
||||
|
||||
|
||||
class HeatmapAlignment():
|
||||
def __init__(self, align=True, final_img_size=64, offset=0, heatmap_image_size=128) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
align: bool = True,
|
||||
final_img_size: int = 64,
|
||||
offset: int = 0,
|
||||
heatmap_image_size: int = 128,
|
||||
scope: AlignmentScope = "frame",
|
||||
crop_mode: AlignmentCropMode = "square_center",
|
||||
preserve_aspect_ratio: bool = False,
|
||||
) -> None:
|
||||
self.align = align
|
||||
self.final_img_size = final_img_size
|
||||
self.offset = offset
|
||||
self.heatmap_image_size = heatmap_image_size
|
||||
self.scope = scope
|
||||
self.crop_mode = crop_mode
|
||||
self.preserve_aspect_ratio = preserve_aspect_ratio
|
||||
|
||||
def _compute_crop_bounds(
|
||||
self,
|
||||
heatmap: np.ndarray,
|
||||
) -> tuple[int, int, int, int] | None:
|
||||
support_map = heatmap.max(axis=0)
|
||||
y_sum = support_map.sum(axis=1)
|
||||
x_sum = support_map.sum(axis=0)
|
||||
nonzero_rows = np.flatnonzero(y_sum != 0)
|
||||
nonzero_cols = np.flatnonzero(x_sum != 0)
|
||||
if nonzero_rows.size == 0:
|
||||
return None
|
||||
if nonzero_cols.size == 0:
|
||||
return None
|
||||
|
||||
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
|
||||
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
|
||||
|
||||
if self.crop_mode == "bbox_pad":
|
||||
x_left = max(int(nonzero_cols[0]) - self.offset, 0)
|
||||
x_right = min(int(nonzero_cols[-1]) + self.offset + 1, self.heatmap_image_size)
|
||||
return y_top, y_btm, x_left, x_right
|
||||
|
||||
height = y_btm - y_top + 1
|
||||
x_center = self.heatmap_image_size // 2
|
||||
x_left = max(x_center - (height // 2), 0)
|
||||
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
|
||||
return y_top, y_btm, x_left, x_right
|
||||
|
||||
def _resize_and_pad(self, cropped_heatmap: np.ndarray) -> np.ndarray:
|
||||
_, src_h, src_w = cropped_heatmap.shape
|
||||
if src_h <= 0 or src_w <= 0:
|
||||
return np.zeros(
|
||||
(cropped_heatmap.shape[0], self.final_img_size, self.final_img_size),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
scale = float(self.final_img_size) / float(max(src_h, src_w))
|
||||
resized_h = max(1, int(round(src_h * scale)))
|
||||
resized_w = max(1, int(round(src_w * scale)))
|
||||
|
||||
resized = np.stack([
|
||||
cv2.resize(channel, (resized_w, resized_h), interpolation=cv2.INTER_AREA)
|
||||
for channel in cropped_heatmap
|
||||
], axis=0)
|
||||
|
||||
canvas = np.zeros(
|
||||
(cropped_heatmap.shape[0], self.final_img_size, self.final_img_size),
|
||||
dtype=np.float32,
|
||||
)
|
||||
y_offset = (self.final_img_size - resized_h) // 2
|
||||
x_offset = (self.final_img_size - resized_w) // 2
|
||||
canvas[:, y_offset:y_offset + resized_h, x_offset:x_offset + resized_w] = resized
|
||||
return canvas
|
||||
|
||||
def _crop_and_resize(
|
||||
self,
|
||||
heatmap: np.ndarray,
|
||||
crop_bounds: tuple[int, int, int, int] | None,
|
||||
) -> np.ndarray:
|
||||
raw_heatmap = heatmap
|
||||
if crop_bounds is not None:
|
||||
y_top, y_btm, x_left, x_right = crop_bounds
|
||||
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
|
||||
if self.preserve_aspect_ratio:
|
||||
return self._resize_and_pad(raw_heatmap)
|
||||
|
||||
return np.stack([
|
||||
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
||||
for channel in raw_heatmap
|
||||
], axis=0)
|
||||
|
||||
def center_crop(self, heatmap):
|
||||
"""
|
||||
Input: [C, heatmap_image_size, heatmap_image_size]
|
||||
Output: [C, final_img_size, final_img_size]
|
||||
"""
|
||||
raw_heatmap = heatmap
|
||||
if self.align:
|
||||
support_map = raw_heatmap.max(axis=0)
|
||||
y_sum = support_map.sum(axis=1)
|
||||
nonzero_rows = np.flatnonzero(y_sum != 0)
|
||||
if nonzero_rows.size != 0:
|
||||
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
|
||||
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
|
||||
height = y_btm - y_top + 1
|
||||
x_center = self.heatmap_image_size // 2
|
||||
x_left = max(x_center - (height // 2), 0)
|
||||
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
|
||||
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
|
||||
resized = np.stack([
|
||||
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
||||
for channel in raw_heatmap
|
||||
], axis=0)
|
||||
return resized # [C, final_img_size, final_img_size]
|
||||
crop_bounds = self._compute_crop_bounds(heatmap) if self.align else None
|
||||
return self._crop_and_resize(heatmap, crop_bounds) # [C, final_img_size, final_img_size]
|
||||
|
||||
def __call__(self, heatmap_imgs):
|
||||
"""
|
||||
@@ -576,7 +671,14 @@ class HeatmapAlignment():
|
||||
"""
|
||||
original_dtype = heatmap_imgs.dtype
|
||||
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
|
||||
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32)
|
||||
if self.align and self.scope == "sequence":
|
||||
sequence_crop_bounds = self._compute_crop_bounds(heatmap_imgs.max(axis=0))
|
||||
heatmap_imgs = np.array(
|
||||
[self._crop_and_resize(heatmap_img, sequence_crop_bounds) for heatmap_img in heatmap_imgs],
|
||||
dtype=np.float32,
|
||||
)
|
||||
else:
|
||||
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32)
|
||||
heatmap_imgs = heatmap_imgs * 255.0
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
return np.clip(heatmap_imgs, 0.0, 255.0).astype(original_dtype)
|
||||
@@ -591,6 +693,8 @@ def GenerateHeatmapTransform(
|
||||
reduction: Literal["upstream", "max", "sum"] = "upstream",
|
||||
sigma_limb: float | None = None,
|
||||
sigma_joint: float | None = None,
|
||||
channel_gain_limb: float | None = None,
|
||||
channel_gain_joint: float | None = None,
|
||||
):
|
||||
|
||||
base_transform = T.Compose([
|
||||
@@ -636,6 +740,8 @@ def GenerateHeatmapTransform(
|
||||
transform_bone,
|
||||
transform_joint,
|
||||
HeatmapAlignment(**align_args),
|
||||
limb_gain=1.0 if channel_gain_limb is None else channel_gain_limb,
|
||||
joint_gain=1.0 if channel_gain_joint is None else channel_gain_joint,
|
||||
) # [T, 2, H, W]
|
||||
])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user