fix(studio): harden runtime integration and dependency defaults

Stabilize studio publish/visualization flow and tighten export behavior while aligning project dependencies with the monorepo runtime expectations.
This commit is contained in:
2026-03-03 17:59:56 +08:00
parent 00fcda4fe3
commit 967a10c10e
7 changed files with 122 additions and 87 deletions
+2 -2
View File
@@ -11,7 +11,7 @@ Use it as the default playbook for commands, conventions, and safety checks.
- Package/runtime tool: `uv` - Package/runtime tool: `uv`
Critical source-of-truth rule: Critical source-of-truth rule:
- `opengait/demo` is an implementation layer and may contain project-specific behavior. - `opengait-studio/opengait_studio` is an implementation layer and may contain project-specific behavior.
- When asked to “refer to the paper” or verify methodology, use the paper and official citations as ground truth. - When asked to “refer to the paper” or verify methodology, use the paper and official citations as ground truth.
- Do not treat demo/runtime behavior as proof of paper method unless explicitly cited by the paper. - Do not treat demo/runtime behavior as proof of paper method unless explicitly cited by the paper.
@@ -20,7 +20,7 @@ Critical source-of-truth rule:
Install dependencies with uv: Install dependencies with uv:
```bash ```bash
uv sync --extra torch uv sync
``` ```
Notes from `pyproject.toml`: Notes from `pyproject.toml`:
+1 -1
View File
@@ -79,7 +79,7 @@ See [here](https://github.com/jdyjjj/All-in-One-Gait) for details.
### Quick Start (uv) ### Quick Start (uv)
```bash ```bash
# Install dependencies # Install dependencies
uv sync --extra torch uv sync
# Train # Train
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase train CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase train
+54 -8
View File
@@ -9,6 +9,7 @@ Provides pluggable result publishing:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from concurrent.futures import CancelledError, Future
import json import json
import logging import logging
import nats import nats
@@ -181,10 +182,38 @@ class NatsPublisher:
def _stop_background_loop(self) -> None: def _stop_background_loop(self) -> None:
"""Stop the background event loop and thread.""" """Stop the background event loop and thread."""
with self._lock: with self._lock:
if self._loop is not None and self._loop.is_running(): loop = self._loop
_ = self._loop.call_soon_threadsafe(self._loop.stop) thread = self._thread
if self._thread is not None and self._thread.is_alive():
self._thread.join(timeout=2.0) if loop is not None and loop.is_running():
try:
async def _cancel_pending_tasks() -> None:
current = asyncio.current_task()
pending = [
task
for task in asyncio.all_tasks()
if task is not current and not task.done()
]
for task in pending:
_ = task.cancel()
if pending:
_ = await asyncio.gather(*pending, return_exceptions=True)
cancel_future = asyncio.run_coroutine_threadsafe(
_cancel_pending_tasks(),
loop,
)
cancel_future.result(timeout=2.0)
except (RuntimeError, OSError, TimeoutError, CancelledError):
pass
finally:
_ = loop.call_soon_threadsafe(loop.stop)
if thread is not None and thread.is_alive():
thread.join(timeout=2.0)
with self._lock:
self._loop = None self._loop = None
self._thread = None self._thread = None
@@ -204,7 +233,12 @@ class NatsPublisher:
if not self._start_background_loop(): if not self._start_background_loop():
return False return False
future: Future[_NatsClient] | None = None
try: try:
loop = self._loop
if loop is None:
logger.warning("Background event loop unavailable for NATS connection")
return False
async def _connect() -> _NatsClient: async def _connect() -> _NatsClient:
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType] nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
@@ -213,12 +247,21 @@ class NatsPublisher:
# Run connection in background loop # Run connection in background loop
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(
_connect(), _connect(),
self._loop, # pyright: ignore[reportArgumentType] loop,
) )
self._nc = future.result(timeout=10.0) self._nc = future.result(timeout=10.0)
self._connected = True self._connected = True
logger.info(f"Connected to NATS at {self._nats_url}") logger.info(f"Connected to NATS at {self._nats_url}")
return True return True
except TimeoutError as e:
if future is not None:
_ = future.cancel()
try:
_ = future.result(timeout=1.0)
except (TimeoutError, CancelledError, RuntimeError, OSError):
pass
logger.warning("Timed out connecting to NATS at %s: %s", self._nats_url, e)
return False
except (RuntimeError, OSError, TimeoutError) as e: except (RuntimeError, OSError, TimeoutError) as e:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False return False
@@ -260,16 +303,19 @@ class NatsPublisher:
exc = fut.exception() exc = fut.exception()
except (RuntimeError, OSError) as callback_error: except (RuntimeError, OSError) as callback_error:
logger.warning(f"NATS publish callback failed: {callback_error}") logger.warning(f"NATS publish callback failed: {callback_error}")
self._connected = False with self._lock:
self._connected = False
return return
if exc is not None: if exc is not None:
logger.warning(f"Failed to publish to NATS: {exc}") logger.warning(f"Failed to publish to NATS: {exc}")
self._connected = False with self._lock:
self._connected = False
future.add_done_callback(_on_done) future.add_done_callback(_on_done)
except (RuntimeError, OSError, ValueError, TypeError) as e: except (RuntimeError, OSError, ValueError, TypeError) as e:
logger.warning(f"Failed to schedule NATS publish: {e}") logger.warning(f"Failed to schedule NATS publish: {e}")
self._connected = False with self._lock:
self._connected = False
def close(self) -> None: def close(self) -> None:
"""Close NATS connection.""" """Close NATS connection."""
+37 -7
View File
@@ -1,7 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import copy
from contextlib import suppress from contextlib import suppress
import inspect
import logging import logging
from pathlib import Path from pathlib import Path
import time import time
@@ -146,6 +148,8 @@ class ScoliosisPipeline:
_visualizer: OpenCVVisualizer | None _visualizer: OpenCVVisualizer | None
_last_viz_payload: _VizPayload | None _last_viz_payload: _VizPayload | None
_frame_pacer: _FramePacer | None _frame_pacer: _FramePacer | None
_visualizer_accepts_pose_data: bool | None
_visualizer_signature_owner: object | None
def __init__( def __init__(
self, self,
@@ -205,6 +209,31 @@ class ScoliosisPipeline:
self._visualizer = None self._visualizer = None
self._last_viz_payload = None self._last_viz_payload = None
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
self._visualizer_accepts_pose_data = None
self._visualizer_signature_owner = None
def _detect_visualizer_pose_kwarg(self) -> bool:
visualizer = self._visualizer
if visualizer is None:
return False
if (
self._visualizer_signature_owner is visualizer
and self._visualizer_accepts_pose_data is not None
):
return self._visualizer_accepts_pose_data
update_fn = getattr(visualizer, "update", None)
if update_fn is None or not callable(update_fn):
self._visualizer_signature_owner = visualizer
self._visualizer_accepts_pose_data = False
return False
try:
signature = inspect.signature(update_fn)
accepts_pose_data = "pose_data" in signature.parameters
except (ValueError, TypeError):
accepts_pose_data = False
self._visualizer_signature_owner = visualizer
self._visualizer_accepts_pose_data = accepts_pose_data
return accepts_pose_data
@staticmethod @staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -459,7 +488,7 @@ class ScoliosisPipeline:
viz_payload = None viz_payload = None
try: try:
viz_payload = self.process_frame(frame_u8, metadata) viz_payload = self.process_frame(frame_u8, metadata)
except (RuntimeError, ValueError, TypeError, OSError) as frame_error: except (RuntimeError, ValueError, OSError) as frame_error:
logger.warning( logger.warning(
"Skipping frame %d due to processing error: %s", "Skipping frame %d due to processing error: %s",
frame_idx, frame_idx,
@@ -474,6 +503,9 @@ class ScoliosisPipeline:
viz_payload_dict = cast(_VizPayload, viz_payload) viz_payload_dict = cast(_VizPayload, viz_payload)
cached: _VizPayload = {} cached: _VizPayload = {}
for k, v in viz_payload_dict.items(): for k, v in viz_payload_dict.items():
if k == "pose" and isinstance(v, dict):
cached[k] = cast(dict[str, object], copy.deepcopy(v))
continue
copy_method = cast( copy_method = cast(
Callable[[], object] | None, getattr(v, "copy", None) Callable[[], object] | None, getattr(v, "copy", None)
) )
@@ -531,8 +563,7 @@ class ScoliosisPipeline:
confidence = None confidence = None
pose_data = None pose_data = None
# Try keyword arg for pose_data (backward compatible with old signatures) if self._detect_visualizer_pose_kwarg():
try:
keep_running = self._visualizer.update( keep_running = self._visualizer.update(
frame_u8, frame_u8,
bbox, bbox,
@@ -546,8 +577,7 @@ class ScoliosisPipeline:
ema_fps, ema_fps,
pose_data=pose_data, pose_data=pose_data,
) )
except TypeError: else:
# Fallback for legacy visualizers that don't accept pose_data
keep_running = self._visualizer.update( keep_running = self._visualizer.update(
frame_u8, frame_u8,
bbox, bbox,
@@ -676,7 +706,7 @@ class ScoliosisPipeline:
"frame": pa.array(frames, type=pa.int64()), "frame": pa.array(frames, type=pa.int64()),
"track_id": pa.array(track_ids, type=pa.int64()), "track_id": pa.array(track_ids, type=pa.int64()),
"timestamp_ns": pa.array(timestamps, type=pa.int64()), "timestamp_ns": pa.array(timestamps, type=pa.int64()),
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())), "silhouette": pa.array(silhouettes, type=pa.list_(pa.float32())),
} }
) )
@@ -746,7 +776,7 @@ class ScoliosisPipeline:
track_ids.append(result["track_id"]) track_ids.append(result["track_id"])
labels.append(result["label"]) labels.append(result["label"])
confidences.append(result["confidence"]) confidences.append(result["confidence"])
windows.append(result["window"]) windows.append(int(result["window"]))
timestamps.append(result["timestamp_ns"]) timestamps.append(result["timestamp_ns"])
table = pa.table( table = pa.table(
+21 -55
View File
@@ -50,45 +50,24 @@ ImageArray = NDArray[np.uint8]
# COCO-format skeleton connections (17 keypoints) # COCO-format skeleton connections (17 keypoints)
# Connections are pairs of keypoint indices # Connections are pairs of keypoint indices
SKELETON_CONNECTIONS: list[tuple[int, int]] = [ SKELETON_CONNECTIONS: list[tuple[int, int]] = [
(0, 1), # nose -> left_eye (0, 1), # nose -> left_eye
(0, 2), # nose -> right_eye (0, 2), # nose -> right_eye
(1, 3), # left_eye -> left_ear (1, 3), # left_eye -> left_ear
(2, 4), # right_eye -> right_ear (2, 4), # right_eye -> right_ear
(5, 6), # left_shoulder -> right_shoulder (5, 6), # left_shoulder -> right_shoulder
(5, 7), # left_shoulder -> left_elbow (5, 7), # left_shoulder -> left_elbow
(7, 9), # left_elbow -> left_wrist (7, 9), # left_elbow -> left_wrist
(6, 8), # right_shoulder -> right_elbow (6, 8), # right_shoulder -> right_elbow
(8, 10), # right_elbow -> right_wrist (8, 10), # right_elbow -> right_wrist
(11, 12), # left_hip -> right_hip (11, 12), # left_hip -> right_hip
(5, 11), # left_shoulder -> left_hip (5, 11), # left_shoulder -> left_hip
(6, 12), # right_shoulder -> right_hip (6, 12), # right_shoulder -> right_hip
(11, 13), # left_hip -> left_knee (11, 13), # left_hip -> left_knee
(13, 15), # left_knee -> left_ankle (13, 15), # left_knee -> left_ankle
(12, 14), # right_hip -> right_knee (12, 14), # right_hip -> right_knee
(14, 16), # right_knee -> right_ankle (14, 16), # right_knee -> right_ankle
] ]
# Keypoint names for COCO format (17 keypoints)
KEYPOINT_NAMES: list[str] = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
# Joints where angles are typically calculated (for scoliosis/ gait analysis)
ANGLE_JOINTS: list[tuple[int, int, int]] = [
(5, 7, 9), # left_shoulder -> left_elbow -> left_wrist
(6, 8, 10), # right_shoulder -> right_elbow -> right_wrist
(7, 5, 11), # left_elbow -> left_shoulder -> left_hip
(8, 6, 12), # right_elbow -> right_shoulder -> right_hip
(5, 11, 13), # left_shoulder -> left_hip -> left_knee
(6, 12, 14), # right_shoulder -> right_hip -> right_knee
(11, 13, 15),# left_hip -> left_knee -> left_ankle
(12, 14, 16),# right_hip -> right_knee -> right_ankle
]
class OpenCVVisualizer: class OpenCVVisualizer:
def __init__(self) -> None: def __init__(self) -> None:
@@ -210,7 +189,7 @@ class OpenCVVisualizer:
if pose_data is None: if pose_data is None:
return return
keypoints_obj = pose_data.get('keypoints') keypoints_obj = pose_data.get("keypoints")
if keypoints_obj is None: if keypoints_obj is None:
return return
@@ -222,7 +201,7 @@ class OpenCVVisualizer:
h, w = frame.shape[:2] h, w = frame.shape[:2]
# Get confidence scores if available # Get confidence scores if available
confidence_obj = pose_data.get('confidence') confidence_obj = pose_data.get("confidence")
confidences = ( confidences = (
np.asarray(confidence_obj, dtype=np.float32) np.asarray(confidence_obj, dtype=np.float32)
if confidence_obj is not None if confidence_obj is not None
@@ -267,7 +246,7 @@ class OpenCVVisualizer:
if pose_data is None: if pose_data is None:
return return
angles_obj = pose_data.get('angles') angles_obj = pose_data.get("angles")
if angles_obj is None: if angles_obj is None:
return return
@@ -467,12 +446,10 @@ class OpenCVVisualizer:
def _prepare_segmentation_view( def _prepare_segmentation_view(
self, self,
mask_raw: ImageArray | None, _mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None, silhouette: NDArray[np.float32] | None,
bbox: BBoxXYXY | None, _bbox: BBoxXYXY | None,
) -> ImageArray: ) -> ImageArray:
_ = mask_raw
_ = bbox
return self._prepare_normalized_view(silhouette) return self._prepare_normalized_view(silhouette)
def _fit_gray_to_display( def _fit_gray_to_display(
@@ -661,20 +638,7 @@ class OpenCVVisualizer:
y_pos = h - 8 y_pos = h - 8
y_top = max(0, h - MODE_LABEL_PAD) y_top = max(0, h - MODE_LABEL_PAD)
_ = cv2.rectangle( _ = cv2.rectangle(image, (0, y_top), (w, h), COLOR_DARK_GRAY, -1)
image,
(0, y_top),
(w, h),
COLOR_DARK_GRAY,
-1,
)
_ = cv2.rectangle(
image,
(x_pos - 6, y_pos - text_height - 6),
(x_pos + text_width + 8, y_pos + 6),
COLOR_DARK_GRAY,
-1,
)
# Draw text # Draw text
_ = cv2.putText( _ = cv2.putText(
@@ -706,9 +670,11 @@ class OpenCVVisualizer:
Args: Args:
frame: Input frame (H, W, C) uint8 frame: Input frame (H, W, C) uint8
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
bbox_mask: Bounding box in mask coordinates (x1, y1, x2, y2) or None
track_id: Tracking ID track_id: Tracking ID
mask_raw: Raw binary mask (H, W) uint8 or None mask_raw: Raw binary mask (H, W) uint8 or None
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
segmentation_input: Windowed silhouette stack for model input visualization
label: Classification label or None label: Classification label or None
confidence: Classification confidence [0,1] or None confidence: Classification confidence [0,1] or None
fps: Current FPS fps: Current FPS
+2 -5
View File
@@ -7,6 +7,8 @@ name = "opengait"
version = "0.0.0" version = "0.0.0"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"torch>=2.0",
"torchvision",
"pyyaml", "pyyaml",
"tensorboard", "tensorboard",
"opencv-python", "opencv-python",
@@ -24,10 +26,6 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
torch = [
"torch>=1.10",
"torchvision",
]
parquet = [ parquet = [
"pyarrow", "pyarrow",
] ]
@@ -45,7 +43,6 @@ include = ["opengait", "opengait.*", "opengait_studio", "opengait_studio.*"]
dev = [ dev = [
"basedpyright>=1.38.1", "basedpyright>=1.38.1",
"pytest", "pytest",
"nats-py",
"ultralytics", "ultralytics",
"jaxtyping", "jaxtyping",
"beartype", "beartype",
Generated
+5 -9
View File
@@ -1747,6 +1747,8 @@ dependencies = [
{ name = "scikit-learn", version = "1.7.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" },
{ name = "scikit-learn", version = "1.8.0", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" },
{ name = "tensorboard" }, { name = "tensorboard" },
{ name = "torch" },
{ name = "torchvision" },
{ name = "tqdm" }, { name = "tqdm" },
] ]
@@ -1754,10 +1756,6 @@ dependencies = [
parquet = [ parquet = [
{ name = "pyarrow" }, { name = "pyarrow" },
] ]
torch = [
{ name = "torch" },
{ name = "torchvision" },
]
wandb = [ wandb = [
{ name = "wandb" }, { name = "wandb" },
] ]
@@ -1769,7 +1767,6 @@ dev = [
{ name = "click" }, { name = "click" },
{ name = "jaxtyping", version = "0.3.7", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" }, { name = "jaxtyping", version = "0.3.7", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" },
{ name = "jaxtyping", version = "0.3.9", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" }, { name = "jaxtyping", version = "0.3.9", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" },
{ name = "nats-py" },
{ name = "pytest" }, { name = "pytest" },
{ name = "ultralytics" }, { name = "ultralytics" },
] ]
@@ -1790,12 +1787,12 @@ requires-dist = [
{ name = "pyyaml" }, { name = "pyyaml" },
{ name = "scikit-learn" }, { name = "scikit-learn" },
{ name = "tensorboard" }, { name = "tensorboard" },
{ name = "torch", marker = "extra == 'torch'", specifier = ">=1.10" }, { name = "torch", specifier = ">=2.0" },
{ name = "torchvision", marker = "extra == 'torch'" }, { name = "torchvision" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "wandb", marker = "extra == 'wandb'" }, { name = "wandb", marker = "extra == 'wandb'" },
] ]
provides-extras = ["torch", "parquet", "wandb"] provides-extras = ["parquet", "wandb"]
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
@@ -1803,7 +1800,6 @@ dev = [
{ name = "beartype" }, { name = "beartype" },
{ name = "click" }, { name = "click" },
{ name = "jaxtyping" }, { name = "jaxtyping" },
{ name = "nats-py" },
{ name = "pytest" }, { name = "pytest" },
{ name = "ultralytics" }, { name = "ultralytics" },
] ]