Add resumable ScoNet skeleton training diagnostics
This commit is contained in:
@@ -75,6 +75,7 @@ The silhouette and skeleton-map pipelines are different experiments and should n
|
||||
|
||||
* `Scoliosis1K-sil-pkl` is the silhouette modality used by the standard ScoNet configs.
|
||||
* pose-derived heatmap roots such as `Scoliosis1K_sigma_8.0/pkl` or DRF exports are skeleton-map inputs and require `in_channel: 2`.
|
||||
* DRF does **not** use the silhouette stream as an input. It uses `0_heatmap.pkl` plus `1_pav.pkl`.
|
||||
|
||||
Naming note:
|
||||
|
||||
@@ -89,6 +90,18 @@ A strong silhouette checkpoint does not validate the skeleton-map path. In parti
|
||||
|
||||
So if you are debugging DRF or `ScoNet-MT-ske` reproduction, do not use `ScoNet-20000-better.pt` as evidence that the heatmap preprocessing is correct.
|
||||
|
||||
### Overlay caveat
|
||||
|
||||
Do not treat a direct overlay between `Scoliosis1K-sil-pkl` and pose-derived skeleton maps as a valid alignment test.
|
||||
|
||||
Reason:
|
||||
|
||||
* the released silhouette modality is an estimated segmentation output from `PP-HumanSeg v2`
|
||||
* the released pose modality is an estimated keypoint output from `ViTPose`
|
||||
* the two modalities are normalized by different preprocessing pipelines before they reach OpenGait
|
||||
|
||||
So a silhouette-vs-skeleton mismatch in a debug figure is usually a cross-modality frame-of-reference issue, not proof that the raw dataset is bad. The more important check for skeleton-map debugging is whether the **limb and joint channels align with each other** inside `0_heatmap.pkl`.
|
||||
|
||||
---
|
||||
|
||||
## Pose-to-Heatmap Conversion
|
||||
@@ -146,6 +159,21 @@ If you explicitly want train-only PAV min-max statistics, add:
|
||||
--stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json
|
||||
```
|
||||
|
||||
### Heatmap debugging notes
|
||||
|
||||
Current confirmed findings from local debugging:
|
||||
|
||||
* the raw pose dataset itself looks healthy; poor `ScoNet-MT-ske` results are not explained by obvious missing-joint collapse
|
||||
* a larger heatmap sigma can materially blur away the articulated structure; `sigma=8` was much broader than the silhouette geometry, while smaller sigma values recovered more structure
|
||||
* an earlier bug aligned the limb and joint channels separately; that made the two channels of `0_heatmap.pkl` slightly misregistered
|
||||
* the heatmap path is now patched so limb and joint channels share one alignment crop
|
||||
|
||||
Remaining caution:
|
||||
|
||||
* the exported skeleton map is stored as `64x64`
|
||||
* if the runtime config uses `BaseSilCuttingTransform`, the network actually sees `64x44`
|
||||
* that symmetric left/right crop is not automatically wrong, but it is still a meaningful ablation point for skeleton-map experiments
|
||||
|
||||
The output layout is:
|
||||
|
||||
```text
|
||||
|
||||
@@ -8,7 +8,8 @@ import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Literal
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal
|
||||
from tqdm import tqdm
|
||||
import matplotlib.cm as cm
|
||||
import torch.distributed as dist
|
||||
@@ -516,7 +517,7 @@ class GatherTransform(object):
|
||||
"""
|
||||
Gather the different transforms.
|
||||
"""
|
||||
def __init__(self, base_transform, transform_bone, transform_joint):
|
||||
def __init__(self, base_transform, transform_bone, transform_joint, align_transform=None):
|
||||
|
||||
"""
|
||||
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
|
||||
@@ -526,12 +527,15 @@ class GatherTransform(object):
|
||||
self.base_transform = base_transform
|
||||
self.transform_bone = transform_bone
|
||||
self.transform_joint = transform_joint
|
||||
self.align_transform = align_transform
|
||||
|
||||
def __call__(self, pose_data):
|
||||
x = self.base_transform(pose_data)
|
||||
heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
|
||||
heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
|
||||
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
|
||||
if self.align_transform is not None:
|
||||
heatmap = self.align_transform(heatmap)
|
||||
return heatmap
|
||||
|
||||
class HeatmapAlignment():
|
||||
@@ -543,23 +547,32 @@ class HeatmapAlignment():
|
||||
|
||||
def center_crop(self, heatmap):
|
||||
"""
|
||||
Input: [1, heatmap_image_size, heatmap_image_size]
|
||||
Output: [1, final_img_size, final_img_size]
|
||||
Input: [C, heatmap_image_size, heatmap_image_size]
|
||||
Output: [C, final_img_size, final_img_size]
|
||||
"""
|
||||
raw_heatmap = heatmap[0]
|
||||
if self.align:
|
||||
y_sum = raw_heatmap.sum(axis=1)
|
||||
y_top = (y_sum != 0).argmax(axis=0)
|
||||
y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
|
||||
height = y_btm - y_top + 1
|
||||
raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1]
|
||||
raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
||||
return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size]
|
||||
raw_heatmap = heatmap
|
||||
if self.align:
|
||||
support_map = raw_heatmap.max(axis=0)
|
||||
y_sum = support_map.sum(axis=1)
|
||||
nonzero_rows = np.flatnonzero(y_sum != 0)
|
||||
if nonzero_rows.size != 0:
|
||||
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
|
||||
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
|
||||
height = y_btm - y_top + 1
|
||||
x_center = self.heatmap_image_size // 2
|
||||
x_left = max(x_center - (height // 2), 0)
|
||||
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
|
||||
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
|
||||
resized = np.stack([
|
||||
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
||||
for channel in raw_heatmap
|
||||
], axis=0)
|
||||
return resized # [C, final_img_size, final_img_size]
|
||||
|
||||
def __call__(self, heatmap_imgs):
|
||||
"""
|
||||
heatmap_imgs: (T, 1, raw_size, raw_size)
|
||||
return (T, 1, final_img_size, final_img_size)
|
||||
heatmap_imgs: (T, C, raw_size, raw_size)
|
||||
return (T, C, final_img_size, final_img_size)
|
||||
"""
|
||||
original_dtype = heatmap_imgs.dtype
|
||||
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
|
||||
@@ -570,12 +583,14 @@ class HeatmapAlignment():
|
||||
return heatmap_imgs.astype(original_dtype)
|
||||
|
||||
def GenerateHeatmapTransform(
|
||||
coco18tococo17_args,
|
||||
padkeypoints_args,
|
||||
norm_args,
|
||||
heatmap_generator_args,
|
||||
align_args,
|
||||
coco18tococo17_args: dict[str, Any],
|
||||
padkeypoints_args: dict[str, Any],
|
||||
norm_args: dict[str, Any],
|
||||
heatmap_generator_args: dict[str, Any],
|
||||
align_args: dict[str, Any],
|
||||
reduction: Literal["upstream", "max", "sum"] = "upstream",
|
||||
sigma_limb: float | None = None,
|
||||
sigma_joint: float | None = None,
|
||||
):
|
||||
|
||||
base_transform = T.Compose([
|
||||
@@ -584,34 +599,44 @@ def GenerateHeatmapTransform(
|
||||
CenterAndScaleNormalizer(**norm_args),
|
||||
])
|
||||
|
||||
heatmap_generator_args["with_limb"] = True
|
||||
heatmap_generator_args["with_kp"] = False
|
||||
bone_generator_args = deepcopy(heatmap_generator_args)
|
||||
joint_generator_args = deepcopy(heatmap_generator_args)
|
||||
|
||||
bone_generator_args["with_limb"] = True
|
||||
bone_generator_args["with_kp"] = False
|
||||
if sigma_limb is not None:
|
||||
bone_generator_args["sigma"] = sigma_limb
|
||||
bone_image_transform = (
|
||||
HeatmapToImage()
|
||||
if reduction == "upstream"
|
||||
else HeatmapReducer(reduction=reduction)
|
||||
)
|
||||
transform_bone = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
GeneratePoseTarget(**bone_generator_args),
|
||||
bone_image_transform,
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
heatmap_generator_args["with_limb"] = False
|
||||
heatmap_generator_args["with_kp"] = True
|
||||
joint_generator_args["with_limb"] = False
|
||||
joint_generator_args["with_kp"] = True
|
||||
if sigma_joint is not None:
|
||||
joint_generator_args["sigma"] = sigma_joint
|
||||
joint_image_transform = (
|
||||
HeatmapToImage()
|
||||
if reduction == "upstream"
|
||||
else HeatmapReducer(reduction=reduction)
|
||||
)
|
||||
transform_joint = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
GeneratePoseTarget(**joint_generator_args),
|
||||
joint_image_transform,
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
transform = T.Compose([
|
||||
GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W]
|
||||
GatherTransform(
|
||||
base_transform,
|
||||
transform_bone,
|
||||
transform_joint,
|
||||
HeatmapAlignment(**align_args),
|
||||
) # [T, 2, H, W]
|
||||
])
|
||||
|
||||
return transform
|
||||
|
||||
@@ -98,6 +98,15 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], replaced)
|
||||
|
||||
|
||||
def optional_cfg_float(cfg: dict[str, Any], key: str) -> float | None:
|
||||
value = cfg.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, (int, float)):
|
||||
raise TypeError(f"Expected numeric value for {key}, got {type(value).__name__}")
|
||||
return float(value)
|
||||
|
||||
|
||||
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
||||
return T.Compose([
|
||||
heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]),
|
||||
@@ -192,6 +201,8 @@ def main() -> None:
|
||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||
align_args=heatmap_cfg["align_args"],
|
||||
reduction=cast(HeatmapReduction, args.heatmap_reduction),
|
||||
sigma_limb=optional_cfg_float(heatmap_cfg, "sigma_limb"),
|
||||
sigma_joint=optional_cfg_float(heatmap_cfg, "sigma_joint"),
|
||||
)
|
||||
|
||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||
|
||||
Reference in New Issue
Block a user