174 lines
5.9 KiB
Python
174 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import pickle
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import yaml
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
if __package__ in {None, ""}:
|
|
sys.path.append(str(REPO_ROOT))
|
|
|
|
from datasets import pretreatment_heatmap as prep
|
|
|
|
FloatArray = npt.NDArray[np.float32]
|
|
|
|
SAMPLE_ID = "00000"
|
|
LABEL = "positive"
|
|
VIEW = "000_180"
|
|
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
|
|
POSE_PATH = Path(
|
|
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
|
|
)
|
|
SIL_PATH = Path(
|
|
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
|
|
)
|
|
CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml"
|
|
|
|
LIMB_RGB = np.asarray([0.18, 0.85, 0.96], dtype=np.float32)
|
|
JOINT_RGB = np.asarray([1.00, 0.45, 0.12], dtype=np.float32)
|
|
SIL_RGB = np.asarray([0.85, 0.97, 0.78], dtype=np.float32)
|
|
LIMB_ALPHA = 0.8
|
|
JOINT_ALPHA = 0.8
|
|
SIL_ALPHA = 0.3
|
|
|
|
|
|
def _normalize(array: npt.NDArray[np.generic]) -> FloatArray:
|
|
image = np.asarray(array, dtype=np.float32)
|
|
if image.max() > 1.0:
|
|
image = image / 255.0
|
|
return image
|
|
|
|
|
|
def _load_pickle(path: Path) -> FloatArray:
|
|
with path.open("rb") as handle:
|
|
return _normalize(pickle.load(handle))
|
|
|
|
|
|
def _load_mixed_sharedalign() -> FloatArray:
|
|
with CFG_PATH.open("r", encoding="utf-8") as handle:
|
|
cfg = yaml.safe_load(handle)
|
|
cfg = prep.replace_variables(cfg, cfg)
|
|
transform = prep.GenerateHeatmapTransform(
|
|
coco18tococo17_args=cfg["coco18tococo17_args"],
|
|
padkeypoints_args=cfg["padkeypoints_args"],
|
|
norm_args=cfg["norm_args"],
|
|
heatmap_generator_args=cfg["heatmap_generator_args"],
|
|
align_args=cfg["align_args"],
|
|
reduction="upstream",
|
|
sigma_limb=float(cfg["sigma_limb"]),
|
|
sigma_joint=float(cfg["sigma_joint"]),
|
|
)
|
|
with POSE_PATH.open("rb") as handle:
|
|
pose = np.asarray(pickle.load(handle), dtype=np.float32)
|
|
return _normalize(transform(pose))
|
|
|
|
|
|
def _overlay_frame(silhouette: FloatArray, limb: FloatArray, joint: FloatArray) -> FloatArray:
|
|
canvas = np.zeros((silhouette.shape[0], silhouette.shape[1], 4), dtype=np.float32)
|
|
|
|
sil_mask = np.clip(silhouette, 0.0, 1.0)
|
|
limb_mask = np.clip(limb, 0.0, 1.0)
|
|
joint_mask = np.clip(joint, 0.0, 1.0)
|
|
|
|
canvas[..., :3] += sil_mask[..., None] * SIL_RGB[None, None, :] * SIL_ALPHA
|
|
canvas[..., 3] = np.maximum(canvas[..., 3], sil_mask * SIL_ALPHA)
|
|
|
|
canvas[..., :3] += limb_mask[..., None] * LIMB_RGB[None, None, :] * LIMB_ALPHA
|
|
canvas[..., 3] = np.maximum(canvas[..., 3], limb_mask * LIMB_ALPHA)
|
|
|
|
canvas[..., :3] += joint_mask[..., None] * JOINT_RGB[None, None, :] * JOINT_ALPHA
|
|
canvas[..., 3] = np.maximum(canvas[..., 3], joint_mask * JOINT_ALPHA)
|
|
|
|
canvas[..., :3] = np.clip(canvas[..., :3], 0.0, 1.0)
|
|
canvas[..., 3] = np.clip(canvas[..., 3], 0.0, 1.0)
|
|
return canvas
|
|
|
|
|
|
def _bbox(mask: FloatArray, threshold: float = 0.05) -> tuple[int, int, int, int] | None:
|
|
ys, xs = np.where(mask > threshold)
|
|
if ys.size == 0:
|
|
return None
|
|
return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max())
|
|
|
|
|
|
def _align_silhouette_to_heatmap(silhouette: FloatArray, heatmap: FloatArray) -> FloatArray:
|
|
sil_box = _bbox(silhouette)
|
|
heat_box = _bbox(heatmap)
|
|
if sil_box is None or heat_box is None:
|
|
return silhouette
|
|
|
|
sy0, sy1, sx0, sx1 = sil_box
|
|
hy0, hy1, hx0, hx1 = heat_box
|
|
sil_crop = silhouette[sy0 : sy1 + 1, sx0 : sx1 + 1]
|
|
sil_h = max(sy1 - sy0 + 1, 1)
|
|
sil_w = max(sx1 - sx0 + 1, 1)
|
|
target_h = max(hy1 - hy0 + 1, 1)
|
|
scale = target_h / sil_h
|
|
target_w = max(int(round(sil_w * scale)), 1)
|
|
resized = cv2.resize(sil_crop, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
|
canvas = np.zeros_like(silhouette)
|
|
heat_center_x = (hx0 + hx1) / 2.0
|
|
x0 = int(round(heat_center_x - target_w / 2.0))
|
|
x1 = x0 + target_w
|
|
|
|
src_x0 = max(0, -x0)
|
|
src_x1 = target_w - max(0, x1 - silhouette.shape[1])
|
|
dst_x0 = max(0, x0)
|
|
dst_x1 = min(silhouette.shape[1], x1)
|
|
|
|
if src_x0 < src_x1 and dst_x0 < dst_x1:
|
|
canvas[hy0 : hy1 + 1, dst_x0:dst_x1] = resized[:, src_x0:src_x1]
|
|
return canvas
|
|
|
|
|
|
def main() -> None:
|
|
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
|
|
|
|
silhouette = _load_pickle(SIL_PATH)
|
|
mixed = _load_mixed_sharedalign()
|
|
frames = np.linspace(0, silhouette.shape[0] - 1, num=6, dtype=int)
|
|
|
|
raw_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_preview.png"
|
|
aligned_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_aligned_preview.png"
|
|
|
|
for output_path, align_silhouette in ((raw_path, False), (aligned_path, True)):
|
|
fig, axes = plt.subplots(1, len(frames), figsize=(2.6 * len(frames), 3.6), dpi=220)
|
|
for axis, frame_idx in zip(axes, frames, strict=True):
|
|
limb = mixed[frame_idx, 0]
|
|
joint = mixed[frame_idx, 1]
|
|
sil_frame = silhouette[frame_idx]
|
|
if align_silhouette:
|
|
sil_frame = _align_silhouette_to_heatmap(sil_frame, np.maximum(limb, joint))
|
|
overlay = _overlay_frame(sil_frame, limb, joint)
|
|
axis.set_facecolor("black")
|
|
axis.imshow(overlay)
|
|
axis.set_xticks([])
|
|
axis.set_yticks([])
|
|
axis.set_title(f"Frame {frame_idx}", fontsize=9)
|
|
|
|
mode = "height+center aligned silhouette" if align_silhouette else "raw silhouette"
|
|
fig.suptitle(
|
|
(
|
|
f"{SAMPLE_ID}/{LABEL}/{VIEW}: {mode}, "
|
|
f"silhouette ({SIL_ALPHA}), limb ({LIMB_ALPHA}), joint ({JOINT_ALPHA})"
|
|
),
|
|
fontsize=12,
|
|
)
|
|
fig.tight_layout()
|
|
fig.savefig(output_path, bbox_inches="tight", facecolor="black")
|
|
plt.close(fig)
|
|
print(output_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|