Add resumable ScoNet skeleton training diagnostics

This commit is contained in:
2026-03-09 15:57:13 +08:00
parent 4e0b0a18dc
commit 36aef46a0d
15 changed files with 1226 additions and 44 deletions
+28
View File
@@ -75,6 +75,7 @@ The silhouette and skeleton-map pipelines are different experiments and should n
* `Scoliosis1K-sil-pkl` is the silhouette modality used by the standard ScoNet configs.
* pose-derived heatmap roots such as `Scoliosis1K_sigma_8.0/pkl` or DRF exports are skeleton-map inputs and require `in_channel: 2`.
* DRF does **not** use the silhouette stream as an input. It uses `0_heatmap.pkl` plus `1_pav.pkl`.
Naming note:
@@ -89,6 +90,18 @@ A strong silhouette checkpoint does not validate the skeleton-map path. In parti
So if you are debugging DRF or `ScoNet-MT-ske` reproduction, do not use `ScoNet-20000-better.pt` as evidence that the heatmap preprocessing is correct.
### Overlay caveat
Do not treat a direct overlay between `Scoliosis1K-sil-pkl` and pose-derived skeleton maps as a valid alignment test.
Reason:
* the released silhouette modality is an estimated segmentation output from `PP-HumanSeg v2`
* the released pose modality is an estimated keypoint output from `ViTPose`
* the two modalities are normalized by different preprocessing pipelines before they reach OpenGait
So a silhouette-vs-skeleton mismatch in a debug figure is usually a cross-modality frame-of-reference issue, not proof that the raw dataset is bad. The more important check for skeleton-map debugging is whether the **limb and joint channels align with each other** inside `0_heatmap.pkl`.
---
## Pose-to-Heatmap Conversion
@@ -146,6 +159,21 @@ If you explicitly want train-only PAV min-max statistics, add:
--stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json
```
### Heatmap debugging notes
Current confirmed findings from local debugging:
* the raw pose dataset itself looks healthy; poor `ScoNet-MT-ske` results are not explained by obvious missing-joint collapse
* a larger heatmap sigma can materially blur away the articulated structure; `sigma=8` was much broader than the silhouette geometry, while smaller sigma values recovered more structure
* an earlier bug aligned the limb and joint channels separately; that made the two channels of `0_heatmap.pkl` slightly misregistered
* the heatmap path is now patched so limb and joint channels share one alignment crop
Remaining caution:
* the exported skeleton map is stored as `64x64`
* if the runtime config uses `BaseSilCuttingTransform`, the network actually sees `64x44`
* that symmetric left/right crop is not automatically wrong, but it is still a meaningful ablation point for skeleton-map experiments
The output layout is:
```text
+54 -29
View File
@@ -8,7 +8,8 @@ import pickle
import argparse
import numpy as np
from glob import glob
from typing import Literal
from copy import deepcopy
from typing import Any, Literal
from tqdm import tqdm
import matplotlib.cm as cm
import torch.distributed as dist
@@ -516,7 +517,7 @@ class GatherTransform(object):
"""
Gather the different transforms.
"""
def __init__(self, base_transform, transform_bone, transform_joint):
def __init__(self, base_transform, transform_bone, transform_joint, align_transform=None):
"""
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
@@ -526,12 +527,15 @@ class GatherTransform(object):
self.base_transform = base_transform
self.transform_bone = transform_bone
self.transform_joint = transform_joint
self.align_transform = align_transform
def __call__(self, pose_data):
x = self.base_transform(pose_data)
heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
if self.align_transform is not None:
heatmap = self.align_transform(heatmap)
return heatmap
class HeatmapAlignment():
@@ -543,23 +547,32 @@ class HeatmapAlignment():
def center_crop(self, heatmap):
"""
Input: [1, heatmap_image_size, heatmap_image_size]
Output: [1, final_img_size, final_img_size]
Input: [C, heatmap_image_size, heatmap_image_size]
Output: [C, final_img_size, final_img_size]
"""
raw_heatmap = heatmap[0]
if self.align:
y_sum = raw_heatmap.sum(axis=1)
y_top = (y_sum != 0).argmax(axis=0)
y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
height = y_btm - y_top + 1
raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1]
raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size]
raw_heatmap = heatmap
if self.align:
support_map = raw_heatmap.max(axis=0)
y_sum = support_map.sum(axis=1)
nonzero_rows = np.flatnonzero(y_sum != 0)
if nonzero_rows.size != 0:
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
height = y_btm - y_top + 1
x_center = self.heatmap_image_size // 2
x_left = max(x_center - (height // 2), 0)
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
resized = np.stack([
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
for channel in raw_heatmap
], axis=0)
return resized # [C, final_img_size, final_img_size]
def __call__(self, heatmap_imgs):
"""
heatmap_imgs: (T, 1, raw_size, raw_size)
return (T, 1, final_img_size, final_img_size)
heatmap_imgs: (T, C, raw_size, raw_size)
return (T, C, final_img_size, final_img_size)
"""
original_dtype = heatmap_imgs.dtype
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
@@ -570,12 +583,14 @@ class HeatmapAlignment():
return heatmap_imgs.astype(original_dtype)
def GenerateHeatmapTransform(
coco18tococo17_args,
padkeypoints_args,
norm_args,
heatmap_generator_args,
align_args,
coco18tococo17_args: dict[str, Any],
padkeypoints_args: dict[str, Any],
norm_args: dict[str, Any],
heatmap_generator_args: dict[str, Any],
align_args: dict[str, Any],
reduction: Literal["upstream", "max", "sum"] = "upstream",
sigma_limb: float | None = None,
sigma_joint: float | None = None,
):
base_transform = T.Compose([
@@ -584,34 +599,44 @@ def GenerateHeatmapTransform(
CenterAndScaleNormalizer(**norm_args),
])
heatmap_generator_args["with_limb"] = True
heatmap_generator_args["with_kp"] = False
bone_generator_args = deepcopy(heatmap_generator_args)
joint_generator_args = deepcopy(heatmap_generator_args)
bone_generator_args["with_limb"] = True
bone_generator_args["with_kp"] = False
if sigma_limb is not None:
bone_generator_args["sigma"] = sigma_limb
bone_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_bone = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
GeneratePoseTarget(**bone_generator_args),
bone_image_transform,
HeatmapAlignment(**align_args)
])
heatmap_generator_args["with_limb"] = False
heatmap_generator_args["with_kp"] = True
joint_generator_args["with_limb"] = False
joint_generator_args["with_kp"] = True
if sigma_joint is not None:
joint_generator_args["sigma"] = sigma_joint
joint_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_joint = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
GeneratePoseTarget(**joint_generator_args),
joint_image_transform,
HeatmapAlignment(**align_args)
])
transform = T.Compose([
GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W]
GatherTransform(
base_transform,
transform_bone,
transform_joint,
HeatmapAlignment(**align_args),
) # [T, 2, H, W]
])
return transform
+11
View File
@@ -98,6 +98,15 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
return cast(dict[str, Any], replaced)
def optional_cfg_float(cfg: dict[str, Any], key: str) -> float | None:
value = cfg.get(key)
if value is None:
return None
if not isinstance(value, (int, float)):
raise TypeError(f"Expected numeric value for {key}, got {type(value).__name__}")
return float(value)
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
return T.Compose([
heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]),
@@ -192,6 +201,8 @@ def main() -> None:
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
align_args=heatmap_cfg["align_args"],
reduction=cast(HeatmapReduction, args.heatmap_reduction),
sigma_limb=optional_cfg_float(heatmap_cfg, "sigma_limb"),
sigma_joint=optional_cfg_float(heatmap_cfg, "sigma_joint"),
)
pose_paths = iter_pose_paths(args.pose_data_path)