From bbb41e8dd9cb8a68351ca41caf4dbb754260ab83 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Sun, 8 Mar 2026 04:04:15 +0800 Subject: [PATCH] Refine DRF preprocessing and body-prior pipeline --- configs/drf/drf_scoliosis1k.yaml | 10 +- configs/drf/drf_scoliosis1k_eval_1gpu.yaml | 106 +++++++++++++++ configs/drf/drf_scoliosis1k_smoke.yaml | 10 +- configs/drf/pretreatment_heatmap_drf.yaml | 26 ++++ datasets/Scoliosis1K/README.md | 12 +- datasets/pretreatment_heatmap.py | 66 ++++++++-- datasets/pretreatment_scoliosis_drf.py | 14 +- opengait/modeling/base_model_body.py | 82 ++++++++++++ opengait/modeling/models/drf.py | 30 +---- scripts/monitor_drf_jobs.py | 145 +++++++++++++++++++++ 10 files changed, 448 insertions(+), 53 deletions(-) create mode 100644 configs/drf/drf_scoliosis1k_eval_1gpu.yaml create mode 100644 configs/drf/pretreatment_heatmap_drf.yaml create mode 100644 opengait/modeling/base_model_body.py create mode 100644 scripts/monitor_drf_jobs.py diff --git a/configs/drf/drf_scoliosis1k.yaml b/configs/drf/drf_scoliosis1k.yaml index 221f75f..cb8cdcf 100644 --- a/configs/drf/drf_scoliosis1k.yaml +++ b/configs/drf/drf_scoliosis1k.yaml @@ -1,6 +1,6 @@ data_cfg: dataset_name: Scoliosis1K - dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118 + dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json num_workers: 1 remove_no_gallery: false @@ -10,7 +10,7 @@ evaluator_cfg: enable_float16: true restore_ckpt_strict: true restore_hint: 20000 - save_name: DRF + save_name: DRF_paper eval_func: evaluate_scoliosis sampler: batch_shuffle: false @@ -19,7 +19,7 @@ evaluator_cfg: frames_all_limit: 720 metric: euc transform: - - type: BaseSilCuttingTransform + - type: BaseSilTransform - type: NoOperation loss_cfg: @@ -90,7 +90,7 @@ trainer_cfg: restore_ckpt_strict: true restore_hint: 0 save_iter: 20000 - save_name: DRF + save_name: DRF_paper sync_BN: true total_iter: 20000 sampler: @@ -102,5 +102,5 @@ trainer_cfg: sample_type: fixed_unordered type: TripletSampler transform: - - type: BaseSilCuttingTransform + - type: BaseSilTransform - type: NoOperation diff --git a/configs/drf/drf_scoliosis1k_eval_1gpu.yaml b/configs/drf/drf_scoliosis1k_eval_1gpu.yaml new file mode 100644 index 0000000..83b1950 --- /dev/null +++ b/configs/drf/drf_scoliosis1k_eval_1gpu.yaml @@ -0,0 +1,106 @@ +data_cfg: + dataset_name: Scoliosis1K + dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper + dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json + num_workers: 1 + remove_no_gallery: false + test_dataset_name: Scoliosis1K + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 20000 + save_name: DRF_paper + eval_func: evaluate_scoliosis + sampler: + batch_shuffle: false + batch_size: 1 + sample_type: all_ordered + frames_all_limit: 720 + metric: euc + transform: + - type: BaseSilTransform + - type: NoOperation + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DRF + num_pairs: 8 + num_metrics: 3 + backbone_cfg: + type: ResNet9 + block: BasicBlock + in_channel: 2 + channels: + - 64 + - 128 + - 256 + - 512 + layers: + - 1 + - 1 + - 1 + - 1 + strides: + - 1 + - 2 + - 2 + - 1 + maxpool: false + SeparateFCs: + in_channels: 512 + out_channels: 256 + parts_num: 16 + SeparateBNNecks: + class_num: 3 + in_channels: 256 + parts_num: 16 + bin_num: + - 16 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: + - 10000 + - 14000 + - 18000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true + fix_BN: false + with_test: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 20000 + save_name: DRF_paper + sync_BN: true + total_iter: 20000 + sampler: + batch_shuffle: true + batch_size: + - 8 + - 8 + frames_num_fixed: 30 + sample_type: fixed_unordered + type: TripletSampler + transform: + - type: BaseSilTransform + - type: NoOperation diff --git a/configs/drf/drf_scoliosis1k_smoke.yaml b/configs/drf/drf_scoliosis1k_smoke.yaml index 9aeae2f..d6b035d 100644 --- a/configs/drf/drf_scoliosis1k_smoke.yaml +++ b/configs/drf/drf_scoliosis1k_smoke.yaml @@ -1,6 +1,6 @@ data_cfg: dataset_name: Scoliosis1K - dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118 + dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json num_workers: 1 remove_no_gallery: false @@ -14,7 +14,7 @@ evaluator_cfg: enable_float16: true restore_ckpt_strict: true restore_hint: 0 - save_name: DRF_smoke + save_name: DRF_paper_smoke eval_func: evaluate_scoliosis sampler: batch_shuffle: false @@ -23,7 +23,7 @@ evaluator_cfg: frames_all_limit: 720 metric: euc transform: - - type: BaseSilCuttingTransform + - type: BaseSilTransform - type: NoOperation loss_cfg: @@ -97,7 +97,7 @@ trainer_cfg: scheduler_reset: false restore_hint: 0 save_iter: 1 - save_name: DRF_smoke + save_name: DRF_paper_smoke sync_BN: true total_iter: 1 sampler: @@ -109,5 +109,5 @@ trainer_cfg: sample_type: fixed_unordered type: TripletSampler transform: - - type: BaseSilCuttingTransform + - type: BaseSilTransform - type: NoOperation diff --git a/configs/drf/pretreatment_heatmap_drf.yaml b/configs/drf/pretreatment_heatmap_drf.yaml new file mode 100644 index 0000000..fbddaf9 --- /dev/null +++ b/configs/drf/pretreatment_heatmap_drf.yaml @@ -0,0 +1,26 @@ +coco18tococo17_args: + transfer_to_coco17: False + +padkeypoints_args: + pad_method: knn + use_conf: True + +norm_args: + pose_format: coco + use_conf: ${padkeypoints_args.use_conf} + heatmap_image_height: 128 + target_body_height: ${norm_args.heatmap_image_height} + +heatmap_generator_args: + sigma: 8.0 + use_score: ${padkeypoints_args.use_conf} + img_h: ${norm_args.heatmap_image_height} + img_w: ${norm_args.heatmap_image_height} + with_limb: null + with_kp: null + +align_args: + align: True + final_img_size: 64 + offset: 0 + heatmap_image_size: ${norm_args.heatmap_image_height} diff --git a/datasets/Scoliosis1K/README.md b/datasets/Scoliosis1K/README.md index 8e22845..cdd756f 100644 --- a/datasets/Scoliosis1K/README.md +++ b/datasets/Scoliosis1K/README.md @@ -99,14 +99,22 @@ The PAV pass is implemented from the paper: 4. compute vertical, midline, and angular deviations for the 8 symmetric joint pairs 5. apply IQR filtering per metric 6. average over time -7. min-max normalize across the dataset, or across `TRAIN_SET` when `--stats_partition` is provided +7. min-max normalize across the full dataset (paper default), or across `TRAIN_SET` when `--stats_partition` is provided as an anti-leakage variant Run: ```bash uv run python datasets/pretreatment_scoliosis_drf.py \ --pose_data_path= \ - --output_path= \ + --output_path= +``` + +To reproduce the paper defaults more closely, the script now uses +`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables +summed two-channel skeleton maps and a literal 128-pixel height normalization. +If you explicitly want train-only PAV min-max statistics, add: + +```bash --stats_partition=./datasets/Scoliosis1K/Scoliosis1K_118.json ``` diff --git a/datasets/pretreatment_heatmap.py b/datasets/pretreatment_heatmap.py index effa922..d3cb35c 100644 --- a/datasets/pretreatment_heatmap.py +++ b/datasets/pretreatment_heatmap.py @@ -118,8 +118,8 @@ class GeneratePoseTarget: ed_x = min(tmp_ed_x + 1, img_w) st_y = max(tmp_st_y, 0) ed_y = min(tmp_ed_y + 1, img_h) - x = np.arange(st_x, ed_x, 1, np.float32) - y = np.arange(st_y, ed_y, 1, np.float32) + x = np.arange(st_x, ed_x, dtype=np.float32) + y = np.arange(st_y, ed_y, dtype=np.float32) # if the keypoint not in the heatmap coordinate system if not (len(x) and len(y)): @@ -166,8 +166,8 @@ class GeneratePoseTarget: min_y = max(tmp_min_y, 0) max_y = min(tmp_max_y + 1, img_h) - x = np.arange(min_x, max_x, 1, np.float32) - y = np.arange(min_y, max_y, 1, np.float32) + x = np.arange(min_x, max_x, dtype=np.float32) + y = np.arange(min_y, max_y, dtype=np.float32) if not (len(x) and len(y)): continue @@ -324,9 +324,37 @@ class HeatmapToImage: heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps] return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2)) + +class HeatmapReducer: + """Reduce stacked joint/limb heatmaps to a single grayscale channel.""" + + def __init__(self, reduction: str = "max") -> None: + if reduction not in {"max", "sum"}: + raise ValueError(f"Unsupported heatmap reduction: {reduction}") + self.reduction = reduction + + def __call__(self, heatmaps: np.ndarray) -> np.ndarray: + """ + heatmaps: (T, C, H, W) + return: (T, 1, H, W) + """ + if self.reduction == "max": + reduced = np.max(heatmaps, axis=1, keepdims=True) + reduced = np.clip(reduced, 0.0, 1.0) + return (reduced * 255).astype(np.uint8) + + reduced = np.sum(heatmaps, axis=1, keepdims=True) + return (reduced * 255.0).astype(np.float32) + class CenterAndScaleNormalizer: - def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None: + def __init__( + self, + pose_format="coco", + use_conf=True, + heatmap_image_height=128, + target_body_height=None, + ) -> None: """ Parameters: - pose_format (str): Specifies the format of the keypoints. @@ -334,10 +362,13 @@ class CenterAndScaleNormalizer: The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model. - use_conf (bool): Indicates whether confidence scores. - heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization. + - target_body_height (float | None): Optional normalized body height. When omitted, + preserve the historical SkeletonGait scaling heuristic. """ self.pose_format = pose_format self.use_conf = use_conf self.heatmap_image_height = heatmap_image_height + self.target_body_height = target_body_height def __call__(self, data): """ @@ -369,7 +400,13 @@ class CenterAndScaleNormalizer: # Scale-normalization y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t] y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t] - pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2] + target_body_height = ( + float(self.target_body_height) + if self.target_body_height is not None + else float(self.heatmap_image_height // 1.5) + ) + body_height = np.maximum(y_max - y_min, 1e-6) + pose_seq *= (target_body_height / body_height)[:, np.newaxis, np.newaxis] # [t, v, 2] pose_seq += self.heatmap_image_height // 2 @@ -523,16 +560,21 @@ class HeatmapAlignment(): heatmap_imgs: (T, 1, raw_size, raw_size) return (T, 1, final_img_size, final_img_size) """ - heatmap_imgs = heatmap_imgs / 255. - heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs]) - return (heatmap_imgs * 255).astype('uint8') + original_dtype = heatmap_imgs.dtype + heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0 + heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32) + heatmap_imgs = heatmap_imgs * 255.0 + if np.issubdtype(original_dtype, np.integer): + return np.clip(heatmap_imgs, 0.0, 255.0).astype(original_dtype) + return heatmap_imgs.astype(original_dtype) def GenerateHeatmapTransform( coco18tococo17_args, padkeypoints_args, norm_args, heatmap_generator_args, - align_args + align_args, + reduction="max", ): base_transform = T.Compose([ @@ -545,7 +587,7 @@ def GenerateHeatmapTransform( heatmap_generator_args["with_kp"] = False transform_bone = T.Compose([ GeneratePoseTarget(**heatmap_generator_args), - HeatmapToImage(), + HeatmapReducer(reduction=reduction), HeatmapAlignment(**align_args) ]) @@ -553,7 +595,7 @@ def GenerateHeatmapTransform( heatmap_generator_args["with_kp"] = True transform_joint = T.Compose([ GeneratePoseTarget(**heatmap_generator_args), - HeatmapToImage(), + HeatmapReducer(reduction=reduction), HeatmapAlignment(**align_args) ]) diff --git a/datasets/pretreatment_scoliosis_drf.py b/datasets/pretreatment_scoliosis_drf.py index ebc313c..7a230f8 100644 --- a/datasets/pretreatment_scoliosis_drf.py +++ b/datasets/pretreatment_scoliosis_drf.py @@ -7,7 +7,7 @@ import pickle import sys from glob import glob from pathlib import Path -from typing import Any, TypedDict +from typing import Any, TypedDict, cast import numpy as np import yaml @@ -63,14 +63,17 @@ def get_args() -> argparse.Namespace: _ = parser.add_argument( "--heatmap_cfg_path", type=str, - default="configs/skeletongait/pretreatment_heatmap.yaml", + default="configs/drf/pretreatment_heatmap_drf.yaml", help="Heatmap preprocessing config used to build the skeleton map branch.", ) _ = parser.add_argument( "--stats_partition", type=str, default=None, - help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.", + help=( + "Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. " + "Omit it to match the paper's dataset-level min-max normalization." + ), ) return parser.parse_args() @@ -79,7 +82,9 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]: with open(cfg_path, "r", encoding="utf-8") as stream: cfg = yaml.safe_load(stream) replaced = heatmap_prep.replace_variables(cfg, cfg) - return dict(replaced) + if not isinstance(replaced, dict): + raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}") + return cast(dict[str, Any], replaced) def build_pose_transform(cfg: dict[str, Any]) -> T.Compose: @@ -175,6 +180,7 @@ def main() -> None: norm_args=heatmap_cfg["norm_args"], heatmap_generator_args=heatmap_cfg["heatmap_generator_args"], align_args=heatmap_cfg["align_args"], + reduction="sum", ) pose_paths = iter_pose_paths(args.pose_data_path) diff --git a/opengait/modeling/base_model_body.py b/opengait/modeling/base_model_body.py new file mode 100644 index 0000000..0a59033 --- /dev/null +++ b/opengait/modeling/base_model_body.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any, Callable, cast + +import numpy as np +import torch +from jaxtyping import Float, Int + +from .base_model import BaseModel +from opengait.utils.common import list2var, np2var + + +class BaseModelBody(BaseModel): + """Base model variant with a separate sequence-level body-prior input.""" + + def inputs_pretreament( + self, + inputs: tuple[list[np.ndarray], list[int], list[str], list[str], np.ndarray | None], + ) -> Any: + seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs + seq_trfs = cast( + list[Callable[[Any], Any]], + self.trainer_trfs if self.training else self.evaluator_trfs, + ) + if len(seqs_batch) != len(seq_trfs): + raise ValueError( + "The number of types of input data and transform should be same. " + f"But got {len(seqs_batch)} and {len(seq_trfs)}" + ) + if len(seqs_batch) < 2: + raise ValueError("BaseModelBody expects one visual input and one body-prior input.") + + requires_grad = bool(self.training) + visual_seqs = [ + np2var( + np.asarray([trf(fra) for fra in seq]), + requires_grad=requires_grad, + ).float() + for trf, seq in zip(seq_trfs[:-1], seqs_batch[:-1]) + ] + body_trf = seq_trfs[-1] + body_seq = np2var( + np.asarray([body_trf(fra) for fra in seqs_batch[-1]]), + requires_grad=requires_grad, + ).float() + + labs = list2var(labs_batch).long() + seqL = np2var(seqL_batch).int() if seqL_batch is not None else None + + body_features = aggregate_body_features(body_seq, seqL) + + if seqL is not None: + seqL_sum = int(seqL.sum().data.cpu().numpy()) + ipts = [_[:, :seqL_sum] for _ in visual_seqs] + else: + ipts = visual_seqs + return ipts, labs, typs_batch, vies_batch, seqL, body_features + + +def aggregate_body_features( + sequence_features: Float[torch.Tensor, "..."], + seqL: Int[torch.Tensor, "1 batch"] | None, +) -> Float[torch.Tensor, "batch pairs metrics"]: + """Collapse a sampled body-prior sequence back to one vector per sequence.""" + + if seqL is None: + if sequence_features.ndim < 3: + raise ValueError(f"Expected body prior with >=3 dims, got shape {tuple(sequence_features.shape)}") + return sequence_features.mean(dim=1) + + if sequence_features.ndim < 4: + raise ValueError(f"Expected packed body prior with >=4 dims, got shape {tuple(sequence_features.shape)}") + + lengths = seqL[0].tolist() + flattened = sequence_features.squeeze(0) + aggregated: list[torch.Tensor] = [] + start = 0 + for length in lengths: + end = start + int(length) + aggregated.append(flattened[start:end].mean(dim=0)) + start = end + return torch.stack(aggregated, dim=0) diff --git a/opengait/modeling/models/drf.py b/opengait/modeling/models/drf.py index 887249a..bcf9573 100644 --- a/opengait/modeling/models/drf.py +++ b/opengait/modeling/models/drf.py @@ -8,7 +8,7 @@ from jaxtyping import Float, Int from einops import rearrange -from ..base_model import BaseModel +from ..base_model_body import BaseModelBody from ..modules import ( HorizontalPoolingPyramid, PackSequenceWrapper, @@ -18,7 +18,7 @@ from ..modules import ( ) -class DRF(BaseModel): +class DRF(BaseModelBody): """Dual Representation Framework from arXiv:2509.00872v1.""" def build_network(self, model_cfg: dict[str, Any]) -> None: @@ -43,9 +43,10 @@ class DRF(BaseModel): list[str], list[str], Int[torch.Tensor, "1 batch"] | None, + Float[torch.Tensor, "batch pairs metrics"], ], ) -> dict[str, dict[str, Any]]: - ipts, pids, labels, _, seqL = inputs + ipts, pids, labels, _, seqL, key_features = inputs label_ids = torch.as_tensor( [LABEL_MAP[str(label).lower()] for label in labels], device=pids.device, @@ -58,15 +59,12 @@ class DRF(BaseModel): else: heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w") - pav_seq = ipts[1] - pav = aggregate_sequence_features(pav_seq, seqL) - outs = self.Backbone(heatmaps) outs = self.TP(outs, seqL, options={"dim": 2})[0] feat = self.HPP(outs) embed_1 = self.FCs(feat) - embed_1 = self.PGA(embed_1, pav) + embed_1 = self.PGA(embed_1, key_features) embed_2, logits = self.BNNecks(embed_1) del embed_2 @@ -120,24 +118,6 @@ class PAVGuidedAttention(nn.Module): return embeddings * channel_att * spatial_att -def aggregate_sequence_features( - sequence_features: Float[torch.Tensor, "batch seq pairs metrics"], - seqL: Int[torch.Tensor, "1 batch"] | None, -) -> Float[torch.Tensor, "batch pairs metrics"]: - if seqL is None: - return sequence_features.mean(dim=1) - - lengths = seqL[0].tolist() - flattened = sequence_features.squeeze(0) - aggregated = [] - start = 0 - for length in lengths: - end = start + int(length) - aggregated.append(flattened[start:end].mean(dim=0)) - start = end - return torch.stack(aggregated, dim=0) - - LABEL_MAP: dict[str, int] = { "negative": 0, "neutral": 1, diff --git a/scripts/monitor_drf_jobs.py b/scripts/monitor_drf_jobs.py new file mode 100644 index 0000000..e1ec527 --- /dev/null +++ b/scripts/monitor_drf_jobs.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import argparse +import os +import sys +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Final + + +ERROR_PATTERNS: Final[tuple[str, ...]] = ( + "traceback", + "runtimeerror", + "error:", + "exception", + "failed", + "segmentation fault", + "killed", +) + + +@dataclass(frozen=True) +class JobSpec: + name: str + pid: int + log_path: Path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Monitor long-running DRF preprocess/train jobs.") + parser.add_argument("--preprocess-pid", type=int, required=True) + parser.add_argument("--preprocess-log", type=Path, required=True) + parser.add_argument("--launcher-pid", type=int, required=True) + parser.add_argument("--launcher-log", type=Path, required=True) + parser.add_argument("--sentinel-path", type=Path, required=True) + parser.add_argument("--status-log", type=Path, required=True) + parser.add_argument("--poll-seconds", type=float, default=30.0) + return parser.parse_args() + + +def pid_alive(pid: int) -> bool: + return Path(f"/proc/{pid}").exists() + + +def read_tail(path: Path, limit: int = 8192) -> str: + if not path.exists(): + return "" + with path.open("rb") as handle: + handle.seek(0, os.SEEK_END) + size = handle.tell() + handle.seek(max(size - limit, 0), os.SEEK_SET) + data = handle.read() + return data.decode("utf-8", errors="replace") + + +def detect_error(log_text: str) -> str | None: + lowered = log_text.lower() + for pattern in ERROR_PATTERNS: + if pattern in lowered: + return pattern + return None + + +def append_status(path: Path, line: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(f"{datetime.now().isoformat(timespec='seconds')} {line}\n") + + +def monitor_job(job: JobSpec, status_log: Path) -> str | None: + tail = read_tail(job.log_path) + error = detect_error(tail) + if error is not None: + append_status(status_log, f"[alert] {job.name}: detected `{error}` in {job.log_path}") + return f"{job.name} log shows `{error}`" + return None + + +def main() -> int: + args = parse_args() + preprocess = JobSpec("preprocess", args.preprocess_pid, args.preprocess_log) + launcher = JobSpec("launcher", args.launcher_pid, args.launcher_log) + + append_status( + args.status_log, + ( + "[start] monitoring " + f"preprocess_pid={preprocess.pid} launcher_pid={launcher.pid} " + f"sentinel={args.sentinel_path}" + ), + ) + + preprocess_seen_alive = False + launcher_seen_alive = False + + while True: + preprocess_alive = pid_alive(preprocess.pid) + launcher_alive = pid_alive(launcher.pid) + preprocess_seen_alive = preprocess_seen_alive or preprocess_alive + launcher_seen_alive = launcher_seen_alive or launcher_alive + + preprocess_error = monitor_job(preprocess, args.status_log) + if preprocess_error is not None: + print(preprocess_error, file=sys.stderr) + return 1 + + launcher_error = monitor_job(launcher, args.status_log) + if launcher_error is not None: + print(launcher_error, file=sys.stderr) + return 1 + + sentinel_ready = args.sentinel_path.exists() + append_status( + args.status_log, + ( + "[ok] " + f"preprocess_alive={preprocess_alive} " + f"launcher_alive={launcher_alive} " + f"sentinel_ready={sentinel_ready}" + ), + ) + + if preprocess_seen_alive and not preprocess_alive and not sentinel_ready: + append_status(args.status_log, "[alert] preprocess exited before sentinel was written") + print("preprocess exited before sentinel was written", file=sys.stderr) + return 1 + + launcher_tail = read_tail(launcher.log_path) + train_started = "[start]" in launcher_tail + + if launcher_seen_alive and not launcher_alive: + if not train_started and not sentinel_ready: + append_status(args.status_log, "[alert] launcher exited before training started") + print("launcher exited before training started", file=sys.stderr) + return 1 + append_status(args.status_log, "[done] launcher exited; monitoring complete") + return 0 + + time.sleep(args.poll_seconds) + + +if __name__ == "__main__": + raise SystemExit(main())