Files
OpenGait/scripts/visualize_sharedalign_overlay.py

174 lines
5.9 KiB
Python

from __future__ import annotations
import pickle
import sys
from pathlib import Path
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import yaml
matplotlib.use("Agg")
REPO_ROOT = Path(__file__).resolve().parents[1]
if __package__ in {None, ""}:
sys.path.append(str(REPO_ROOT))
from datasets import pretreatment_heatmap as prep
FloatArray = npt.NDArray[np.float32]
SAMPLE_ID = "00000"
LABEL = "positive"
VIEW = "000_180"
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
POSE_PATH = Path(
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
)
SIL_PATH = Path(
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
)
CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml"
LIMB_RGB = np.asarray([0.18, 0.85, 0.96], dtype=np.float32)
JOINT_RGB = np.asarray([1.00, 0.45, 0.12], dtype=np.float32)
SIL_RGB = np.asarray([0.85, 0.97, 0.78], dtype=np.float32)
LIMB_ALPHA = 0.8
JOINT_ALPHA = 0.8
SIL_ALPHA = 0.3
def _normalize(array: npt.NDArray[np.generic]) -> FloatArray:
image = np.asarray(array, dtype=np.float32)
if image.max() > 1.0:
image = image / 255.0
return image
def _load_pickle(path: Path) -> FloatArray:
with path.open("rb") as handle:
return _normalize(pickle.load(handle))
def _load_mixed_sharedalign() -> FloatArray:
with CFG_PATH.open("r", encoding="utf-8") as handle:
cfg = yaml.safe_load(handle)
cfg = prep.replace_variables(cfg, cfg)
transform = prep.GenerateHeatmapTransform(
coco18tococo17_args=cfg["coco18tococo17_args"],
padkeypoints_args=cfg["padkeypoints_args"],
norm_args=cfg["norm_args"],
heatmap_generator_args=cfg["heatmap_generator_args"],
align_args=cfg["align_args"],
reduction="upstream",
sigma_limb=float(cfg["sigma_limb"]),
sigma_joint=float(cfg["sigma_joint"]),
)
with POSE_PATH.open("rb") as handle:
pose = np.asarray(pickle.load(handle), dtype=np.float32)
return _normalize(transform(pose))
def _overlay_frame(silhouette: FloatArray, limb: FloatArray, joint: FloatArray) -> FloatArray:
canvas = np.zeros((silhouette.shape[0], silhouette.shape[1], 4), dtype=np.float32)
sil_mask = np.clip(silhouette, 0.0, 1.0)
limb_mask = np.clip(limb, 0.0, 1.0)
joint_mask = np.clip(joint, 0.0, 1.0)
canvas[..., :3] += sil_mask[..., None] * SIL_RGB[None, None, :] * SIL_ALPHA
canvas[..., 3] = np.maximum(canvas[..., 3], sil_mask * SIL_ALPHA)
canvas[..., :3] += limb_mask[..., None] * LIMB_RGB[None, None, :] * LIMB_ALPHA
canvas[..., 3] = np.maximum(canvas[..., 3], limb_mask * LIMB_ALPHA)
canvas[..., :3] += joint_mask[..., None] * JOINT_RGB[None, None, :] * JOINT_ALPHA
canvas[..., 3] = np.maximum(canvas[..., 3], joint_mask * JOINT_ALPHA)
canvas[..., :3] = np.clip(canvas[..., :3], 0.0, 1.0)
canvas[..., 3] = np.clip(canvas[..., 3], 0.0, 1.0)
return canvas
def _bbox(mask: FloatArray, threshold: float = 0.05) -> tuple[int, int, int, int] | None:
ys, xs = np.where(mask > threshold)
if ys.size == 0:
return None
return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max())
def _align_silhouette_to_heatmap(silhouette: FloatArray, heatmap: FloatArray) -> FloatArray:
sil_box = _bbox(silhouette)
heat_box = _bbox(heatmap)
if sil_box is None or heat_box is None:
return silhouette
sy0, sy1, sx0, sx1 = sil_box
hy0, hy1, hx0, hx1 = heat_box
sil_crop = silhouette[sy0 : sy1 + 1, sx0 : sx1 + 1]
sil_h = max(sy1 - sy0 + 1, 1)
sil_w = max(sx1 - sx0 + 1, 1)
target_h = max(hy1 - hy0 + 1, 1)
scale = target_h / sil_h
target_w = max(int(round(sil_w * scale)), 1)
resized = cv2.resize(sil_crop, (target_w, target_h), interpolation=cv2.INTER_AREA)
canvas = np.zeros_like(silhouette)
heat_center_x = (hx0 + hx1) / 2.0
x0 = int(round(heat_center_x - target_w / 2.0))
x1 = x0 + target_w
src_x0 = max(0, -x0)
src_x1 = target_w - max(0, x1 - silhouette.shape[1])
dst_x0 = max(0, x0)
dst_x1 = min(silhouette.shape[1], x1)
if src_x0 < src_x1 and dst_x0 < dst_x1:
canvas[hy0 : hy1 + 1, dst_x0:dst_x1] = resized[:, src_x0:src_x1]
return canvas
def main() -> None:
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
silhouette = _load_pickle(SIL_PATH)
mixed = _load_mixed_sharedalign()
frames = np.linspace(0, silhouette.shape[0] - 1, num=6, dtype=int)
raw_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_preview.png"
aligned_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_aligned_preview.png"
for output_path, align_silhouette in ((raw_path, False), (aligned_path, True)):
fig, axes = plt.subplots(1, len(frames), figsize=(2.6 * len(frames), 3.6), dpi=220)
for axis, frame_idx in zip(axes, frames, strict=True):
limb = mixed[frame_idx, 0]
joint = mixed[frame_idx, 1]
sil_frame = silhouette[frame_idx]
if align_silhouette:
sil_frame = _align_silhouette_to_heatmap(sil_frame, np.maximum(limb, joint))
overlay = _overlay_frame(sil_frame, limb, joint)
axis.set_facecolor("black")
axis.imshow(overlay)
axis.set_xticks([])
axis.set_yticks([])
axis.set_title(f"Frame {frame_idx}", fontsize=9)
mode = "height+center aligned silhouette" if align_silhouette else "raw silhouette"
fig.suptitle(
(
f"{SAMPLE_ID}/{LABEL}/{VIEW}: {mode}, "
f"silhouette ({SIL_ALPHA}), limb ({LIMB_ALPHA}), joint ({JOINT_ALPHA})"
),
fontsize=12,
)
fig.tight_layout()
fig.savefig(output_path, bbox_inches="tight", facecolor="black")
plt.close(fig)
print(output_path)
if __name__ == "__main__":
main()