from __future__ import annotations import pickle import sys from pathlib import Path import cv2 import matplotlib import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import yaml matplotlib.use("Agg") REPO_ROOT = Path(__file__).resolve().parents[1] if __package__ in {None, ""}: sys.path.append(str(REPO_ROOT)) from datasets import pretreatment_heatmap as prep FloatArray = npt.NDArray[np.float32] SAMPLE_ID = "00000" LABEL = "positive" VIEW = "000_180" RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison" POSE_PATH = Path( f"/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl" ) SIL_PATH = Path( f"/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl" ) CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml" LIMB_RGB = np.asarray([0.18, 0.85, 0.96], dtype=np.float32) JOINT_RGB = np.asarray([1.00, 0.45, 0.12], dtype=np.float32) SIL_RGB = np.asarray([0.85, 0.97, 0.78], dtype=np.float32) LIMB_ALPHA = 0.8 JOINT_ALPHA = 0.8 SIL_ALPHA = 0.3 def _normalize(array: npt.NDArray[np.generic]) -> FloatArray: image = np.asarray(array, dtype=np.float32) if image.max() > 1.0: image = image / 255.0 return image def _load_pickle(path: Path) -> FloatArray: with path.open("rb") as handle: return _normalize(pickle.load(handle)) def _load_mixed_sharedalign() -> FloatArray: with CFG_PATH.open("r", encoding="utf-8") as handle: cfg = yaml.safe_load(handle) cfg = prep.replace_variables(cfg, cfg) transform = prep.GenerateHeatmapTransform( coco18tococo17_args=cfg["coco18tococo17_args"], padkeypoints_args=cfg["padkeypoints_args"], norm_args=cfg["norm_args"], heatmap_generator_args=cfg["heatmap_generator_args"], align_args=cfg["align_args"], reduction="upstream", sigma_limb=float(cfg["sigma_limb"]), sigma_joint=float(cfg["sigma_joint"]), ) with POSE_PATH.open("rb") as handle: pose = np.asarray(pickle.load(handle), dtype=np.float32) return _normalize(transform(pose)) def _overlay_frame(silhouette: FloatArray, limb: FloatArray, joint: FloatArray) -> FloatArray: canvas = np.zeros((silhouette.shape[0], silhouette.shape[1], 4), dtype=np.float32) sil_mask = np.clip(silhouette, 0.0, 1.0) limb_mask = np.clip(limb, 0.0, 1.0) joint_mask = np.clip(joint, 0.0, 1.0) canvas[..., :3] += sil_mask[..., None] * SIL_RGB[None, None, :] * SIL_ALPHA canvas[..., 3] = np.maximum(canvas[..., 3], sil_mask * SIL_ALPHA) canvas[..., :3] += limb_mask[..., None] * LIMB_RGB[None, None, :] * LIMB_ALPHA canvas[..., 3] = np.maximum(canvas[..., 3], limb_mask * LIMB_ALPHA) canvas[..., :3] += joint_mask[..., None] * JOINT_RGB[None, None, :] * JOINT_ALPHA canvas[..., 3] = np.maximum(canvas[..., 3], joint_mask * JOINT_ALPHA) canvas[..., :3] = np.clip(canvas[..., :3], 0.0, 1.0) canvas[..., 3] = np.clip(canvas[..., 3], 0.0, 1.0) return canvas def _bbox(mask: FloatArray, threshold: float = 0.05) -> tuple[int, int, int, int] | None: ys, xs = np.where(mask > threshold) if ys.size == 0: return None return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max()) def _align_silhouette_to_heatmap(silhouette: FloatArray, heatmap: FloatArray) -> FloatArray: sil_box = _bbox(silhouette) heat_box = _bbox(heatmap) if sil_box is None or heat_box is None: return silhouette sy0, sy1, sx0, sx1 = sil_box hy0, hy1, hx0, hx1 = heat_box sil_crop = silhouette[sy0 : sy1 + 1, sx0 : sx1 + 1] sil_h = max(sy1 - sy0 + 1, 1) sil_w = max(sx1 - sx0 + 1, 1) target_h = max(hy1 - hy0 + 1, 1) scale = target_h / sil_h target_w = max(int(round(sil_w * scale)), 1) resized = cv2.resize(sil_crop, (target_w, target_h), interpolation=cv2.INTER_AREA) canvas = np.zeros_like(silhouette) heat_center_x = (hx0 + hx1) / 2.0 x0 = int(round(heat_center_x - target_w / 2.0)) x1 = x0 + target_w src_x0 = max(0, -x0) src_x1 = target_w - max(0, x1 - silhouette.shape[1]) dst_x0 = max(0, x0) dst_x1 = min(silhouette.shape[1], x1) if src_x0 < src_x1 and dst_x0 < dst_x1: canvas[hy0 : hy1 + 1, dst_x0:dst_x1] = resized[:, src_x0:src_x1] return canvas def main() -> None: RESEARCH_ROOT.mkdir(parents=True, exist_ok=True) silhouette = _load_pickle(SIL_PATH) mixed = _load_mixed_sharedalign() frames = np.linspace(0, silhouette.shape[0] - 1, num=6, dtype=int) raw_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_preview.png" aligned_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_aligned_preview.png" for output_path, align_silhouette in ((raw_path, False), (aligned_path, True)): fig, axes = plt.subplots(1, len(frames), figsize=(2.6 * len(frames), 3.6), dpi=220) for axis, frame_idx in zip(axes, frames, strict=True): limb = mixed[frame_idx, 0] joint = mixed[frame_idx, 1] sil_frame = silhouette[frame_idx] if align_silhouette: sil_frame = _align_silhouette_to_heatmap(sil_frame, np.maximum(limb, joint)) overlay = _overlay_frame(sil_frame, limb, joint) axis.set_facecolor("black") axis.imshow(overlay) axis.set_xticks([]) axis.set_yticks([]) axis.set_title(f"Frame {frame_idx}", fontsize=9) mode = "height+center aligned silhouette" if align_silhouette else "raw silhouette" fig.suptitle( ( f"{SAMPLE_ID}/{LABEL}/{VIEW}: {mode}, " f"silhouette ({SIL_ALPHA}), limb ({LIMB_ALPHA}), joint ({JOINT_ALPHA})" ), fontsize=12, ) fig.tight_layout() fig.savefig(output_path, bbox_inches="tight", facecolor="black") plt.close(fig) print(output_path) if __name__ == "__main__": main()