Refine DRF preprocessing and body-prior pipeline
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
data_cfg:
|
data_cfg:
|
||||||
dataset_name: Scoliosis1K
|
dataset_name: Scoliosis1K
|
||||||
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118
|
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
|
||||||
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
|
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
remove_no_gallery: false
|
remove_no_gallery: false
|
||||||
@@ -10,7 +10,7 @@ evaluator_cfg:
|
|||||||
enable_float16: true
|
enable_float16: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 20000
|
restore_hint: 20000
|
||||||
save_name: DRF
|
save_name: DRF_paper
|
||||||
eval_func: evaluate_scoliosis
|
eval_func: evaluate_scoliosis
|
||||||
sampler:
|
sampler:
|
||||||
batch_shuffle: false
|
batch_shuffle: false
|
||||||
@@ -19,7 +19,7 @@ evaluator_cfg:
|
|||||||
frames_all_limit: 720
|
frames_all_limit: 720
|
||||||
metric: euc
|
metric: euc
|
||||||
transform:
|
transform:
|
||||||
- type: BaseSilCuttingTransform
|
- type: BaseSilTransform
|
||||||
- type: NoOperation
|
- type: NoOperation
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
@@ -90,7 +90,7 @@ trainer_cfg:
|
|||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 20000
|
save_iter: 20000
|
||||||
save_name: DRF
|
save_name: DRF_paper
|
||||||
sync_BN: true
|
sync_BN: true
|
||||||
total_iter: 20000
|
total_iter: 20000
|
||||||
sampler:
|
sampler:
|
||||||
@@ -102,5 +102,5 @@ trainer_cfg:
|
|||||||
sample_type: fixed_unordered
|
sample_type: fixed_unordered
|
||||||
type: TripletSampler
|
type: TripletSampler
|
||||||
transform:
|
transform:
|
||||||
- type: BaseSilCuttingTransform
|
- type: BaseSilTransform
|
||||||
- type: NoOperation
|
- type: NoOperation
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
data_cfg:
|
||||||
|
dataset_name: Scoliosis1K
|
||||||
|
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
|
||||||
|
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
|
||||||
|
num_workers: 1
|
||||||
|
remove_no_gallery: false
|
||||||
|
test_dataset_name: Scoliosis1K
|
||||||
|
|
||||||
|
evaluator_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 20000
|
||||||
|
save_name: DRF_paper
|
||||||
|
eval_func: evaluate_scoliosis
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: false
|
||||||
|
batch_size: 1
|
||||||
|
sample_type: all_ordered
|
||||||
|
frames_all_limit: 720
|
||||||
|
metric: euc
|
||||||
|
transform:
|
||||||
|
- type: BaseSilTransform
|
||||||
|
- type: NoOperation
|
||||||
|
|
||||||
|
loss_cfg:
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
margin: 0.2
|
||||||
|
type: TripletLoss
|
||||||
|
log_prefix: triplet
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
scale: 16
|
||||||
|
type: CrossEntropyLoss
|
||||||
|
log_prefix: softmax
|
||||||
|
log_accuracy: true
|
||||||
|
|
||||||
|
model_cfg:
|
||||||
|
model: DRF
|
||||||
|
num_pairs: 8
|
||||||
|
num_metrics: 3
|
||||||
|
backbone_cfg:
|
||||||
|
type: ResNet9
|
||||||
|
block: BasicBlock
|
||||||
|
in_channel: 2
|
||||||
|
channels:
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
- 512
|
||||||
|
layers:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
strides:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
maxpool: false
|
||||||
|
SeparateFCs:
|
||||||
|
in_channels: 512
|
||||||
|
out_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
SeparateBNNecks:
|
||||||
|
class_num: 3
|
||||||
|
in_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
bin_num:
|
||||||
|
- 16
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
lr: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
solver: SGD
|
||||||
|
weight_decay: 0.0005
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
gamma: 0.1
|
||||||
|
milestones:
|
||||||
|
- 10000
|
||||||
|
- 14000
|
||||||
|
- 18000
|
||||||
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
|
trainer_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
fix_BN: false
|
||||||
|
with_test: false
|
||||||
|
log_iter: 100
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 0
|
||||||
|
save_iter: 20000
|
||||||
|
save_name: DRF_paper
|
||||||
|
sync_BN: true
|
||||||
|
total_iter: 20000
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: true
|
||||||
|
batch_size:
|
||||||
|
- 8
|
||||||
|
- 8
|
||||||
|
frames_num_fixed: 30
|
||||||
|
sample_type: fixed_unordered
|
||||||
|
type: TripletSampler
|
||||||
|
transform:
|
||||||
|
- type: BaseSilTransform
|
||||||
|
- type: NoOperation
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
data_cfg:
|
data_cfg:
|
||||||
dataset_name: Scoliosis1K
|
dataset_name: Scoliosis1K
|
||||||
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118
|
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
|
||||||
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
|
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
remove_no_gallery: false
|
remove_no_gallery: false
|
||||||
@@ -14,7 +14,7 @@ evaluator_cfg:
|
|||||||
enable_float16: true
|
enable_float16: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_name: DRF_smoke
|
save_name: DRF_paper_smoke
|
||||||
eval_func: evaluate_scoliosis
|
eval_func: evaluate_scoliosis
|
||||||
sampler:
|
sampler:
|
||||||
batch_shuffle: false
|
batch_shuffle: false
|
||||||
@@ -23,7 +23,7 @@ evaluator_cfg:
|
|||||||
frames_all_limit: 720
|
frames_all_limit: 720
|
||||||
metric: euc
|
metric: euc
|
||||||
transform:
|
transform:
|
||||||
- type: BaseSilCuttingTransform
|
- type: BaseSilTransform
|
||||||
- type: NoOperation
|
- type: NoOperation
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
@@ -97,7 +97,7 @@ trainer_cfg:
|
|||||||
scheduler_reset: false
|
scheduler_reset: false
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 1
|
save_iter: 1
|
||||||
save_name: DRF_smoke
|
save_name: DRF_paper_smoke
|
||||||
sync_BN: true
|
sync_BN: true
|
||||||
total_iter: 1
|
total_iter: 1
|
||||||
sampler:
|
sampler:
|
||||||
@@ -109,5 +109,5 @@ trainer_cfg:
|
|||||||
sample_type: fixed_unordered
|
sample_type: fixed_unordered
|
||||||
type: TripletSampler
|
type: TripletSampler
|
||||||
transform:
|
transform:
|
||||||
- type: BaseSilCuttingTransform
|
- type: BaseSilTransform
|
||||||
- type: NoOperation
|
- type: NoOperation
|
||||||
|
|||||||
@@ -0,0 +1,26 @@
|
|||||||
|
coco18tococo17_args:
|
||||||
|
transfer_to_coco17: False
|
||||||
|
|
||||||
|
padkeypoints_args:
|
||||||
|
pad_method: knn
|
||||||
|
use_conf: True
|
||||||
|
|
||||||
|
norm_args:
|
||||||
|
pose_format: coco
|
||||||
|
use_conf: ${padkeypoints_args.use_conf}
|
||||||
|
heatmap_image_height: 128
|
||||||
|
target_body_height: ${norm_args.heatmap_image_height}
|
||||||
|
|
||||||
|
heatmap_generator_args:
|
||||||
|
sigma: 8.0
|
||||||
|
use_score: ${padkeypoints_args.use_conf}
|
||||||
|
img_h: ${norm_args.heatmap_image_height}
|
||||||
|
img_w: ${norm_args.heatmap_image_height}
|
||||||
|
with_limb: null
|
||||||
|
with_kp: null
|
||||||
|
|
||||||
|
align_args:
|
||||||
|
align: True
|
||||||
|
final_img_size: 64
|
||||||
|
offset: 0
|
||||||
|
heatmap_image_size: ${norm_args.heatmap_image_height}
|
||||||
@@ -99,14 +99,22 @@ The PAV pass is implemented from the paper:
|
|||||||
4. compute vertical, midline, and angular deviations for the 8 symmetric joint pairs
|
4. compute vertical, midline, and angular deviations for the 8 symmetric joint pairs
|
||||||
5. apply IQR filtering per metric
|
5. apply IQR filtering per metric
|
||||||
6. average over time
|
6. average over time
|
||||||
7. min-max normalize across the dataset, or across `TRAIN_SET` when `--stats_partition` is provided
|
7. min-max normalize across the full dataset (paper default), or across `TRAIN_SET` when `--stats_partition` is provided as an anti-leakage variant
|
||||||
|
|
||||||
Run:
|
Run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run python datasets/pretreatment_scoliosis_drf.py \
|
uv run python datasets/pretreatment_scoliosis_drf.py \
|
||||||
--pose_data_path=<path_to_pose_pkl> \
|
--pose_data_path=<path_to_pose_pkl> \
|
||||||
--output_path=<path_to_drf_pkl> \
|
--output_path=<path_to_drf_pkl>
|
||||||
|
```
|
||||||
|
|
||||||
|
To reproduce the paper defaults more closely, the script now uses
|
||||||
|
`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables
|
||||||
|
summed two-channel skeleton maps and a literal 128-pixel height normalization.
|
||||||
|
If you explicitly want train-only PAV min-max statistics, add:
|
||||||
|
|
||||||
|
```bash
|
||||||
--stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json
|
--stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -118,8 +118,8 @@ class GeneratePoseTarget:
|
|||||||
ed_x = min(tmp_ed_x + 1, img_w)
|
ed_x = min(tmp_ed_x + 1, img_w)
|
||||||
st_y = max(tmp_st_y, 0)
|
st_y = max(tmp_st_y, 0)
|
||||||
ed_y = min(tmp_ed_y + 1, img_h)
|
ed_y = min(tmp_ed_y + 1, img_h)
|
||||||
x = np.arange(st_x, ed_x, 1, np.float32)
|
x = np.arange(st_x, ed_x, dtype=np.float32)
|
||||||
y = np.arange(st_y, ed_y, 1, np.float32)
|
y = np.arange(st_y, ed_y, dtype=np.float32)
|
||||||
|
|
||||||
# if the keypoint not in the heatmap coordinate system
|
# if the keypoint not in the heatmap coordinate system
|
||||||
if not (len(x) and len(y)):
|
if not (len(x) and len(y)):
|
||||||
@@ -166,8 +166,8 @@ class GeneratePoseTarget:
|
|||||||
min_y = max(tmp_min_y, 0)
|
min_y = max(tmp_min_y, 0)
|
||||||
max_y = min(tmp_max_y + 1, img_h)
|
max_y = min(tmp_max_y + 1, img_h)
|
||||||
|
|
||||||
x = np.arange(min_x, max_x, 1, np.float32)
|
x = np.arange(min_x, max_x, dtype=np.float32)
|
||||||
y = np.arange(min_y, max_y, 1, np.float32)
|
y = np.arange(min_y, max_y, dtype=np.float32)
|
||||||
|
|
||||||
if not (len(x) and len(y)):
|
if not (len(x) and len(y)):
|
||||||
continue
|
continue
|
||||||
@@ -324,9 +324,37 @@ class HeatmapToImage:
|
|||||||
heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
|
heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
|
||||||
return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
|
return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
|
||||||
|
|
||||||
|
|
||||||
|
class HeatmapReducer:
|
||||||
|
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
|
||||||
|
|
||||||
|
def __init__(self, reduction: str = "max") -> None:
|
||||||
|
if reduction not in {"max", "sum"}:
|
||||||
|
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def __call__(self, heatmaps: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
heatmaps: (T, C, H, W)
|
||||||
|
return: (T, 1, H, W)
|
||||||
|
"""
|
||||||
|
if self.reduction == "max":
|
||||||
|
reduced = np.max(heatmaps, axis=1, keepdims=True)
|
||||||
|
reduced = np.clip(reduced, 0.0, 1.0)
|
||||||
|
return (reduced * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
reduced = np.sum(heatmaps, axis=1, keepdims=True)
|
||||||
|
return (reduced * 255.0).astype(np.float32)
|
||||||
|
|
||||||
class CenterAndScaleNormalizer:
|
class CenterAndScaleNormalizer:
|
||||||
|
|
||||||
def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
pose_format="coco",
|
||||||
|
use_conf=True,
|
||||||
|
heatmap_image_height=128,
|
||||||
|
target_body_height=None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
- pose_format (str): Specifies the format of the keypoints.
|
- pose_format (str): Specifies the format of the keypoints.
|
||||||
@@ -334,10 +362,13 @@ class CenterAndScaleNormalizer:
|
|||||||
The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model.
|
The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model.
|
||||||
- use_conf (bool): Indicates whether confidence scores.
|
- use_conf (bool): Indicates whether confidence scores.
|
||||||
- heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization.
|
- heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization.
|
||||||
|
- target_body_height (float | None): Optional normalized body height. When omitted,
|
||||||
|
preserve the historical SkeletonGait scaling heuristic.
|
||||||
"""
|
"""
|
||||||
self.pose_format = pose_format
|
self.pose_format = pose_format
|
||||||
self.use_conf = use_conf
|
self.use_conf = use_conf
|
||||||
self.heatmap_image_height = heatmap_image_height
|
self.heatmap_image_height = heatmap_image_height
|
||||||
|
self.target_body_height = target_body_height
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
"""
|
"""
|
||||||
@@ -369,7 +400,13 @@ class CenterAndScaleNormalizer:
|
|||||||
# Scale-normalization
|
# Scale-normalization
|
||||||
y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
|
y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
|
||||||
y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
|
y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
|
||||||
pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2]
|
target_body_height = (
|
||||||
|
float(self.target_body_height)
|
||||||
|
if self.target_body_height is not None
|
||||||
|
else float(self.heatmap_image_height // 1.5)
|
||||||
|
)
|
||||||
|
body_height = np.maximum(y_max - y_min, 1e-6)
|
||||||
|
pose_seq *= (target_body_height / body_height)[:, np.newaxis, np.newaxis] # [t, v, 2]
|
||||||
|
|
||||||
pose_seq += self.heatmap_image_height // 2
|
pose_seq += self.heatmap_image_height // 2
|
||||||
|
|
||||||
@@ -523,16 +560,21 @@ class HeatmapAlignment():
|
|||||||
heatmap_imgs: (T, 1, raw_size, raw_size)
|
heatmap_imgs: (T, 1, raw_size, raw_size)
|
||||||
return (T, 1, final_img_size, final_img_size)
|
return (T, 1, final_img_size, final_img_size)
|
||||||
"""
|
"""
|
||||||
heatmap_imgs = heatmap_imgs / 255.
|
original_dtype = heatmap_imgs.dtype
|
||||||
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs])
|
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
|
||||||
return (heatmap_imgs * 255).astype('uint8')
|
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32)
|
||||||
|
heatmap_imgs = heatmap_imgs * 255.0
|
||||||
|
if np.issubdtype(original_dtype, np.integer):
|
||||||
|
return np.clip(heatmap_imgs, 0.0, 255.0).astype(original_dtype)
|
||||||
|
return heatmap_imgs.astype(original_dtype)
|
||||||
|
|
||||||
def GenerateHeatmapTransform(
|
def GenerateHeatmapTransform(
|
||||||
coco18tococo17_args,
|
coco18tococo17_args,
|
||||||
padkeypoints_args,
|
padkeypoints_args,
|
||||||
norm_args,
|
norm_args,
|
||||||
heatmap_generator_args,
|
heatmap_generator_args,
|
||||||
align_args
|
align_args,
|
||||||
|
reduction="max",
|
||||||
):
|
):
|
||||||
|
|
||||||
base_transform = T.Compose([
|
base_transform = T.Compose([
|
||||||
@@ -545,7 +587,7 @@ def GenerateHeatmapTransform(
|
|||||||
heatmap_generator_args["with_kp"] = False
|
heatmap_generator_args["with_kp"] = False
|
||||||
transform_bone = T.Compose([
|
transform_bone = T.Compose([
|
||||||
GeneratePoseTarget(**heatmap_generator_args),
|
GeneratePoseTarget(**heatmap_generator_args),
|
||||||
HeatmapToImage(),
|
HeatmapReducer(reduction=reduction),
|
||||||
HeatmapAlignment(**align_args)
|
HeatmapAlignment(**align_args)
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -553,7 +595,7 @@ def GenerateHeatmapTransform(
|
|||||||
heatmap_generator_args["with_kp"] = True
|
heatmap_generator_args["with_kp"] = True
|
||||||
transform_joint = T.Compose([
|
transform_joint = T.Compose([
|
||||||
GeneratePoseTarget(**heatmap_generator_args),
|
GeneratePoseTarget(**heatmap_generator_args),
|
||||||
HeatmapToImage(),
|
HeatmapReducer(reduction=reduction),
|
||||||
HeatmapAlignment(**align_args)
|
HeatmapAlignment(**align_args)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pickle
|
|||||||
import sys
|
import sys
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
@@ -63,14 +63,17 @@ def get_args() -> argparse.Namespace:
|
|||||||
_ = parser.add_argument(
|
_ = parser.add_argument(
|
||||||
"--heatmap_cfg_path",
|
"--heatmap_cfg_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="configs/skeletongait/pretreatment_heatmap.yaml",
|
default="configs/drf/pretreatment_heatmap_drf.yaml",
|
||||||
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
||||||
)
|
)
|
||||||
_ = parser.add_argument(
|
_ = parser.add_argument(
|
||||||
"--stats_partition",
|
"--stats_partition",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.",
|
help=(
|
||||||
|
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
|
||||||
|
"Omit it to match the paper's dataset-level min-max normalization."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -79,7 +82,9 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
|||||||
with open(cfg_path, "r", encoding="utf-8") as stream:
|
with open(cfg_path, "r", encoding="utf-8") as stream:
|
||||||
cfg = yaml.safe_load(stream)
|
cfg = yaml.safe_load(stream)
|
||||||
replaced = heatmap_prep.replace_variables(cfg, cfg)
|
replaced = heatmap_prep.replace_variables(cfg, cfg)
|
||||||
return dict(replaced)
|
if not isinstance(replaced, dict):
|
||||||
|
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
|
||||||
|
return cast(dict[str, Any], replaced)
|
||||||
|
|
||||||
|
|
||||||
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
||||||
@@ -175,6 +180,7 @@ def main() -> None:
|
|||||||
norm_args=heatmap_cfg["norm_args"],
|
norm_args=heatmap_cfg["norm_args"],
|
||||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||||
align_args=heatmap_cfg["align_args"],
|
align_args=heatmap_cfg["align_args"],
|
||||||
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||||
|
|||||||
@@ -0,0 +1,82 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float, Int
|
||||||
|
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from opengait.utils.common import list2var, np2var
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelBody(BaseModel):
|
||||||
|
"""Base model variant with a separate sequence-level body-prior input."""
|
||||||
|
|
||||||
|
def inputs_pretreament(
|
||||||
|
self,
|
||||||
|
inputs: tuple[list[np.ndarray], list[int], list[str], list[str], np.ndarray | None],
|
||||||
|
) -> Any:
|
||||||
|
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||||
|
seq_trfs = cast(
|
||||||
|
list[Callable[[Any], Any]],
|
||||||
|
self.trainer_trfs if self.training else self.evaluator_trfs,
|
||||||
|
)
|
||||||
|
if len(seqs_batch) != len(seq_trfs):
|
||||||
|
raise ValueError(
|
||||||
|
"The number of types of input data and transform should be same. "
|
||||||
|
f"But got {len(seqs_batch)} and {len(seq_trfs)}"
|
||||||
|
)
|
||||||
|
if len(seqs_batch) < 2:
|
||||||
|
raise ValueError("BaseModelBody expects one visual input and one body-prior input.")
|
||||||
|
|
||||||
|
requires_grad = bool(self.training)
|
||||||
|
visual_seqs = [
|
||||||
|
np2var(
|
||||||
|
np.asarray([trf(fra) for fra in seq]),
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
).float()
|
||||||
|
for trf, seq in zip(seq_trfs[:-1], seqs_batch[:-1])
|
||||||
|
]
|
||||||
|
body_trf = seq_trfs[-1]
|
||||||
|
body_seq = np2var(
|
||||||
|
np.asarray([body_trf(fra) for fra in seqs_batch[-1]]),
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
).float()
|
||||||
|
|
||||||
|
labs = list2var(labs_batch).long()
|
||||||
|
seqL = np2var(seqL_batch).int() if seqL_batch is not None else None
|
||||||
|
|
||||||
|
body_features = aggregate_body_features(body_seq, seqL)
|
||||||
|
|
||||||
|
if seqL is not None:
|
||||||
|
seqL_sum = int(seqL.sum().data.cpu().numpy())
|
||||||
|
ipts = [_[:, :seqL_sum] for _ in visual_seqs]
|
||||||
|
else:
|
||||||
|
ipts = visual_seqs
|
||||||
|
return ipts, labs, typs_batch, vies_batch, seqL, body_features
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_body_features(
|
||||||
|
sequence_features: Float[torch.Tensor, "..."],
|
||||||
|
seqL: Int[torch.Tensor, "1 batch"] | None,
|
||||||
|
) -> Float[torch.Tensor, "batch pairs metrics"]:
|
||||||
|
"""Collapse a sampled body-prior sequence back to one vector per sequence."""
|
||||||
|
|
||||||
|
if seqL is None:
|
||||||
|
if sequence_features.ndim < 3:
|
||||||
|
raise ValueError(f"Expected body prior with >=3 dims, got shape {tuple(sequence_features.shape)}")
|
||||||
|
return sequence_features.mean(dim=1)
|
||||||
|
|
||||||
|
if sequence_features.ndim < 4:
|
||||||
|
raise ValueError(f"Expected packed body prior with >=4 dims, got shape {tuple(sequence_features.shape)}")
|
||||||
|
|
||||||
|
lengths = seqL[0].tolist()
|
||||||
|
flattened = sequence_features.squeeze(0)
|
||||||
|
aggregated: list[torch.Tensor] = []
|
||||||
|
start = 0
|
||||||
|
for length in lengths:
|
||||||
|
end = start + int(length)
|
||||||
|
aggregated.append(flattened[start:end].mean(dim=0))
|
||||||
|
start = end
|
||||||
|
return torch.stack(aggregated, dim=0)
|
||||||
@@ -8,7 +8,7 @@ from jaxtyping import Float, Int
|
|||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from ..base_model import BaseModel
|
from ..base_model_body import BaseModelBody
|
||||||
from ..modules import (
|
from ..modules import (
|
||||||
HorizontalPoolingPyramid,
|
HorizontalPoolingPyramid,
|
||||||
PackSequenceWrapper,
|
PackSequenceWrapper,
|
||||||
@@ -18,7 +18,7 @@ from ..modules import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DRF(BaseModel):
|
class DRF(BaseModelBody):
|
||||||
"""Dual Representation Framework from arXiv:2509.00872v1."""
|
"""Dual Representation Framework from arXiv:2509.00872v1."""
|
||||||
|
|
||||||
def build_network(self, model_cfg: dict[str, Any]) -> None:
|
def build_network(self, model_cfg: dict[str, Any]) -> None:
|
||||||
@@ -43,9 +43,10 @@ class DRF(BaseModel):
|
|||||||
list[str],
|
list[str],
|
||||||
list[str],
|
list[str],
|
||||||
Int[torch.Tensor, "1 batch"] | None,
|
Int[torch.Tensor, "1 batch"] | None,
|
||||||
|
Float[torch.Tensor, "batch pairs metrics"],
|
||||||
],
|
],
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
ipts, pids, labels, _, seqL = inputs
|
ipts, pids, labels, _, seqL, key_features = inputs
|
||||||
label_ids = torch.as_tensor(
|
label_ids = torch.as_tensor(
|
||||||
[LABEL_MAP[str(label).lower()] for label in labels],
|
[LABEL_MAP[str(label).lower()] for label in labels],
|
||||||
device=pids.device,
|
device=pids.device,
|
||||||
@@ -58,15 +59,12 @@ class DRF(BaseModel):
|
|||||||
else:
|
else:
|
||||||
heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w")
|
heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w")
|
||||||
|
|
||||||
pav_seq = ipts[1]
|
|
||||||
pav = aggregate_sequence_features(pav_seq, seqL)
|
|
||||||
|
|
||||||
outs = self.Backbone(heatmaps)
|
outs = self.Backbone(heatmaps)
|
||||||
outs = self.TP(outs, seqL, options={"dim": 2})[0]
|
outs = self.TP(outs, seqL, options={"dim": 2})[0]
|
||||||
|
|
||||||
feat = self.HPP(outs)
|
feat = self.HPP(outs)
|
||||||
embed_1 = self.FCs(feat)
|
embed_1 = self.FCs(feat)
|
||||||
embed_1 = self.PGA(embed_1, pav)
|
embed_1 = self.PGA(embed_1, key_features)
|
||||||
|
|
||||||
embed_2, logits = self.BNNecks(embed_1)
|
embed_2, logits = self.BNNecks(embed_1)
|
||||||
del embed_2
|
del embed_2
|
||||||
@@ -120,24 +118,6 @@ class PAVGuidedAttention(nn.Module):
|
|||||||
return embeddings * channel_att * spatial_att
|
return embeddings * channel_att * spatial_att
|
||||||
|
|
||||||
|
|
||||||
def aggregate_sequence_features(
|
|
||||||
sequence_features: Float[torch.Tensor, "batch seq pairs metrics"],
|
|
||||||
seqL: Int[torch.Tensor, "1 batch"] | None,
|
|
||||||
) -> Float[torch.Tensor, "batch pairs metrics"]:
|
|
||||||
if seqL is None:
|
|
||||||
return sequence_features.mean(dim=1)
|
|
||||||
|
|
||||||
lengths = seqL[0].tolist()
|
|
||||||
flattened = sequence_features.squeeze(0)
|
|
||||||
aggregated = []
|
|
||||||
start = 0
|
|
||||||
for length in lengths:
|
|
||||||
end = start + int(length)
|
|
||||||
aggregated.append(flattened[start:end].mean(dim=0))
|
|
||||||
start = end
|
|
||||||
return torch.stack(aggregated, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
LABEL_MAP: dict[str, int] = {
|
LABEL_MAP: dict[str, int] = {
|
||||||
"negative": 0,
|
"negative": 0,
|
||||||
"neutral": 1,
|
"neutral": 1,
|
||||||
|
|||||||
@@ -0,0 +1,145 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
ERROR_PATTERNS: Final[tuple[str, ...]] = (
|
||||||
|
"traceback",
|
||||||
|
"runtimeerror",
|
||||||
|
"error:",
|
||||||
|
"exception",
|
||||||
|
"failed",
|
||||||
|
"segmentation fault",
|
||||||
|
"killed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class JobSpec:
|
||||||
|
name: str
|
||||||
|
pid: int
|
||||||
|
log_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Monitor long-running DRF preprocess/train jobs.")
|
||||||
|
parser.add_argument("--preprocess-pid", type=int, required=True)
|
||||||
|
parser.add_argument("--preprocess-log", type=Path, required=True)
|
||||||
|
parser.add_argument("--launcher-pid", type=int, required=True)
|
||||||
|
parser.add_argument("--launcher-log", type=Path, required=True)
|
||||||
|
parser.add_argument("--sentinel-path", type=Path, required=True)
|
||||||
|
parser.add_argument("--status-log", type=Path, required=True)
|
||||||
|
parser.add_argument("--poll-seconds", type=float, default=30.0)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def pid_alive(pid: int) -> bool:
|
||||||
|
return Path(f"/proc/{pid}").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def read_tail(path: Path, limit: int = 8192) -> str:
|
||||||
|
if not path.exists():
|
||||||
|
return ""
|
||||||
|
with path.open("rb") as handle:
|
||||||
|
handle.seek(0, os.SEEK_END)
|
||||||
|
size = handle.tell()
|
||||||
|
handle.seek(max(size - limit, 0), os.SEEK_SET)
|
||||||
|
data = handle.read()
|
||||||
|
return data.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
def detect_error(log_text: str) -> str | None:
|
||||||
|
lowered = log_text.lower()
|
||||||
|
for pattern in ERROR_PATTERNS:
|
||||||
|
if pattern in lowered:
|
||||||
|
return pattern
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def append_status(path: Path, line: str) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("a", encoding="utf-8") as handle:
|
||||||
|
handle.write(f"{datetime.now().isoformat(timespec='seconds')} {line}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_job(job: JobSpec, status_log: Path) -> str | None:
|
||||||
|
tail = read_tail(job.log_path)
|
||||||
|
error = detect_error(tail)
|
||||||
|
if error is not None:
|
||||||
|
append_status(status_log, f"[alert] {job.name}: detected `{error}` in {job.log_path}")
|
||||||
|
return f"{job.name} log shows `{error}`"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
preprocess = JobSpec("preprocess", args.preprocess_pid, args.preprocess_log)
|
||||||
|
launcher = JobSpec("launcher", args.launcher_pid, args.launcher_log)
|
||||||
|
|
||||||
|
append_status(
|
||||||
|
args.status_log,
|
||||||
|
(
|
||||||
|
"[start] monitoring "
|
||||||
|
f"preprocess_pid={preprocess.pid} launcher_pid={launcher.pid} "
|
||||||
|
f"sentinel={args.sentinel_path}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocess_seen_alive = False
|
||||||
|
launcher_seen_alive = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
preprocess_alive = pid_alive(preprocess.pid)
|
||||||
|
launcher_alive = pid_alive(launcher.pid)
|
||||||
|
preprocess_seen_alive = preprocess_seen_alive or preprocess_alive
|
||||||
|
launcher_seen_alive = launcher_seen_alive or launcher_alive
|
||||||
|
|
||||||
|
preprocess_error = monitor_job(preprocess, args.status_log)
|
||||||
|
if preprocess_error is not None:
|
||||||
|
print(preprocess_error, file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
launcher_error = monitor_job(launcher, args.status_log)
|
||||||
|
if launcher_error is not None:
|
||||||
|
print(launcher_error, file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
sentinel_ready = args.sentinel_path.exists()
|
||||||
|
append_status(
|
||||||
|
args.status_log,
|
||||||
|
(
|
||||||
|
"[ok] "
|
||||||
|
f"preprocess_alive={preprocess_alive} "
|
||||||
|
f"launcher_alive={launcher_alive} "
|
||||||
|
f"sentinel_ready={sentinel_ready}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if preprocess_seen_alive and not preprocess_alive and not sentinel_ready:
|
||||||
|
append_status(args.status_log, "[alert] preprocess exited before sentinel was written")
|
||||||
|
print("preprocess exited before sentinel was written", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
launcher_tail = read_tail(launcher.log_path)
|
||||||
|
train_started = "[start]" in launcher_tail
|
||||||
|
|
||||||
|
if launcher_seen_alive and not launcher_alive:
|
||||||
|
if not train_started and not sentinel_ready:
|
||||||
|
append_status(args.status_log, "[alert] launcher exited before training started")
|
||||||
|
print("launcher exited before training started", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
append_status(args.status_log, "[done] launcher exited; monitoring complete")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
time.sleep(args.poll_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Reference in New Issue
Block a user