feat: extract opengait_studio monorepo module

Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
2026-03-03 17:16:17 +08:00
parent 5c6bef1ca1
commit 00fcda4fe3
39 changed files with 359 additions and 270 deletions
+1 -1
View File
@@ -57,7 +57,7 @@ CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch \
Demo CLI entry:
```bash
uv run python -m opengait.demo --help
uv run python -m opengait_studio --help
```
## DDP Constraints (Important)
@@ -0,0 +1,7 @@
from __future__ import annotations
from .pipeline import main
if __name__ == "__main__":
main()
@@ -11,6 +11,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import nats
import sys
import threading
import time
@@ -81,7 +82,7 @@ class ConsolePublisher:
json_line = json.dumps(result, ensure_ascii=False, default=str)
_ = self._output.write(json_line + "\n")
self._output.flush()
except Exception as e:
except (OSError, ValueError, TypeError) as e:
logger.warning(f"Failed to publish to console: {e}")
def close(self) -> None:
@@ -173,7 +174,7 @@ class NatsPublisher:
self._thread = threading.Thread(target=run_loop, daemon=True)
self._thread.start()
return True
except Exception as e:
except (RuntimeError, OSError) as e:
logger.warning(f"Failed to start background event loop: {e}")
return False
@@ -204,7 +205,6 @@ class NatsPublisher:
return False
try:
import nats
async def _connect() -> _NatsClient:
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
@@ -219,12 +219,7 @@ class NatsPublisher:
self._connected = True
logger.info(f"Connected to NATS at {self._nats_url}")
return True
except ImportError:
logger.warning(
"nats-py package not installed. Install with: pip install nats-py"
)
return False
except Exception as e:
except (RuntimeError, OSError, TimeoutError) as e:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False
@@ -254,15 +249,27 @@ class NatsPublisher:
_ = await self._nc.publish(self._subject, payload)
_ = await self._nc.flush()
# Run publish in background loop
future = asyncio.run_coroutine_threadsafe(
_publish(),
self._loop, # pyright: ignore[reportArgumentType]
)
future.result(timeout=5.0) # Wait for publish to complete
except Exception as e:
logger.warning(f"Failed to publish to NATS: {e}")
self._connected = False # Mark for reconnection on next publish
def _on_done(publish_future: object) -> None:
fut = cast("asyncio.Future[None]", publish_future)
try:
exc = fut.exception()
except (RuntimeError, OSError) as callback_error:
logger.warning(f"NATS publish callback failed: {callback_error}")
self._connected = False
return
if exc is not None:
logger.warning(f"Failed to publish to NATS: {exc}")
self._connected = False
future.add_done_callback(_on_done)
except (RuntimeError, OSError, ValueError, TypeError) as e:
logger.warning(f"Failed to schedule NATS publish: {e}")
self._connected = False
def close(self) -> None:
"""Close NATS connection."""
@@ -279,7 +286,7 @@ class NatsPublisher:
self._loop,
)
future.result(timeout=5.0)
except Exception as e:
except (RuntimeError, OSError, TimeoutError) as e:
logger.debug(f"Error closing NATS connection: {e}")
finally:
self._nc = None
@@ -93,6 +93,19 @@ class _SelectedSilhouette(TypedDict):
track_id: int
class _VizPayload(TypedDict, total=False):
result: DemoResult
mask_raw: UInt8[ndarray, "h w"] | None
bbox: BBoxXYXY | None
bbox_mask: BBoxXYXY | None
silhouette: Float[ndarray, "64 44"] | None
segmentation_input: NDArray[np.float32] | None
track_id: int
label: str | None
confidence: float | None
pose: dict[str, object] | None
class _FramePacer:
_interval_ns: int
_next_emit_ns: int | None
@@ -131,7 +144,7 @@ class ScoliosisPipeline:
_result_export_format: str
_result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None
_last_viz_payload: _VizPayload | None
_frame_pacer: _FramePacer | None
def __init__(
@@ -208,10 +221,9 @@ class ScoliosisPipeline:
return time.monotonic_ns()
@staticmethod
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]:
mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType]
binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8)
return cast(UInt8[ndarray, "h w"], binary)
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
@@ -294,7 +306,7 @@ class ScoliosisPipeline:
self,
frame: UInt8[ndarray, "h w c"],
metadata: dict[str, object],
) -> dict[str, object] | None:
) -> _VizPayload | None:
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
timestamp_ns = self._extract_timestamp(metadata)
@@ -323,8 +335,8 @@ class ScoliosisPipeline:
bbox = selected["bbox_frame"]
bbox_mask = selected["bbox_mask"]
track_id = selected["track_id"]
pose_data = None
# Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only:
self._silhouette_buffer.append(
{
@@ -350,7 +362,9 @@ class ScoliosisPipeline:
"track_id": track_id,
"label": None,
"confidence": None,
"pose": pose_data,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns
@@ -364,9 +378,8 @@ class ScoliosisPipeline:
"track_id": track_id,
"label": None,
"confidence": None,
"pose": pose_data,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify():
@@ -380,8 +393,8 @@ class ScoliosisPipeline:
"track_id": track_id,
"label": None,
"confidence": None,
"pose": pose_data,
}
window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast(
tuple[str, float],
@@ -415,6 +428,7 @@ class ScoliosisPipeline:
"track_id": track_id,
"label": label,
"confidence": confidence,
"pose": pose_data,
}
def run(self) -> int:
@@ -445,7 +459,7 @@ class ScoliosisPipeline:
viz_payload = None
try:
viz_payload = self.process_frame(frame_u8, metadata)
except Exception as frame_error:
except (RuntimeError, ValueError, TypeError, OSError) as frame_error:
logger.warning(
"Skipping frame %d due to processing error: %s",
frame_idx,
@@ -457,8 +471,8 @@ class ScoliosisPipeline:
# Cache valid payload for no-detection frames
if viz_payload is not None:
# Cache a copy to prevent mutation of original data
viz_payload_dict = cast(dict[str, object], viz_payload)
cached: dict[str, object] = {}
viz_payload_dict = cast(_VizPayload, viz_payload)
cached: _VizPayload = {}
for k, v in viz_payload_dict.items():
copy_method = cast(
Callable[[], object] | None, getattr(v, "copy", None)
@@ -477,12 +491,12 @@ class ScoliosisPipeline:
viz_data["bbox_mask"] = None
viz_data["label"] = None
viz_data["confidence"] = None
viz_data["pose"] = None
else:
viz_data = None
if viz_data is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_data)
viz_dict = cast(_VizPayload, viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask")
@@ -492,8 +506,8 @@ class ScoliosisPipeline:
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
pose_obj = viz_dict.get("pose")
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
@@ -504,6 +518,7 @@ class ScoliosisPipeline:
)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
pose_data = cast(dict[str, object] | None, pose_obj)
else:
# No detection and no cache - use default values
mask_raw = None
@@ -514,19 +529,37 @@ class ScoliosisPipeline:
segmentation_input = None
label = None
confidence = None
pose_data = None
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
)
# Try keyword arg for pose_data (backward compatible with old signatures)
try:
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
pose_data=pose_data,
)
except TypeError:
# Fallback for legacy visualizers that don't accept pose_data
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
@@ -635,7 +668,7 @@ class ScoliosisPipeline:
frames.append(item["frame"])
track_ids.append(item["track_id"])
timestamps.append(item["timestamp_ns"])
silhouette_array = cast(ndarray, item["silhouette"])
silhouette_array = cast(NDArray[np.float32], item["silhouette"])
silhouettes.append(silhouette_array.flatten().tolist())
table = pa.table(
@@ -830,6 +863,12 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
default=None,
help="Directory to save silhouette PNG visualizations.",
)
@click.option(
"--visualize",
is_flag=True,
default=False,
help="Enable real-time visualization.",
)
def main(
source: str,
checkpoint: str,
@@ -839,7 +878,7 @@ def main(
window: int,
stride: int,
window_mode: str,
target_fps: float | None,
target_fps: float,
no_target_fps: bool,
nats_url: str | None,
nats_subject: str,
@@ -850,7 +889,10 @@ def main(
result_export_path: str | None,
result_export_format: str,
silhouette_visualize_dir: str | None,
visualize: bool,
) -> None:
# Resolve effective target_fps: respect --no-target_fps to disable pacing
effective_target_fps = None if no_target_fps else target_fps
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
@@ -884,7 +926,6 @@ def main(
yolo_model=yolo_model,
window=window,
stride=effective_stride,
target_fps=None if no_target_fps else target_fps,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,
@@ -894,6 +935,8 @@ def main(
silhouette_visualize_dir=silhouette_visualize_dir,
result_export_path=result_export_path,
result_export_format=result_export_format,
visualize=visualize,
target_fps=effective_target_fps,
)
raise SystemExit(pipeline.run())
except ValueError as err:
@@ -2,7 +2,6 @@ from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
import sys
from typing import ClassVar, Protocol, cast, override
import torch
@@ -13,10 +12,6 @@ from jaxtyping import Float
import jaxtyping
from torch import Tensor
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
from opengait.modeling.backbones.resnet import ResNet9
from opengait.modeling.modules import (
HorizontalPoolingPyramid,
@@ -41,10 +41,54 @@ COLOR_BLACK = (0, 0, 0)
COLOR_DARK_GRAY = (56, 56, 56)
COLOR_RED = (0, 0, 255)
COLOR_YELLOW = (0, 255, 255)
# Type alias for image arrays (NDArray or cv2.Mat)
COLOR_CYAN = (255, 255, 0)
COLOR_ORANGE = (0, 165, 255)
COLOR_MAGENTA = (255, 0, 255)
ImageArray = NDArray[np.uint8]
# COCO-format skeleton connections (17 keypoints)
# Connections are pairs of keypoint indices
SKELETON_CONNECTIONS: list[tuple[int, int]] = [
(0, 1), # nose -> left_eye
(0, 2), # nose -> right_eye
(1, 3), # left_eye -> left_ear
(2, 4), # right_eye -> right_ear
(5, 6), # left_shoulder -> right_shoulder
(5, 7), # left_shoulder -> left_elbow
(7, 9), # left_elbow -> left_wrist
(6, 8), # right_shoulder -> right_elbow
(8, 10), # right_elbow -> right_wrist
(11, 12), # left_hip -> right_hip
(5, 11), # left_shoulder -> left_hip
(6, 12), # right_shoulder -> right_hip
(11, 13), # left_hip -> left_knee
(13, 15), # left_knee -> left_ankle
(12, 14), # right_hip -> right_knee
(14, 16), # right_knee -> right_ankle
]
# Keypoint names for COCO format (17 keypoints)
KEYPOINT_NAMES: list[str] = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
# Joints where angles are typically calculated (for scoliosis/ gait analysis)
ANGLE_JOINTS: list[tuple[int, int, int]] = [
(5, 7, 9), # left_shoulder -> left_elbow -> left_wrist
(6, 8, 10), # right_shoulder -> right_elbow -> right_wrist
(7, 5, 11), # left_elbow -> left_shoulder -> left_hip
(8, 6, 12), # right_elbow -> right_shoulder -> right_hip
(5, 11, 13), # left_shoulder -> left_hip -> left_knee
(6, 12, 14), # right_shoulder -> right_hip -> right_knee
(11, 13, 15),# left_hip -> left_knee -> left_ankle
(12, 14, 16),# right_hip -> right_knee -> right_ankle
]
class OpenCVVisualizer:
def __init__(self) -> None:
@@ -149,6 +193,134 @@ class OpenCVVisualizer:
thickness,
)
def _draw_pose_skeleton(
self,
frame: ImageArray,
pose_data: dict[str, object] | None,
) -> None:
"""Draw pose skeleton on frame.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
pose_data: Pose data dictionary from Sports2D or similar
Expected format: {'keypoints': [[x1, y1], [x2, y2], ...],
'confidence': [c1, c2, ...],
'angles': {'joint_name': angle, ...}}
"""
if pose_data is None:
return
keypoints_obj = pose_data.get('keypoints')
if keypoints_obj is None:
return
# Convert keypoints to numpy array
keypoints = np.asarray(keypoints_obj, dtype=np.float32)
if keypoints.size == 0:
return
h, w = frame.shape[:2]
# Get confidence scores if available
confidence_obj = pose_data.get('confidence')
confidences = (
np.asarray(confidence_obj, dtype=np.float32)
if confidence_obj is not None
else np.ones(len(keypoints), dtype=np.float32)
)
# Draw skeleton connections
for connection in SKELETON_CONNECTIONS:
idx1, idx2 = connection
if idx1 < len(keypoints) and idx2 < len(keypoints):
# Check confidence threshold (0.3)
if confidences[idx1] > 0.3 and confidences[idx2] > 0.3:
pt1 = (int(keypoints[idx1][0]), int(keypoints[idx1][1]))
pt2 = (int(keypoints[idx2][0]), int(keypoints[idx2][1]))
# Clip to frame bounds
pt1 = (max(0, min(w - 1, pt1[0])), max(0, min(h - 1, pt1[1])))
pt2 = (max(0, min(w - 1, pt2[0])), max(0, min(h - 1, pt2[1])))
_ = cv2.line(frame, pt1, pt2, COLOR_CYAN, 2)
# Draw keypoints
for i, (kp, conf) in enumerate(zip(keypoints, confidences)):
if conf > 0.3 and i < len(keypoints):
x, y = int(kp[0]), int(kp[1])
# Clip to frame bounds
x = max(0, min(w - 1, x))
y = max(0, min(h - 1, y))
# Draw keypoint as circle
_ = cv2.circle(frame, (x, y), 4, COLOR_MAGENTA, -1)
_ = cv2.circle(frame, (x, y), 4, COLOR_WHITE, 1)
def _draw_pose_angles(
self,
frame: ImageArray,
pose_data: dict[str, object] | None,
) -> None:
"""Draw pose angles as text overlay.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
pose_data: Pose data dictionary with 'angles' key
"""
if pose_data is None:
return
angles_obj = pose_data.get('angles')
if angles_obj is None:
return
angles = cast(dict[str, float], angles_obj)
if not angles:
return
# Draw angles in top-right corner
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.45
thickness = 1
line_height = 20
margin = 10
h, w = frame.shape[:2]
# Filter and format angles
angle_texts: list[tuple[str, float]] = []
for name, angle in angles.items():
# Only show angles that are reasonable (0-180 degrees)
if 0 <= angle <= 180:
angle_texts.append((str(name), float(angle)))
# Sort by name for consistent display
angle_texts.sort(key=lambda x: x[0])
# Draw from top-right
for i, (name, angle) in enumerate(angle_texts[:8]): # Limit to 8 angles
text = f"{name}: {angle:.1f}"
(text_width, text_height), _ = cv2.getTextSize(
text, font, font_scale, thickness
)
x_pos = w - margin - text_width - 10
y_pos = margin + (i + 1) * line_height
# Draw background rectangle
_ = cv2.rectangle(
frame,
(x_pos - 4, y_pos - text_height - 4),
(x_pos + text_width + 4, y_pos + 4),
COLOR_BLACK,
-1,
)
# Draw text in orange
_ = cv2.putText(
frame,
text,
(x_pos, y_pos),
font,
font_scale,
COLOR_ORANGE,
thickness,
)
def _prepare_main_frame(
self,
frame: ImageArray,
@@ -157,6 +329,7 @@ class OpenCVVisualizer:
fps: float,
label: str | None,
confidence: float | None,
pose_data: dict[str, object] | None = None,
) -> ImageArray:
"""Prepare main display frame with bbox and text overlay.
@@ -167,6 +340,7 @@ class OpenCVVisualizer:
fps: Current FPS
label: Classification label or None
confidence: Classification confidence or None
pose_data: Pose data dictionary or None
Returns:
Processed frame ready for display
@@ -187,6 +361,10 @@ class OpenCVVisualizer:
self._draw_bbox(display_frame, bbox)
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
# Draw pose skeleton and angles if available
self._draw_pose_skeleton(display_frame, pose_data)
self._draw_pose_angles(display_frame, pose_data)
return display_frame
def _upscale_silhouette(
@@ -521,6 +699,7 @@ class OpenCVVisualizer:
label: str | None,
confidence: float | None,
fps: float,
pose_data: dict[str, object] | None = None,
) -> bool:
"""Update visualization with new frame data.
@@ -533,6 +712,7 @@ class OpenCVVisualizer:
label: Classification label or None
confidence: Classification confidence [0,1] or None
fps: Current FPS
pose_data: Pose data dictionary or None
Returns:
False if user requested quit (pressed 'q'), True otherwise
@@ -541,7 +721,7 @@ class OpenCVVisualizer:
# Prepare and show main window
main_display = self._prepare_main_frame(
frame, bbox, track_id, fps, label, confidence
frame, bbox, track_id, fps, label, confidence, pose_data
)
cv2.imshow(MAIN_WINDOW, main_display)
+1 -1
View File
@@ -1,7 +1,7 @@
import math
import random
import numpy as np
from utils import get_msg_mgr
from opengait.utils import get_msg_mgr
class CollateFn(object):
+1 -1
View File
@@ -3,7 +3,7 @@ import pickle
import os.path as osp
import torch.utils.data as tordata
import json
from utils import get_msg_mgr
from opengait.utils import get_msg_mgr
class DataSet(tordata.Dataset):
+1 -1
View File
@@ -4,7 +4,7 @@ import torchvision.transforms as T
import cv2
import math
from data import transform as base_transform
from utils import is_list, is_dict, get_valid_args
from opengait.utils import is_list, is_dict, get_valid_args
class NoOperation():
-149
View File
@@ -1,149 +0,0 @@
from __future__ import annotations
import argparse
import logging
import sys
from typing import cast
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
def _positive_float(value: str) -> float:
parsed = float(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("target-fps must be positive")
return parsed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline")
parser.add_argument(
"--source", type=str, required=True, help="Video source path or camera ID"
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="Model checkpoint path"
)
parser.add_argument(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
help="Config file path",
)
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on")
parser.add_argument(
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name"
)
parser.add_argument(
"--window", type=int, default=30, help="Window size for classification"
)
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
parser.add_argument(
"--target-fps",
type=_positive_float,
default=15.0,
help="Target FPS for temporal downsampling before windowing",
)
parser.add_argument(
"--window-mode",
type=str,
choices=["manual", "sliding", "chunked"],
default="manual",
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
)
parser.add_argument(
"--no-target-fps",
action="store_true",
help="Disable temporal downsampling and use all frames",
)
parser.add_argument(
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
)
parser.add_argument(
"--nats-subject", type=str, default="scoliosis.result", help="NATS subject"
)
parser.add_argument(
"--max-frames", type=int, default=None, help="Maximum frames to process"
)
parser.add_argument(
"--preprocess-only", action="store_true", help="Only preprocess silhouettes"
)
parser.add_argument(
"--silhouette-export-path",
type=str,
default=None,
help="Path to export silhouettes",
)
parser.add_argument(
"--silhouette-export-format", type=str, default="pickle", help="Export format"
)
parser.add_argument(
"--silhouette-visualize-dir",
type=str,
default=None,
help="Directory for silhouette visualizations",
)
parser.add_argument(
"--result-export-path", type=str, default=None, help="Path to export results"
)
parser.add_argument(
"--result-export-format", type=str, default="json", help="Result export format"
)
parser.add_argument(
"--visualize", action="store_true", help="Enable real-time visualization"
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
# Validate preprocess-only mode requires silhouette export path
if args.preprocess_only and not args.silhouette_export_path:
print(
"Error: --silhouette-export-path is required when using --preprocess-only",
file=sys.stderr,
)
raise SystemExit(2)
try:
# Import here to avoid circular imports
from .pipeline import validate_runtime_inputs
validate_runtime_inputs(
source=args.source, checkpoint=args.checkpoint, config=args.config
)
effective_stride = resolve_stride(
window=cast(int, args.window),
stride=cast(int, args.stride),
window_mode=cast(WindowMode, args.window_mode),
)
pipeline = ScoliosisPipeline(
source=cast(str, args.source),
checkpoint=cast(str, args.checkpoint),
config=cast(str, args.config),
device=cast(str, args.device),
yolo_model=cast(str, args.yolo_model),
window=cast(int, args.window),
stride=effective_stride,
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
nats_url=cast(str | None, args.nats_url),
nats_subject=cast(str, args.nats_subject),
max_frames=cast(int | None, args.max_frames),
preprocess_only=cast(bool, args.preprocess_only),
silhouette_export_path=cast(str | None, args.silhouette_export_path),
silhouette_export_format=cast(str, args.silhouette_export_format),
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
result_export_path=cast(str | None, args.result_export_path),
result_export_format=cast(str, args.result_export_format),
visualize=cast(bool, args.visualize),
)
raise SystemExit(pipeline.run())
except ValueError as err:
print(f"Error: {err}", file=sys.stderr)
raise SystemExit(2) from err
except RuntimeError as err:
print(f"Runtime error: {err}", file=sys.stderr)
raise SystemExit(1) from err
+1 -1
View File
@@ -1,7 +1,7 @@
import os
from time import strftime, localtime
import numpy as np
from utils import get_msg_mgr, mkdir
from opengait.utils import get_msg_mgr, mkdir
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
from .re_rank import re_ranking
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import numpy as np
import torch.nn.functional as F
from utils import is_tensor
from opengait.utils import is_tensor
def cuda_dist(x, y, metric='euc'):
+1 -1
View File
@@ -4,7 +4,7 @@ import argparse
import torch
import torch.nn as nn
from modeling import models
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
parser = argparse.ArgumentParser(description='Main program for opengait.')
parser.add_argument('--local_rank', type=int, default=0,
+4 -4
View File
@@ -28,11 +28,11 @@ from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from utils import Odict, mkdir, ddp_all_gather
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from opengait.utils import Odict, mkdir, ddp_all_gather
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions
from utils import NoOp
from utils import get_msg_mgr
from opengait.utils import NoOp
from opengait.utils import get_msg_mgr
__all__ = ['BaseModel']
+3 -3
View File
@@ -3,9 +3,9 @@
import torch
import torch.nn as nn
from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict
from utils import get_msg_mgr
from opengait.utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from opengait.utils import Odict
from opengait.utils import get_msg_mgr
class LossAggregator(nn.Module):
+2 -2
View File
@@ -1,9 +1,9 @@
from ctypes import ArgumentError
import torch.nn as nn
import torch
from utils import Odict
from opengait.utils import Odict
import functools
from utils import ddp_all_gather
from opengait.utils import ddp_all_gather
def gather_and_scale_wrapper(func):
@@ -132,7 +132,7 @@ class Post_ResNet9(ResNet):
return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones
class Baseline(nn.Module):
def __init__(self, model_cfg):
+1 -1
View File
@@ -4,7 +4,7 @@ from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc
import torch.optim as optim
from einops import rearrange
from utils import get_valid_args
from opengait.utils import get_valid_args
import warnings
import random
from torchvision.utils import flow_to_image
@@ -161,7 +161,7 @@ class Post_ResNet9(ResNet):
return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones
class GaitBaseFusion_denoise(nn.Module):
def __init__(self, model_cfg):
+1 -1
View File
@@ -6,7 +6,7 @@ from ..base_model import BaseModel
from .gaitgl import GaitGL
from ..modules import GaitAlign
from torchvision.transforms import Resize
from utils import get_valid_args, get_attr_from, is_list_or_tuple
from opengait.utils import get_valid_args, get_attr_from, is_list_or_tuple
import os.path as osp
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
from utils import clones
from opengait.utils import clones
class BasicConv1d(nn.Module):
+2 -2
View File
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
from utils import np2var, list2var, get_valid_args, ddp_all_gather
from opengait.utils import np2var, list2var, get_valid_args, ddp_all_gather
from data.transform import get_transform
from einops import rearrange
@@ -143,7 +143,7 @@ class GaitSSB_Pretrain(BaseModel):
import torch.optim as optim
import numpy as np
from utils import get_valid_args, list2var
from opengait.utils import get_valid_args, list2var
class no_grad(torch.no_grad):
def __init__(self, enable=True):
+1 -1
View File
@@ -782,7 +782,7 @@ from ..modules import BasicBlock2D, BasicBlockP3D
import torch.optim as optim
import os.path as osp
from collections import OrderedDict
from utils import get_valid_args, get_attr_from
from opengait.utils import get_valid_args, get_attr_from
class SwinGait(BaseModel):
def __init__(self, cfgs, training):
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import clones, is_list_or_tuple
from opengait.utils import clones, is_list_or_tuple
from torchvision.ops import RoIAlign
+5 -1
View File
@@ -20,6 +20,7 @@ dependencies = [
"scikit-learn",
"matplotlib",
"cvmmap-client",
"nats-py",
]
[project.optional-dependencies]
@@ -35,7 +36,10 @@ wandb = [
]
[tool.setuptools]
packages = ["opengait"]
[tool.setuptools.packages.find]
where = [".", "opengait-studio"]
include = ["opengait", "opengait.*", "opengait_studio", "opengait_studio.*"]
[dependency-groups]
dev = [
@@ -251,7 +251,7 @@ class TestNatsPublisherIntegration:
except ImportError:
pytest.skip("nats-py not installed")
from opengait.demo.output import NatsPublisher, create_result
from opengait_studio.output import NatsPublisher, create_result
# Create publisher
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
@@ -341,7 +341,7 @@ class TestNatsPublisherIntegration:
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
"""Test that publisher handles missing server gracefully."""
try:
from opengait.demo.output import NatsPublisher
from opengait_studio.output import NatsPublisher
except ImportError:
pytest.skip("output module not available")
@@ -380,7 +380,7 @@ class TestNatsPublisherIntegration:
import asyncio
import nats # type: ignore[import-untyped]
from opengait.demo.output import NatsPublisher, create_result
from opengait_studio.output import NatsPublisher, create_result
except ImportError as e:
pytest.skip(f"Required module not available: {e}")
@@ -15,7 +15,7 @@ from numpy.typing import NDArray
import pytest
import torch
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
@@ -31,7 +31,7 @@ def _device_for_runtime() -> str:
def _run_pipeline_cli(
*args: str, timeout_seconds: int = 120
) -> subprocess.CompletedProcess[str]:
command = [sys.executable, "-m", "opengait.demo", *args]
command = [sys.executable, "-m", "opengait_studio", *args]
return subprocess.run(
command,
cwd=REPO_ROOT,
@@ -728,14 +728,14 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
This is a regression test for the window freeze issue when no person is detected.
The window should refresh every frame to prevent freezing.
"""
from opengait.demo.pipeline import ScoliosisPipeline
from opengait_studio.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
):
# Setup mock detector that returns no detections (causing process_frame to return None)
mock_detector = mock.MagicMock()
@@ -791,16 +791,16 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
from opengait.demo.pipeline import ScoliosisPipeline
from opengait_studio.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person,
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil,
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
):
# Create mock detection result for frames 0-1 (valid detection)
mock_box = mock.MagicMock()
@@ -919,7 +919,7 @@ def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
def test_frame_pacer_emission_count_24_to_15() -> None:
from opengait.demo.pipeline import _FramePacer
from opengait_studio.pipeline import _FramePacer
pacer = _FramePacer(15.0)
interval_ns = int(1_000_000_000 / 24)
@@ -928,7 +928,7 @@ def test_frame_pacer_emission_count_24_to_15() -> None:
def test_frame_pacer_requires_positive_target_fps() -> None:
from opengait.demo.pipeline import _FramePacer
from opengait_studio.pipeline import _FramePacer
with pytest.raises(ValueError, match="target_fps must be positive"):
_FramePacer(0.0)
@@ -950,6 +950,6 @@ def test_resolve_stride_modes(
mode: Literal["manual", "sliding", "chunked"],
expected: int,
) -> None:
from opengait.demo.pipeline import resolve_stride
from opengait_studio.pipeline import resolve_stride
assert resolve_stride(window, stride, mode) == expected
@@ -8,7 +8,7 @@ import pytest
from beartype.roar import BeartypeCallHintParamViolation
from jaxtyping import TypeCheckError
from opengait.demo.preprocess import mask_to_silhouette
from opengait_studio.preprocess import mask_to_silhouette
class TestMaskToSilhouette:
@@ -18,7 +18,7 @@ import torch
from torch import Tensor
if TYPE_CHECKING:
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
# Constants for test configuration
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@@ -27,7 +27,7 @@ CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@pytest.fixture
def demo() -> "ScoNetDemo":
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
return ScoNetDemo(
cfg_path=str(CONFIG_PATH),
@@ -71,7 +71,7 @@ class TestScoNetDemoConstruction:
def test_construction_from_config_no_checkpoint(self) -> None:
"""Test construction with config only, no checkpoint."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
@@ -87,7 +87,7 @@ class TestScoNetDemoConstruction:
def test_construction_with_relative_path(self) -> None:
"""Test construction handles relative config path correctly."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
@@ -100,7 +100,7 @@ class TestScoNetDemoConstruction:
def test_construction_invalid_config_raises(self) -> None:
"""Test construction raises with invalid config path."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
with pytest.raises((FileNotFoundError, TypeError)):
_ = ScoNetDemo(
@@ -180,7 +180,7 @@ class TestScoNetDemoPredict:
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns (str, float) tuple with valid label."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
@@ -217,7 +217,7 @@ class TestScoNetDemoNoDDP:
def test_no_distributed_init_in_construction(self) -> None:
"""Test that construction does not call torch.distributed."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
with patch("torch.distributed.is_initialized") as mock_is_init:
with patch("torch.distributed.init_process_group") as mock_init_pg:
@@ -327,14 +327,14 @@ class TestScoNetDemoLabelMap:
def test_label_map_has_three_classes(self) -> None:
"""Test LABEL_MAP has exactly 3 classes."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
assert len(ScoNetDemo.LABEL_MAP) == 3
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
def test_label_map_values_are_valid_strings(self) -> None:
"""Test LABEL_MAP values are valid non-empty strings."""
from opengait.demo.sconet_demo import ScoNetDemo
from opengait_studio.sconet_demo import ScoNetDemo
for value in ScoNetDemo.LABEL_MAP.values():
assert isinstance(value, str)
@@ -7,14 +7,14 @@ from unittest import mock
import numpy as np
import pytest
from opengait.demo.input import create_source
from opengait.demo.visualizer import (
from opengait_studio.input import create_source
from opengait_studio.visualizer import (
DISPLAY_HEIGHT,
DISPLAY_WIDTH,
ImageArray,
OpenCVVisualizer,
)
from opengait.demo.window import select_person
from opengait_studio.window import select_person
REPO_ROOT = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
@@ -8,7 +8,7 @@ import pytest
import torch
from numpy.typing import NDArray
from opengait.demo.window import SilhouetteWindow, select_person
from opengait_studio.window import SilhouetteWindow, select_person
class TestSilhouetteWindow:
Generated
+9 -7
View File
@@ -614,7 +614,7 @@ name = "cuda-bindings"
version = "12.9.4"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "cuda-pathfinder", marker = "sys_platform == 'linux'" },
{ name = "cuda-pathfinder", marker = "sys_platform != 'win32'" },
]
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" },
@@ -1611,7 +1611,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" },
]
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
@@ -1622,7 +1622,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
]
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
@@ -1649,9 +1649,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" },
{ name = "nvidia-cusparse-cu12", marker = "sys_platform != 'win32'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
]
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
@@ -1662,7 +1662,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
]
wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
@@ -1737,6 +1737,7 @@ dependencies = [
{ name = "imageio" },
{ name = "kornia" },
{ name = "matplotlib" },
{ name = "nats-py" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.4.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" },
{ name = "opencv-python" },
@@ -1780,6 +1781,7 @@ requires-dist = [
{ name = "imageio" },
{ name = "kornia" },
{ name = "matplotlib" },
{ name = "nats-py" },
{ name = "numpy" },
{ name = "opencv-python" },
{ name = "pillow" },
+1 -1
View File
@@ -1,4 +1,4 @@
uv run python -m opengait.demo \
uv run python -m opengait_studio \
--source "cvmmap://camera_5602" \
--checkpoint "ckpt/ScoNet-20000.pt" \
--config "configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml" \