Refine DRF preprocessing and body-prior pipeline
This commit is contained in:
@@ -99,14 +99,22 @@ The PAV pass is implemented from the paper:
|
||||
4. compute vertical, midline, and angular deviations for the 8 symmetric joint pairs
|
||||
5. apply IQR filtering per metric
|
||||
6. average over time
|
||||
7. min-max normalize across the dataset, or across `TRAIN_SET` when `--stats_partition` is provided
|
||||
7. min-max normalize across the full dataset (paper default), or across `TRAIN_SET` when `--stats_partition` is provided as an anti-leakage variant
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
uv run python datasets/pretreatment_scoliosis_drf.py \
|
||||
--pose_data_path=<path_to_pose_pkl> \
|
||||
--output_path=<path_to_drf_pkl> \
|
||||
--output_path=<path_to_drf_pkl>
|
||||
```
|
||||
|
||||
To reproduce the paper defaults more closely, the script now uses
|
||||
`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables
|
||||
summed two-channel skeleton maps and a literal 128-pixel height normalization.
|
||||
If you explicitly want train-only PAV min-max statistics, add:
|
||||
|
||||
```bash
|
||||
--stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json
|
||||
```
|
||||
|
||||
|
||||
@@ -118,8 +118,8 @@ class GeneratePoseTarget:
|
||||
ed_x = min(tmp_ed_x + 1, img_w)
|
||||
st_y = max(tmp_st_y, 0)
|
||||
ed_y = min(tmp_ed_y + 1, img_h)
|
||||
x = np.arange(st_x, ed_x, 1, np.float32)
|
||||
y = np.arange(st_y, ed_y, 1, np.float32)
|
||||
x = np.arange(st_x, ed_x, dtype=np.float32)
|
||||
y = np.arange(st_y, ed_y, dtype=np.float32)
|
||||
|
||||
# if the keypoint not in the heatmap coordinate system
|
||||
if not (len(x) and len(y)):
|
||||
@@ -166,8 +166,8 @@ class GeneratePoseTarget:
|
||||
min_y = max(tmp_min_y, 0)
|
||||
max_y = min(tmp_max_y + 1, img_h)
|
||||
|
||||
x = np.arange(min_x, max_x, 1, np.float32)
|
||||
y = np.arange(min_y, max_y, 1, np.float32)
|
||||
x = np.arange(min_x, max_x, dtype=np.float32)
|
||||
y = np.arange(min_y, max_y, dtype=np.float32)
|
||||
|
||||
if not (len(x) and len(y)):
|
||||
continue
|
||||
@@ -324,9 +324,37 @@ class HeatmapToImage:
|
||||
heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
|
||||
return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
|
||||
|
||||
|
||||
class HeatmapReducer:
|
||||
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
|
||||
|
||||
def __init__(self, reduction: str = "max") -> None:
|
||||
if reduction not in {"max", "sum"}:
|
||||
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
|
||||
self.reduction = reduction
|
||||
|
||||
def __call__(self, heatmaps: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
heatmaps: (T, C, H, W)
|
||||
return: (T, 1, H, W)
|
||||
"""
|
||||
if self.reduction == "max":
|
||||
reduced = np.max(heatmaps, axis=1, keepdims=True)
|
||||
reduced = np.clip(reduced, 0.0, 1.0)
|
||||
return (reduced * 255).astype(np.uint8)
|
||||
|
||||
reduced = np.sum(heatmaps, axis=1, keepdims=True)
|
||||
return (reduced * 255.0).astype(np.float32)
|
||||
|
||||
class CenterAndScaleNormalizer:
|
||||
|
||||
def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pose_format="coco",
|
||||
use_conf=True,
|
||||
heatmap_image_height=128,
|
||||
target_body_height=None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
- pose_format (str): Specifies the format of the keypoints.
|
||||
@@ -334,10 +362,13 @@ class CenterAndScaleNormalizer:
|
||||
The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model.
|
||||
- use_conf (bool): Indicates whether confidence scores.
|
||||
- heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization.
|
||||
- target_body_height (float | None): Optional normalized body height. When omitted,
|
||||
preserve the historical SkeletonGait scaling heuristic.
|
||||
"""
|
||||
self.pose_format = pose_format
|
||||
self.use_conf = use_conf
|
||||
self.heatmap_image_height = heatmap_image_height
|
||||
self.target_body_height = target_body_height
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
@@ -369,7 +400,13 @@ class CenterAndScaleNormalizer:
|
||||
# Scale-normalization
|
||||
y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
|
||||
y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
|
||||
pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2]
|
||||
target_body_height = (
|
||||
float(self.target_body_height)
|
||||
if self.target_body_height is not None
|
||||
else float(self.heatmap_image_height // 1.5)
|
||||
)
|
||||
body_height = np.maximum(y_max - y_min, 1e-6)
|
||||
pose_seq *= (target_body_height / body_height)[:, np.newaxis, np.newaxis] # [t, v, 2]
|
||||
|
||||
pose_seq += self.heatmap_image_height // 2
|
||||
|
||||
@@ -523,16 +560,21 @@ class HeatmapAlignment():
|
||||
heatmap_imgs: (T, 1, raw_size, raw_size)
|
||||
return (T, 1, final_img_size, final_img_size)
|
||||
"""
|
||||
heatmap_imgs = heatmap_imgs / 255.
|
||||
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs])
|
||||
return (heatmap_imgs * 255).astype('uint8')
|
||||
original_dtype = heatmap_imgs.dtype
|
||||
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
|
||||
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32)
|
||||
heatmap_imgs = heatmap_imgs * 255.0
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
return np.clip(heatmap_imgs, 0.0, 255.0).astype(original_dtype)
|
||||
return heatmap_imgs.astype(original_dtype)
|
||||
|
||||
def GenerateHeatmapTransform(
|
||||
coco18tococo17_args,
|
||||
padkeypoints_args,
|
||||
norm_args,
|
||||
heatmap_generator_args,
|
||||
align_args
|
||||
align_args,
|
||||
reduction="max",
|
||||
):
|
||||
|
||||
base_transform = T.Compose([
|
||||
@@ -545,7 +587,7 @@ def GenerateHeatmapTransform(
|
||||
heatmap_generator_args["with_kp"] = False
|
||||
transform_bone = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
HeatmapToImage(),
|
||||
HeatmapReducer(reduction=reduction),
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
@@ -553,7 +595,7 @@ def GenerateHeatmapTransform(
|
||||
heatmap_generator_args["with_kp"] = True
|
||||
transform_joint = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
HeatmapToImage(),
|
||||
HeatmapReducer(reduction=reduction),
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pickle
|
||||
import sys
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@@ -63,14 +63,17 @@ def get_args() -> argparse.Namespace:
|
||||
_ = parser.add_argument(
|
||||
"--heatmap_cfg_path",
|
||||
type=str,
|
||||
default="configs/skeletongait/pretreatment_heatmap.yaml",
|
||||
default="configs/drf/pretreatment_heatmap_drf.yaml",
|
||||
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--stats_partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.",
|
||||
help=(
|
||||
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
|
||||
"Omit it to match the paper's dataset-level min-max normalization."
|
||||
),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -79,7 +82,9 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
||||
with open(cfg_path, "r", encoding="utf-8") as stream:
|
||||
cfg = yaml.safe_load(stream)
|
||||
replaced = heatmap_prep.replace_variables(cfg, cfg)
|
||||
return dict(replaced)
|
||||
if not isinstance(replaced, dict):
|
||||
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
|
||||
return cast(dict[str, Any], replaced)
|
||||
|
||||
|
||||
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
||||
@@ -175,6 +180,7 @@ def main() -> None:
|
||||
norm_args=heatmap_cfg["norm_args"],
|
||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||
align_args=heatmap_cfg["align_args"],
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||
|
||||
Reference in New Issue
Block a user