diff --git a/AGENTS.md b/AGENTS.md index 292f8a6..6a61a29 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,7 +11,7 @@ Use it as the default playbook for commands, conventions, and safety checks. - Package/runtime tool: `uv` Critical source-of-truth rule: -- `opengait/demo` is an implementation layer and may contain project-specific behavior. +- `opengait-studio/opengait_studio` is an implementation layer and may contain project-specific behavior. - When asked to “refer to the paper” or verify methodology, use the paper and official citations as ground truth. - Do not treat demo/runtime behavior as proof of paper method unless explicitly cited by the paper. @@ -20,7 +20,7 @@ Critical source-of-truth rule: Install dependencies with uv: ```bash -uv sync --extra torch +uv sync ``` Notes from `pyproject.toml`: diff --git a/README.md b/README.md index b537045..473acf5 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ See [here](https://github.com/jdyjjj/All-in-One-Gait) for details. ### Quick Start (uv) ```bash # Install dependencies -uv sync --extra torch +uv sync # Train CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase train diff --git a/opengait-studio/opengait_studio/output.py b/opengait-studio/opengait_studio/output.py index 5917a1f..50b415f 100644 --- a/opengait-studio/opengait_studio/output.py +++ b/opengait-studio/opengait_studio/output.py @@ -9,6 +9,7 @@ Provides pluggable result publishing: from __future__ import annotations import asyncio +from concurrent.futures import CancelledError, Future import json import logging import nats @@ -181,10 +182,38 @@ class NatsPublisher: def _stop_background_loop(self) -> None: """Stop the background event loop and thread.""" with self._lock: - if self._loop is not None and self._loop.is_running(): - _ = self._loop.call_soon_threadsafe(self._loop.stop) - if self._thread is not None and self._thread.is_alive(): - self._thread.join(timeout=2.0) + loop = self._loop + thread = self._thread + + if loop is not None and loop.is_running(): + try: + + async def _cancel_pending_tasks() -> None: + current = asyncio.current_task() + pending = [ + task + for task in asyncio.all_tasks() + if task is not current and not task.done() + ] + for task in pending: + _ = task.cancel() + if pending: + _ = await asyncio.gather(*pending, return_exceptions=True) + + cancel_future = asyncio.run_coroutine_threadsafe( + _cancel_pending_tasks(), + loop, + ) + cancel_future.result(timeout=2.0) + except (RuntimeError, OSError, TimeoutError, CancelledError): + pass + finally: + _ = loop.call_soon_threadsafe(loop.stop) + + if thread is not None and thread.is_alive(): + thread.join(timeout=2.0) + + with self._lock: self._loop = None self._thread = None @@ -204,7 +233,12 @@ class NatsPublisher: if not self._start_background_loop(): return False + future: Future[_NatsClient] | None = None try: + loop = self._loop + if loop is None: + logger.warning("Background event loop unavailable for NATS connection") + return False async def _connect() -> _NatsClient: nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType] @@ -213,12 +247,21 @@ class NatsPublisher: # Run connection in background loop future = asyncio.run_coroutine_threadsafe( _connect(), - self._loop, # pyright: ignore[reportArgumentType] + loop, ) self._nc = future.result(timeout=10.0) self._connected = True logger.info(f"Connected to NATS at {self._nats_url}") return True + except TimeoutError as e: + if future is not None: + _ = future.cancel() + try: + _ = future.result(timeout=1.0) + except (TimeoutError, CancelledError, RuntimeError, OSError): + pass + logger.warning("Timed out connecting to NATS at %s: %s", self._nats_url, e) + return False except (RuntimeError, OSError, TimeoutError) as e: logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") return False @@ -260,16 +303,19 @@ class NatsPublisher: exc = fut.exception() except (RuntimeError, OSError) as callback_error: logger.warning(f"NATS publish callback failed: {callback_error}") - self._connected = False + with self._lock: + self._connected = False return if exc is not None: logger.warning(f"Failed to publish to NATS: {exc}") - self._connected = False + with self._lock: + self._connected = False future.add_done_callback(_on_done) except (RuntimeError, OSError, ValueError, TypeError) as e: logger.warning(f"Failed to schedule NATS publish: {e}") - self._connected = False + with self._lock: + self._connected = False def close(self) -> None: """Close NATS connection.""" diff --git a/opengait-studio/opengait_studio/pipeline.py b/opengait-studio/opengait_studio/pipeline.py index 62da000..35ad124 100644 --- a/opengait-studio/opengait_studio/pipeline.py +++ b/opengait-studio/opengait_studio/pipeline.py @@ -1,7 +1,9 @@ from __future__ import annotations from collections.abc import Callable +import copy from contextlib import suppress +import inspect import logging from pathlib import Path import time @@ -146,6 +148,8 @@ class ScoliosisPipeline: _visualizer: OpenCVVisualizer | None _last_viz_payload: _VizPayload | None _frame_pacer: _FramePacer | None + _visualizer_accepts_pose_data: bool | None + _visualizer_signature_owner: object | None def __init__( self, @@ -205,6 +209,31 @@ class ScoliosisPipeline: self._visualizer = None self._last_viz_payload = None self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None + self._visualizer_accepts_pose_data = None + self._visualizer_signature_owner = None + + def _detect_visualizer_pose_kwarg(self) -> bool: + visualizer = self._visualizer + if visualizer is None: + return False + if ( + self._visualizer_signature_owner is visualizer + and self._visualizer_accepts_pose_data is not None + ): + return self._visualizer_accepts_pose_data + update_fn = getattr(visualizer, "update", None) + if update_fn is None or not callable(update_fn): + self._visualizer_signature_owner = visualizer + self._visualizer_accepts_pose_data = False + return False + try: + signature = inspect.signature(update_fn) + accepts_pose_data = "pose_data" in signature.parameters + except (ValueError, TypeError): + accepts_pose_data = False + self._visualizer_signature_owner = visualizer + self._visualizer_accepts_pose_data = accepts_pose_data + return accepts_pose_data @staticmethod def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: @@ -459,7 +488,7 @@ class ScoliosisPipeline: viz_payload = None try: viz_payload = self.process_frame(frame_u8, metadata) - except (RuntimeError, ValueError, TypeError, OSError) as frame_error: + except (RuntimeError, ValueError, OSError) as frame_error: logger.warning( "Skipping frame %d due to processing error: %s", frame_idx, @@ -474,6 +503,9 @@ class ScoliosisPipeline: viz_payload_dict = cast(_VizPayload, viz_payload) cached: _VizPayload = {} for k, v in viz_payload_dict.items(): + if k == "pose" and isinstance(v, dict): + cached[k] = cast(dict[str, object], copy.deepcopy(v)) + continue copy_method = cast( Callable[[], object] | None, getattr(v, "copy", None) ) @@ -531,8 +563,7 @@ class ScoliosisPipeline: confidence = None pose_data = None - # Try keyword arg for pose_data (backward compatible with old signatures) - try: + if self._detect_visualizer_pose_kwarg(): keep_running = self._visualizer.update( frame_u8, bbox, @@ -546,8 +577,7 @@ class ScoliosisPipeline: ema_fps, pose_data=pose_data, ) - except TypeError: - # Fallback for legacy visualizers that don't accept pose_data + else: keep_running = self._visualizer.update( frame_u8, bbox, @@ -676,7 +706,7 @@ class ScoliosisPipeline: "frame": pa.array(frames, type=pa.int64()), "track_id": pa.array(track_ids, type=pa.int64()), "timestamp_ns": pa.array(timestamps, type=pa.int64()), - "silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())), + "silhouette": pa.array(silhouettes, type=pa.list_(pa.float32())), } ) @@ -746,7 +776,7 @@ class ScoliosisPipeline: track_ids.append(result["track_id"]) labels.append(result["label"]) confidences.append(result["confidence"]) - windows.append(result["window"]) + windows.append(int(result["window"])) timestamps.append(result["timestamp_ns"]) table = pa.table( diff --git a/opengait-studio/opengait_studio/visualizer.py b/opengait-studio/opengait_studio/visualizer.py index 49bb3bc..0ef4a1f 100644 --- a/opengait-studio/opengait_studio/visualizer.py +++ b/opengait-studio/opengait_studio/visualizer.py @@ -50,45 +50,24 @@ ImageArray = NDArray[np.uint8] # COCO-format skeleton connections (17 keypoints) # Connections are pairs of keypoint indices SKELETON_CONNECTIONS: list[tuple[int, int]] = [ - (0, 1), # nose -> left_eye - (0, 2), # nose -> right_eye - (1, 3), # left_eye -> left_ear - (2, 4), # right_eye -> right_ear - (5, 6), # left_shoulder -> right_shoulder - (5, 7), # left_shoulder -> left_elbow - (7, 9), # left_elbow -> left_wrist - (6, 8), # right_shoulder -> right_elbow + (0, 1), # nose -> left_eye + (0, 2), # nose -> right_eye + (1, 3), # left_eye -> left_ear + (2, 4), # right_eye -> right_ear + (5, 6), # left_shoulder -> right_shoulder + (5, 7), # left_shoulder -> left_elbow + (7, 9), # left_elbow -> left_wrist + (6, 8), # right_shoulder -> right_elbow (8, 10), # right_elbow -> right_wrist - (11, 12), # left_hip -> right_hip + (11, 12), # left_hip -> right_hip (5, 11), # left_shoulder -> left_hip (6, 12), # right_shoulder -> right_hip - (11, 13), # left_hip -> left_knee - (13, 15), # left_knee -> left_ankle - (12, 14), # right_hip -> right_knee - (14, 16), # right_knee -> right_ankle + (11, 13), # left_hip -> left_knee + (13, 15), # left_knee -> left_ankle + (12, 14), # right_hip -> right_knee + (14, 16), # right_knee -> right_ankle ] -# Keypoint names for COCO format (17 keypoints) -KEYPOINT_NAMES: list[str] = [ - "nose", "left_eye", "right_eye", "left_ear", "right_ear", - "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", - "left_wrist", "right_wrist", "left_hip", "right_hip", - "left_knee", "right_knee", "left_ankle", "right_ankle" -] - -# Joints where angles are typically calculated (for scoliosis/ gait analysis) -ANGLE_JOINTS: list[tuple[int, int, int]] = [ - (5, 7, 9), # left_shoulder -> left_elbow -> left_wrist - (6, 8, 10), # right_shoulder -> right_elbow -> right_wrist - (7, 5, 11), # left_elbow -> left_shoulder -> left_hip - (8, 6, 12), # right_elbow -> right_shoulder -> right_hip - (5, 11, 13), # left_shoulder -> left_hip -> left_knee - (6, 12, 14), # right_shoulder -> right_hip -> right_knee - (11, 13, 15),# left_hip -> left_knee -> left_ankle - (12, 14, 16),# right_hip -> right_knee -> right_ankle -] - - class OpenCVVisualizer: def __init__(self) -> None: @@ -210,7 +189,7 @@ class OpenCVVisualizer: if pose_data is None: return - keypoints_obj = pose_data.get('keypoints') + keypoints_obj = pose_data.get("keypoints") if keypoints_obj is None: return @@ -222,7 +201,7 @@ class OpenCVVisualizer: h, w = frame.shape[:2] # Get confidence scores if available - confidence_obj = pose_data.get('confidence') + confidence_obj = pose_data.get("confidence") confidences = ( np.asarray(confidence_obj, dtype=np.float32) if confidence_obj is not None @@ -267,7 +246,7 @@ class OpenCVVisualizer: if pose_data is None: return - angles_obj = pose_data.get('angles') + angles_obj = pose_data.get("angles") if angles_obj is None: return @@ -467,12 +446,10 @@ class OpenCVVisualizer: def _prepare_segmentation_view( self, - mask_raw: ImageArray | None, + _mask_raw: ImageArray | None, silhouette: NDArray[np.float32] | None, - bbox: BBoxXYXY | None, + _bbox: BBoxXYXY | None, ) -> ImageArray: - _ = mask_raw - _ = bbox return self._prepare_normalized_view(silhouette) def _fit_gray_to_display( @@ -661,20 +638,7 @@ class OpenCVVisualizer: y_pos = h - 8 y_top = max(0, h - MODE_LABEL_PAD) - _ = cv2.rectangle( - image, - (0, y_top), - (w, h), - COLOR_DARK_GRAY, - -1, - ) - _ = cv2.rectangle( - image, - (x_pos - 6, y_pos - text_height - 6), - (x_pos + text_width + 8, y_pos + 6), - COLOR_DARK_GRAY, - -1, - ) + _ = cv2.rectangle(image, (0, y_top), (w, h), COLOR_DARK_GRAY, -1) # Draw text _ = cv2.putText( @@ -706,9 +670,11 @@ class OpenCVVisualizer: Args: frame: Input frame (H, W, C) uint8 bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None + bbox_mask: Bounding box in mask coordinates (x1, y1, x2, y2) or None track_id: Tracking ID mask_raw: Raw binary mask (H, W) uint8 or None silhouette: Normalized silhouette (64, 44) float32 [0,1] or None + segmentation_input: Windowed silhouette stack for model input visualization label: Classification label or None confidence: Classification confidence [0,1] or None fps: Current FPS diff --git a/pyproject.toml b/pyproject.toml index 2671f02..1f2f90f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,8 @@ name = "opengait" version = "0.0.0" requires-python = ">=3.10" dependencies = [ + "torch>=2.0", + "torchvision", "pyyaml", "tensorboard", "opencv-python", @@ -24,10 +26,6 @@ dependencies = [ ] [project.optional-dependencies] -torch = [ - "torch>=1.10", - "torchvision", -] parquet = [ "pyarrow", ] @@ -45,7 +43,6 @@ include = ["opengait", "opengait.*", "opengait_studio", "opengait_studio.*"] dev = [ "basedpyright>=1.38.1", "pytest", - "nats-py", "ultralytics", "jaxtyping", "beartype", diff --git a/uv.lock b/uv.lock index 977c9a7..0169875 100644 --- a/uv.lock +++ b/uv.lock @@ -1747,6 +1747,8 @@ dependencies = [ { name = "scikit-learn", version = "1.7.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" }, { name = "tensorboard" }, + { name = "torch" }, + { name = "torchvision" }, { name = "tqdm" }, ] @@ -1754,10 +1756,6 @@ dependencies = [ parquet = [ { name = "pyarrow" }, ] -torch = [ - { name = "torch" }, - { name = "torchvision" }, -] wandb = [ { name = "wandb" }, ] @@ -1769,7 +1767,6 @@ dev = [ { name = "click" }, { name = "jaxtyping", version = "0.3.7", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" }, { name = "jaxtyping", version = "0.3.9", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" }, - { name = "nats-py" }, { name = "pytest" }, { name = "ultralytics" }, ] @@ -1790,12 +1787,12 @@ requires-dist = [ { name = "pyyaml" }, { name = "scikit-learn" }, { name = "tensorboard" }, - { name = "torch", marker = "extra == 'torch'", specifier = ">=1.10" }, - { name = "torchvision", marker = "extra == 'torch'" }, + { name = "torch", specifier = ">=2.0" }, + { name = "torchvision" }, { name = "tqdm" }, { name = "wandb", marker = "extra == 'wandb'" }, ] -provides-extras = ["torch", "parquet", "wandb"] +provides-extras = ["parquet", "wandb"] [package.metadata.requires-dev] dev = [ @@ -1803,7 +1800,6 @@ dev = [ { name = "beartype" }, { name = "click" }, { name = "jaxtyping" }, - { name = "nats-py" }, { name = "pytest" }, { name = "ultralytics" }, ]