Refine DRF preprocessing and body-prior pipeline

This commit is contained in:
2026-03-08 04:04:15 +08:00
parent fddbf6eeda
commit bbb41e8dd9
10 changed files with 448 additions and 53 deletions
+5 -5
View File
@@ -1,6 +1,6 @@
data_cfg: data_cfg:
dataset_name: Scoliosis1K dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118 dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1 num_workers: 1
remove_no_gallery: false remove_no_gallery: false
@@ -10,7 +10,7 @@ evaluator_cfg:
enable_float16: true enable_float16: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 20000 restore_hint: 20000
save_name: DRF save_name: DRF_paper
eval_func: evaluate_scoliosis eval_func: evaluate_scoliosis
sampler: sampler:
batch_shuffle: false batch_shuffle: false
@@ -19,7 +19,7 @@ evaluator_cfg:
frames_all_limit: 720 frames_all_limit: 720
metric: euc metric: euc
transform: transform:
- type: BaseSilCuttingTransform - type: BaseSilTransform
- type: NoOperation - type: NoOperation
loss_cfg: loss_cfg:
@@ -90,7 +90,7 @@ trainer_cfg:
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_iter: 20000 save_iter: 20000
save_name: DRF save_name: DRF_paper
sync_BN: true sync_BN: true
total_iter: 20000 total_iter: 20000
sampler: sampler:
@@ -102,5 +102,5 @@ trainer_cfg:
sample_type: fixed_unordered sample_type: fixed_unordered
type: TripletSampler type: TripletSampler
transform: transform:
- type: BaseSilCuttingTransform - type: BaseSilTransform
- type: NoOperation - type: NoOperation
+106
View File
@@ -0,0 +1,106 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1
remove_no_gallery: false
test_dataset_name: Scoliosis1K
evaluator_cfg:
enable_float16: true
restore_ckpt_strict: true
restore_hint: 20000
save_name: DRF_paper
eval_func: evaluate_scoliosis
sampler:
batch_shuffle: false
batch_size: 1
sample_type: all_ordered
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilTransform
- type: NoOperation
loss_cfg:
- loss_term_weight: 1.0
margin: 0.2
type: TripletLoss
log_prefix: triplet
- loss_term_weight: 1.0
scale: 16
type: CrossEntropyLoss
log_prefix: softmax
log_accuracy: true
model_cfg:
model: DRF
num_pairs: 8
num_metrics: 3
backbone_cfg:
type: ResNet9
block: BasicBlock
in_channel: 2
channels:
- 64
- 128
- 256
- 512
layers:
- 1
- 1
- 1
- 1
strides:
- 1
- 2
- 2
- 1
maxpool: false
SeparateFCs:
in_channels: 512
out_channels: 256
parts_num: 16
SeparateBNNecks:
class_num: 3
in_channels: 256
parts_num: 16
bin_num:
- 16
optimizer_cfg:
lr: 0.1
momentum: 0.9
solver: SGD
weight_decay: 0.0005
scheduler_cfg:
gamma: 0.1
milestones:
- 10000
- 14000
- 18000
scheduler: MultiStepLR
trainer_cfg:
enable_float16: true
fix_BN: false
with_test: false
log_iter: 100
restore_ckpt_strict: true
restore_hint: 0
save_iter: 20000
save_name: DRF_paper
sync_BN: true
total_iter: 20000
sampler:
batch_shuffle: true
batch_size:
- 8
- 8
frames_num_fixed: 30
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilTransform
- type: NoOperation
+5 -5
View File
@@ -1,6 +1,6 @@
data_cfg: data_cfg:
dataset_name: Scoliosis1K dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118 dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1 num_workers: 1
remove_no_gallery: false remove_no_gallery: false
@@ -14,7 +14,7 @@ evaluator_cfg:
enable_float16: true enable_float16: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_name: DRF_smoke save_name: DRF_paper_smoke
eval_func: evaluate_scoliosis eval_func: evaluate_scoliosis
sampler: sampler:
batch_shuffle: false batch_shuffle: false
@@ -23,7 +23,7 @@ evaluator_cfg:
frames_all_limit: 720 frames_all_limit: 720
metric: euc metric: euc
transform: transform:
- type: BaseSilCuttingTransform - type: BaseSilTransform
- type: NoOperation - type: NoOperation
loss_cfg: loss_cfg:
@@ -97,7 +97,7 @@ trainer_cfg:
scheduler_reset: false scheduler_reset: false
restore_hint: 0 restore_hint: 0
save_iter: 1 save_iter: 1
save_name: DRF_smoke save_name: DRF_paper_smoke
sync_BN: true sync_BN: true
total_iter: 1 total_iter: 1
sampler: sampler:
@@ -109,5 +109,5 @@ trainer_cfg:
sample_type: fixed_unordered sample_type: fixed_unordered
type: TripletSampler type: TripletSampler
transform: transform:
- type: BaseSilCuttingTransform - type: BaseSilTransform
- type: NoOperation - type: NoOperation
+26
View File
@@ -0,0 +1,26 @@
coco18tococo17_args:
transfer_to_coco17: False
padkeypoints_args:
pad_method: knn
use_conf: True
norm_args:
pose_format: coco
use_conf: ${padkeypoints_args.use_conf}
heatmap_image_height: 128
target_body_height: ${norm_args.heatmap_image_height}
heatmap_generator_args:
sigma: 8.0
use_score: ${padkeypoints_args.use_conf}
img_h: ${norm_args.heatmap_image_height}
img_w: ${norm_args.heatmap_image_height}
with_limb: null
with_kp: null
align_args:
align: True
final_img_size: 64
offset: 0
heatmap_image_size: ${norm_args.heatmap_image_height}
+10 -2
View File
@@ -99,14 +99,22 @@ The PAV pass is implemented from the paper:
4. compute vertical, midline, and angular deviations for the 8 symmetric joint pairs 4. compute vertical, midline, and angular deviations for the 8 symmetric joint pairs
5. apply IQR filtering per metric 5. apply IQR filtering per metric
6. average over time 6. average over time
7. min-max normalize across the dataset, or across `TRAIN_SET` when `--stats_partition` is provided 7. min-max normalize across the full dataset (paper default), or across `TRAIN_SET` when `--stats_partition` is provided as an anti-leakage variant
Run: Run:
```bash ```bash
uv run python datasets/pretreatment_scoliosis_drf.py \ uv run python datasets/pretreatment_scoliosis_drf.py \
--pose_data_path=<path_to_pose_pkl> \ --pose_data_path=<path_to_pose_pkl> \
--output_path=<path_to_drf_pkl> \ --output_path=<path_to_drf_pkl>
```
To reproduce the paper defaults more closely, the script now uses
`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables
summed two-channel skeleton maps and a literal 128-pixel height normalization.
If you explicitly want train-only PAV min-max statistics, add:
```bash
--stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json --stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json
``` ```
+54 -12
View File
@@ -118,8 +118,8 @@ class GeneratePoseTarget:
ed_x = min(tmp_ed_x + 1, img_w) ed_x = min(tmp_ed_x + 1, img_w)
st_y = max(tmp_st_y, 0) st_y = max(tmp_st_y, 0)
ed_y = min(tmp_ed_y + 1, img_h) ed_y = min(tmp_ed_y + 1, img_h)
x = np.arange(st_x, ed_x, 1, np.float32) x = np.arange(st_x, ed_x, dtype=np.float32)
y = np.arange(st_y, ed_y, 1, np.float32) y = np.arange(st_y, ed_y, dtype=np.float32)
# if the keypoint not in the heatmap coordinate system # if the keypoint not in the heatmap coordinate system
if not (len(x) and len(y)): if not (len(x) and len(y)):
@@ -166,8 +166,8 @@ class GeneratePoseTarget:
min_y = max(tmp_min_y, 0) min_y = max(tmp_min_y, 0)
max_y = min(tmp_max_y + 1, img_h) max_y = min(tmp_max_y + 1, img_h)
x = np.arange(min_x, max_x, 1, np.float32) x = np.arange(min_x, max_x, dtype=np.float32)
y = np.arange(min_y, max_y, 1, np.float32) y = np.arange(min_y, max_y, dtype=np.float32)
if not (len(x) and len(y)): if not (len(x) and len(y)):
continue continue
@@ -324,9 +324,37 @@ class HeatmapToImage:
heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps] heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2)) return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
class HeatmapReducer:
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
def __init__(self, reduction: str = "max") -> None:
if reduction not in {"max", "sum"}:
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
self.reduction = reduction
def __call__(self, heatmaps: np.ndarray) -> np.ndarray:
"""
heatmaps: (T, C, H, W)
return: (T, 1, H, W)
"""
if self.reduction == "max":
reduced = np.max(heatmaps, axis=1, keepdims=True)
reduced = np.clip(reduced, 0.0, 1.0)
return (reduced * 255).astype(np.uint8)
reduced = np.sum(heatmaps, axis=1, keepdims=True)
return (reduced * 255.0).astype(np.float32)
class CenterAndScaleNormalizer: class CenterAndScaleNormalizer:
def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None: def __init__(
self,
pose_format="coco",
use_conf=True,
heatmap_image_height=128,
target_body_height=None,
) -> None:
""" """
Parameters: Parameters:
- pose_format (str): Specifies the format of the keypoints. - pose_format (str): Specifies the format of the keypoints.
@@ -334,10 +362,13 @@ class CenterAndScaleNormalizer:
The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model. The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model.
- use_conf (bool): Indicates whether confidence scores. - use_conf (bool): Indicates whether confidence scores.
- heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization. - heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization.
- target_body_height (float | None): Optional normalized body height. When omitted,
preserve the historical SkeletonGait scaling heuristic.
""" """
self.pose_format = pose_format self.pose_format = pose_format
self.use_conf = use_conf self.use_conf = use_conf
self.heatmap_image_height = heatmap_image_height self.heatmap_image_height = heatmap_image_height
self.target_body_height = target_body_height
def __call__(self, data): def __call__(self, data):
""" """
@@ -369,7 +400,13 @@ class CenterAndScaleNormalizer:
# Scale-normalization # Scale-normalization
y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t] y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t] y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2] target_body_height = (
float(self.target_body_height)
if self.target_body_height is not None
else float(self.heatmap_image_height // 1.5)
)
body_height = np.maximum(y_max - y_min, 1e-6)
pose_seq *= (target_body_height / body_height)[:, np.newaxis, np.newaxis] # [t, v, 2]
pose_seq += self.heatmap_image_height // 2 pose_seq += self.heatmap_image_height // 2
@@ -523,16 +560,21 @@ class HeatmapAlignment():
heatmap_imgs: (T, 1, raw_size, raw_size) heatmap_imgs: (T, 1, raw_size, raw_size)
return (T, 1, final_img_size, final_img_size) return (T, 1, final_img_size, final_img_size)
""" """
heatmap_imgs = heatmap_imgs / 255. original_dtype = heatmap_imgs.dtype
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs]) heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
return (heatmap_imgs * 255).astype('uint8') heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32)
heatmap_imgs = heatmap_imgs * 255.0
if np.issubdtype(original_dtype, np.integer):
return np.clip(heatmap_imgs, 0.0, 255.0).astype(original_dtype)
return heatmap_imgs.astype(original_dtype)
def GenerateHeatmapTransform( def GenerateHeatmapTransform(
coco18tococo17_args, coco18tococo17_args,
padkeypoints_args, padkeypoints_args,
norm_args, norm_args,
heatmap_generator_args, heatmap_generator_args,
align_args align_args,
reduction="max",
): ):
base_transform = T.Compose([ base_transform = T.Compose([
@@ -545,7 +587,7 @@ def GenerateHeatmapTransform(
heatmap_generator_args["with_kp"] = False heatmap_generator_args["with_kp"] = False
transform_bone = T.Compose([ transform_bone = T.Compose([
GeneratePoseTarget(**heatmap_generator_args), GeneratePoseTarget(**heatmap_generator_args),
HeatmapToImage(), HeatmapReducer(reduction=reduction),
HeatmapAlignment(**align_args) HeatmapAlignment(**align_args)
]) ])
@@ -553,7 +595,7 @@ def GenerateHeatmapTransform(
heatmap_generator_args["with_kp"] = True heatmap_generator_args["with_kp"] = True
transform_joint = T.Compose([ transform_joint = T.Compose([
GeneratePoseTarget(**heatmap_generator_args), GeneratePoseTarget(**heatmap_generator_args),
HeatmapToImage(), HeatmapReducer(reduction=reduction),
HeatmapAlignment(**align_args) HeatmapAlignment(**align_args)
]) ])
+10 -4
View File
@@ -7,7 +7,7 @@ import pickle
import sys import sys
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Any, TypedDict from typing import Any, TypedDict, cast
import numpy as np import numpy as np
import yaml import yaml
@@ -63,14 +63,17 @@ def get_args() -> argparse.Namespace:
_ = parser.add_argument( _ = parser.add_argument(
"--heatmap_cfg_path", "--heatmap_cfg_path",
type=str, type=str,
default="configs/skeletongait/pretreatment_heatmap.yaml", default="configs/drf/pretreatment_heatmap_drf.yaml",
help="Heatmap preprocessing config used to build the skeleton map branch.", help="Heatmap preprocessing config used to build the skeleton map branch.",
) )
_ = parser.add_argument( _ = parser.add_argument(
"--stats_partition", "--stats_partition",
type=str, type=str,
default=None, default=None,
help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.", help=(
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
"Omit it to match the paper's dataset-level min-max normalization."
),
) )
return parser.parse_args() return parser.parse_args()
@@ -79,7 +82,9 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
with open(cfg_path, "r", encoding="utf-8") as stream: with open(cfg_path, "r", encoding="utf-8") as stream:
cfg = yaml.safe_load(stream) cfg = yaml.safe_load(stream)
replaced = heatmap_prep.replace_variables(cfg, cfg) replaced = heatmap_prep.replace_variables(cfg, cfg)
return dict(replaced) if not isinstance(replaced, dict):
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
return cast(dict[str, Any], replaced)
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose: def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
@@ -175,6 +180,7 @@ def main() -> None:
norm_args=heatmap_cfg["norm_args"], norm_args=heatmap_cfg["norm_args"],
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"], heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
align_args=heatmap_cfg["align_args"], align_args=heatmap_cfg["align_args"],
reduction="sum",
) )
pose_paths = iter_pose_paths(args.pose_data_path) pose_paths = iter_pose_paths(args.pose_data_path)
+82
View File
@@ -0,0 +1,82 @@
from __future__ import annotations
from typing import Any, Callable, cast
import numpy as np
import torch
from jaxtyping import Float, Int
from .base_model import BaseModel
from opengait.utils.common import list2var, np2var
class BaseModelBody(BaseModel):
"""Base model variant with a separate sequence-level body-prior input."""
def inputs_pretreament(
self,
inputs: tuple[list[np.ndarray], list[int], list[str], list[str], np.ndarray | None],
) -> Any:
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
seq_trfs = cast(
list[Callable[[Any], Any]],
self.trainer_trfs if self.training else self.evaluator_trfs,
)
if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. "
f"But got {len(seqs_batch)} and {len(seq_trfs)}"
)
if len(seqs_batch) < 2:
raise ValueError("BaseModelBody expects one visual input and one body-prior input.")
requires_grad = bool(self.training)
visual_seqs = [
np2var(
np.asarray([trf(fra) for fra in seq]),
requires_grad=requires_grad,
).float()
for trf, seq in zip(seq_trfs[:-1], seqs_batch[:-1])
]
body_trf = seq_trfs[-1]
body_seq = np2var(
np.asarray([body_trf(fra) for fra in seqs_batch[-1]]),
requires_grad=requires_grad,
).float()
labs = list2var(labs_batch).long()
seqL = np2var(seqL_batch).int() if seqL_batch is not None else None
body_features = aggregate_body_features(body_seq, seqL)
if seqL is not None:
seqL_sum = int(seqL.sum().data.cpu().numpy())
ipts = [_[:, :seqL_sum] for _ in visual_seqs]
else:
ipts = visual_seqs
return ipts, labs, typs_batch, vies_batch, seqL, body_features
def aggregate_body_features(
sequence_features: Float[torch.Tensor, "..."],
seqL: Int[torch.Tensor, "1 batch"] | None,
) -> Float[torch.Tensor, "batch pairs metrics"]:
"""Collapse a sampled body-prior sequence back to one vector per sequence."""
if seqL is None:
if sequence_features.ndim < 3:
raise ValueError(f"Expected body prior with >=3 dims, got shape {tuple(sequence_features.shape)}")
return sequence_features.mean(dim=1)
if sequence_features.ndim < 4:
raise ValueError(f"Expected packed body prior with >=4 dims, got shape {tuple(sequence_features.shape)}")
lengths = seqL[0].tolist()
flattened = sequence_features.squeeze(0)
aggregated: list[torch.Tensor] = []
start = 0
for length in lengths:
end = start + int(length)
aggregated.append(flattened[start:end].mean(dim=0))
start = end
return torch.stack(aggregated, dim=0)
+5 -25
View File
@@ -8,7 +8,7 @@ from jaxtyping import Float, Int
from einops import rearrange from einops import rearrange
from ..base_model import BaseModel from ..base_model_body import BaseModelBody
from ..modules import ( from ..modules import (
HorizontalPoolingPyramid, HorizontalPoolingPyramid,
PackSequenceWrapper, PackSequenceWrapper,
@@ -18,7 +18,7 @@ from ..modules import (
) )
class DRF(BaseModel): class DRF(BaseModelBody):
"""Dual Representation Framework from arXiv:2509.00872v1.""" """Dual Representation Framework from arXiv:2509.00872v1."""
def build_network(self, model_cfg: dict[str, Any]) -> None: def build_network(self, model_cfg: dict[str, Any]) -> None:
@@ -43,9 +43,10 @@ class DRF(BaseModel):
list[str], list[str],
list[str], list[str],
Int[torch.Tensor, "1 batch"] | None, Int[torch.Tensor, "1 batch"] | None,
Float[torch.Tensor, "batch pairs metrics"],
], ],
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
ipts, pids, labels, _, seqL = inputs ipts, pids, labels, _, seqL, key_features = inputs
label_ids = torch.as_tensor( label_ids = torch.as_tensor(
[LABEL_MAP[str(label).lower()] for label in labels], [LABEL_MAP[str(label).lower()] for label in labels],
device=pids.device, device=pids.device,
@@ -58,15 +59,12 @@ class DRF(BaseModel):
else: else:
heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w") heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w")
pav_seq = ipts[1]
pav = aggregate_sequence_features(pav_seq, seqL)
outs = self.Backbone(heatmaps) outs = self.Backbone(heatmaps)
outs = self.TP(outs, seqL, options={"dim": 2})[0] outs = self.TP(outs, seqL, options={"dim": 2})[0]
feat = self.HPP(outs) feat = self.HPP(outs)
embed_1 = self.FCs(feat) embed_1 = self.FCs(feat)
embed_1 = self.PGA(embed_1, pav) embed_1 = self.PGA(embed_1, key_features)
embed_2, logits = self.BNNecks(embed_1) embed_2, logits = self.BNNecks(embed_1)
del embed_2 del embed_2
@@ -120,24 +118,6 @@ class PAVGuidedAttention(nn.Module):
return embeddings * channel_att * spatial_att return embeddings * channel_att * spatial_att
def aggregate_sequence_features(
sequence_features: Float[torch.Tensor, "batch seq pairs metrics"],
seqL: Int[torch.Tensor, "1 batch"] | None,
) -> Float[torch.Tensor, "batch pairs metrics"]:
if seqL is None:
return sequence_features.mean(dim=1)
lengths = seqL[0].tolist()
flattened = sequence_features.squeeze(0)
aggregated = []
start = 0
for length in lengths:
end = start + int(length)
aggregated.append(flattened[start:end].mean(dim=0))
start = end
return torch.stack(aggregated, dim=0)
LABEL_MAP: dict[str, int] = { LABEL_MAP: dict[str, int] = {
"negative": 0, "negative": 0,
"neutral": 1, "neutral": 1,
+145
View File
@@ -0,0 +1,145 @@
from __future__ import annotations
import argparse
import os
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Final
ERROR_PATTERNS: Final[tuple[str, ...]] = (
"traceback",
"runtimeerror",
"error:",
"exception",
"failed",
"segmentation fault",
"killed",
)
@dataclass(frozen=True)
class JobSpec:
name: str
pid: int
log_path: Path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Monitor long-running DRF preprocess/train jobs.")
parser.add_argument("--preprocess-pid", type=int, required=True)
parser.add_argument("--preprocess-log", type=Path, required=True)
parser.add_argument("--launcher-pid", type=int, required=True)
parser.add_argument("--launcher-log", type=Path, required=True)
parser.add_argument("--sentinel-path", type=Path, required=True)
parser.add_argument("--status-log", type=Path, required=True)
parser.add_argument("--poll-seconds", type=float, default=30.0)
return parser.parse_args()
def pid_alive(pid: int) -> bool:
return Path(f"/proc/{pid}").exists()
def read_tail(path: Path, limit: int = 8192) -> str:
if not path.exists():
return ""
with path.open("rb") as handle:
handle.seek(0, os.SEEK_END)
size = handle.tell()
handle.seek(max(size - limit, 0), os.SEEK_SET)
data = handle.read()
return data.decode("utf-8", errors="replace")
def detect_error(log_text: str) -> str | None:
lowered = log_text.lower()
for pattern in ERROR_PATTERNS:
if pattern in lowered:
return pattern
return None
def append_status(path: Path, line: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as handle:
handle.write(f"{datetime.now().isoformat(timespec='seconds')} {line}\n")
def monitor_job(job: JobSpec, status_log: Path) -> str | None:
tail = read_tail(job.log_path)
error = detect_error(tail)
if error is not None:
append_status(status_log, f"[alert] {job.name}: detected `{error}` in {job.log_path}")
return f"{job.name} log shows `{error}`"
return None
def main() -> int:
args = parse_args()
preprocess = JobSpec("preprocess", args.preprocess_pid, args.preprocess_log)
launcher = JobSpec("launcher", args.launcher_pid, args.launcher_log)
append_status(
args.status_log,
(
"[start] monitoring "
f"preprocess_pid={preprocess.pid} launcher_pid={launcher.pid} "
f"sentinel={args.sentinel_path}"
),
)
preprocess_seen_alive = False
launcher_seen_alive = False
while True:
preprocess_alive = pid_alive(preprocess.pid)
launcher_alive = pid_alive(launcher.pid)
preprocess_seen_alive = preprocess_seen_alive or preprocess_alive
launcher_seen_alive = launcher_seen_alive or launcher_alive
preprocess_error = monitor_job(preprocess, args.status_log)
if preprocess_error is not None:
print(preprocess_error, file=sys.stderr)
return 1
launcher_error = monitor_job(launcher, args.status_log)
if launcher_error is not None:
print(launcher_error, file=sys.stderr)
return 1
sentinel_ready = args.sentinel_path.exists()
append_status(
args.status_log,
(
"[ok] "
f"preprocess_alive={preprocess_alive} "
f"launcher_alive={launcher_alive} "
f"sentinel_ready={sentinel_ready}"
),
)
if preprocess_seen_alive and not preprocess_alive and not sentinel_ready:
append_status(args.status_log, "[alert] preprocess exited before sentinel was written")
print("preprocess exited before sentinel was written", file=sys.stderr)
return 1
launcher_tail = read_tail(launcher.log_path)
train_started = "[start]" in launcher_tail
if launcher_seen_alive and not launcher_alive:
if not train_started and not sentinel_ready:
append_status(args.status_log, "[alert] launcher exited before training started")
print("launcher exited before training started", file=sys.stderr)
return 1
append_status(args.status_log, "[done] launcher exited; monitoring complete")
return 0
time.sleep(args.poll_seconds)
if __name__ == "__main__":
raise SystemExit(main())