feat: extract opengait_studio monorepo module

Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
2026-03-03 17:16:17 +08:00
parent 5c6bef1ca1
commit 00fcda4fe3
39 changed files with 359 additions and 270 deletions
+1 -1
View File
@@ -57,7 +57,7 @@ CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch \
Demo CLI entry: Demo CLI entry:
```bash ```bash
uv run python -m opengait.demo --help uv run python -m opengait_studio --help
``` ```
## DDP Constraints (Important) ## DDP Constraints (Important)
@@ -0,0 +1,7 @@
from __future__ import annotations
from .pipeline import main
if __name__ == "__main__":
main()
@@ -11,6 +11,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import nats
import sys import sys
import threading import threading
import time import time
@@ -81,7 +82,7 @@ class ConsolePublisher:
json_line = json.dumps(result, ensure_ascii=False, default=str) json_line = json.dumps(result, ensure_ascii=False, default=str)
_ = self._output.write(json_line + "\n") _ = self._output.write(json_line + "\n")
self._output.flush() self._output.flush()
except Exception as e: except (OSError, ValueError, TypeError) as e:
logger.warning(f"Failed to publish to console: {e}") logger.warning(f"Failed to publish to console: {e}")
def close(self) -> None: def close(self) -> None:
@@ -173,7 +174,7 @@ class NatsPublisher:
self._thread = threading.Thread(target=run_loop, daemon=True) self._thread = threading.Thread(target=run_loop, daemon=True)
self._thread.start() self._thread.start()
return True return True
except Exception as e: except (RuntimeError, OSError) as e:
logger.warning(f"Failed to start background event loop: {e}") logger.warning(f"Failed to start background event loop: {e}")
return False return False
@@ -204,7 +205,6 @@ class NatsPublisher:
return False return False
try: try:
import nats
async def _connect() -> _NatsClient: async def _connect() -> _NatsClient:
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType] nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
@@ -219,12 +219,7 @@ class NatsPublisher:
self._connected = True self._connected = True
logger.info(f"Connected to NATS at {self._nats_url}") logger.info(f"Connected to NATS at {self._nats_url}")
return True return True
except ImportError: except (RuntimeError, OSError, TimeoutError) as e:
logger.warning(
"nats-py package not installed. Install with: pip install nats-py"
)
return False
except Exception as e:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False return False
@@ -254,15 +249,27 @@ class NatsPublisher:
_ = await self._nc.publish(self._subject, payload) _ = await self._nc.publish(self._subject, payload)
_ = await self._nc.flush() _ = await self._nc.flush()
# Run publish in background loop
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(
_publish(), _publish(),
self._loop, # pyright: ignore[reportArgumentType] self._loop, # pyright: ignore[reportArgumentType]
) )
future.result(timeout=5.0) # Wait for publish to complete
except Exception as e: def _on_done(publish_future: object) -> None:
logger.warning(f"Failed to publish to NATS: {e}") fut = cast("asyncio.Future[None]", publish_future)
self._connected = False # Mark for reconnection on next publish try:
exc = fut.exception()
except (RuntimeError, OSError) as callback_error:
logger.warning(f"NATS publish callback failed: {callback_error}")
self._connected = False
return
if exc is not None:
logger.warning(f"Failed to publish to NATS: {exc}")
self._connected = False
future.add_done_callback(_on_done)
except (RuntimeError, OSError, ValueError, TypeError) as e:
logger.warning(f"Failed to schedule NATS publish: {e}")
self._connected = False
def close(self) -> None: def close(self) -> None:
"""Close NATS connection.""" """Close NATS connection."""
@@ -279,7 +286,7 @@ class NatsPublisher:
self._loop, self._loop,
) )
future.result(timeout=5.0) future.result(timeout=5.0)
except Exception as e: except (RuntimeError, OSError, TimeoutError) as e:
logger.debug(f"Error closing NATS connection: {e}") logger.debug(f"Error closing NATS connection: {e}")
finally: finally:
self._nc = None self._nc = None
@@ -93,6 +93,19 @@ class _SelectedSilhouette(TypedDict):
track_id: int track_id: int
class _VizPayload(TypedDict, total=False):
result: DemoResult
mask_raw: UInt8[ndarray, "h w"] | None
bbox: BBoxXYXY | None
bbox_mask: BBoxXYXY | None
silhouette: Float[ndarray, "64 44"] | None
segmentation_input: NDArray[np.float32] | None
track_id: int
label: str | None
confidence: float | None
pose: dict[str, object] | None
class _FramePacer: class _FramePacer:
_interval_ns: int _interval_ns: int
_next_emit_ns: int | None _next_emit_ns: int | None
@@ -131,7 +144,7 @@ class ScoliosisPipeline:
_result_export_format: str _result_export_format: str
_result_buffer: list[DemoResult] _result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None _visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None _last_viz_payload: _VizPayload | None
_frame_pacer: _FramePacer | None _frame_pacer: _FramePacer | None
def __init__( def __init__(
@@ -208,10 +221,9 @@ class ScoliosisPipeline:
return time.monotonic_ns() return time.monotonic_ns()
@staticmethod @staticmethod
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]: def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]:
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype( mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType]
np.uint8 binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8)
)
return cast(UInt8[ndarray, "h w"], binary) return cast(UInt8[ndarray, "h w"], binary)
def _first_result(self, detections: object) -> _DetectionResultsLike | None: def _first_result(self, detections: object) -> _DetectionResultsLike | None:
@@ -294,7 +306,7 @@ class ScoliosisPipeline:
self, self,
frame: UInt8[ndarray, "h w c"], frame: UInt8[ndarray, "h w c"],
metadata: dict[str, object], metadata: dict[str, object],
) -> dict[str, object] | None: ) -> _VizPayload | None:
frame_idx = self._extract_int(metadata, "frame_count", fallback=0) frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
timestamp_ns = self._extract_timestamp(metadata) timestamp_ns = self._extract_timestamp(metadata)
@@ -323,8 +335,8 @@ class ScoliosisPipeline:
bbox = selected["bbox_frame"] bbox = selected["bbox_frame"]
bbox_mask = selected["bbox_mask"] bbox_mask = selected["bbox_mask"]
track_id = selected["track_id"] track_id = selected["track_id"]
pose_data = None
# Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only: if self._silhouette_export_path is not None or self._preprocess_only:
self._silhouette_buffer.append( self._silhouette_buffer.append(
{ {
@@ -350,7 +362,9 @@ class ScoliosisPipeline:
"track_id": track_id, "track_id": track_id,
"label": None, "label": None,
"confidence": None, "confidence": None,
"pose": pose_data,
} }
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
if self._frame_pacer is not None and not self._frame_pacer.should_emit( if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns timestamp_ns
@@ -364,9 +378,8 @@ class ScoliosisPipeline:
"track_id": track_id, "track_id": track_id,
"label": None, "label": None,
"confidence": None, "confidence": None,
"pose": pose_data,
} }
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
segmentation_input = self._window.buffered_silhouettes segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify(): if not self._window.should_classify():
@@ -380,8 +393,8 @@ class ScoliosisPipeline:
"track_id": track_id, "track_id": track_id,
"label": None, "label": None,
"confidence": None, "confidence": None,
"pose": pose_data,
} }
window_tensor = self._window.get_tensor(device=self._device) window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast( label, confidence = cast(
tuple[str, float], tuple[str, float],
@@ -415,6 +428,7 @@ class ScoliosisPipeline:
"track_id": track_id, "track_id": track_id,
"label": label, "label": label,
"confidence": confidence, "confidence": confidence,
"pose": pose_data,
} }
def run(self) -> int: def run(self) -> int:
@@ -445,7 +459,7 @@ class ScoliosisPipeline:
viz_payload = None viz_payload = None
try: try:
viz_payload = self.process_frame(frame_u8, metadata) viz_payload = self.process_frame(frame_u8, metadata)
except Exception as frame_error: except (RuntimeError, ValueError, TypeError, OSError) as frame_error:
logger.warning( logger.warning(
"Skipping frame %d due to processing error: %s", "Skipping frame %d due to processing error: %s",
frame_idx, frame_idx,
@@ -457,8 +471,8 @@ class ScoliosisPipeline:
# Cache valid payload for no-detection frames # Cache valid payload for no-detection frames
if viz_payload is not None: if viz_payload is not None:
# Cache a copy to prevent mutation of original data # Cache a copy to prevent mutation of original data
viz_payload_dict = cast(dict[str, object], viz_payload) viz_payload_dict = cast(_VizPayload, viz_payload)
cached: dict[str, object] = {} cached: _VizPayload = {}
for k, v in viz_payload_dict.items(): for k, v in viz_payload_dict.items():
copy_method = cast( copy_method = cast(
Callable[[], object] | None, getattr(v, "copy", None) Callable[[], object] | None, getattr(v, "copy", None)
@@ -477,12 +491,12 @@ class ScoliosisPipeline:
viz_data["bbox_mask"] = None viz_data["bbox_mask"] = None
viz_data["label"] = None viz_data["label"] = None
viz_data["confidence"] = None viz_data["confidence"] = None
viz_data["pose"] = None
else: else:
viz_data = None viz_data = None
if viz_data is not None: if viz_data is not None:
# Cast viz_payload to dict for type checking # Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_data) viz_dict = cast(_VizPayload, viz_data)
mask_raw_obj = viz_dict.get("mask_raw") mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox") bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask") bbox_mask_obj = viz_dict.get("bbox_mask")
@@ -492,8 +506,8 @@ class ScoliosisPipeline:
track_id = track_id_val if isinstance(track_id_val, int) else 0 track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label") label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence") confidence_obj = viz_dict.get("confidence")
pose_obj = viz_dict.get("pose")
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj) bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj) bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
@@ -504,6 +518,7 @@ class ScoliosisPipeline:
) )
label = cast(str | None, label_obj) label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj) confidence = cast(float | None, confidence_obj)
pose_data = cast(dict[str, object] | None, pose_obj)
else: else:
# No detection and no cache - use default values # No detection and no cache - use default values
mask_raw = None mask_raw = None
@@ -514,19 +529,37 @@ class ScoliosisPipeline:
segmentation_input = None segmentation_input = None
label = None label = None
confidence = None confidence = None
pose_data = None
keep_running = self._visualizer.update( # Try keyword arg for pose_data (backward compatible with old signatures)
frame_u8, try:
bbox, keep_running = self._visualizer.update(
bbox_mask, frame_u8,
track_id, bbox,
mask_raw, bbox_mask,
silhouette, track_id,
segmentation_input, mask_raw,
label, silhouette,
confidence, segmentation_input,
ema_fps, label,
) confidence,
ema_fps,
pose_data=pose_data,
)
except TypeError:
# Fallback for legacy visualizers that don't accept pose_data
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
)
if not keep_running: if not keep_running:
logger.info("Visualization closed by user.") logger.info("Visualization closed by user.")
break break
@@ -635,7 +668,7 @@ class ScoliosisPipeline:
frames.append(item["frame"]) frames.append(item["frame"])
track_ids.append(item["track_id"]) track_ids.append(item["track_id"])
timestamps.append(item["timestamp_ns"]) timestamps.append(item["timestamp_ns"])
silhouette_array = cast(ndarray, item["silhouette"]) silhouette_array = cast(NDArray[np.float32], item["silhouette"])
silhouettes.append(silhouette_array.flatten().tolist()) silhouettes.append(silhouette_array.flatten().tolist())
table = pa.table( table = pa.table(
@@ -830,6 +863,12 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
default=None, default=None,
help="Directory to save silhouette PNG visualizations.", help="Directory to save silhouette PNG visualizations.",
) )
@click.option(
"--visualize",
is_flag=True,
default=False,
help="Enable real-time visualization.",
)
def main( def main(
source: str, source: str,
checkpoint: str, checkpoint: str,
@@ -839,7 +878,7 @@ def main(
window: int, window: int,
stride: int, stride: int,
window_mode: str, window_mode: str,
target_fps: float | None, target_fps: float,
no_target_fps: bool, no_target_fps: bool,
nats_url: str | None, nats_url: str | None,
nats_subject: str, nats_subject: str,
@@ -850,7 +889,10 @@ def main(
result_export_path: str | None, result_export_path: str | None,
result_export_format: str, result_export_format: str,
silhouette_visualize_dir: str | None, silhouette_visualize_dir: str | None,
visualize: bool,
) -> None: ) -> None:
# Resolve effective target_fps: respect --no-target_fps to disable pacing
effective_target_fps = None if no_target_fps else target_fps
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s", format="%(asctime)s %(levelname)s %(name)s: %(message)s",
@@ -884,7 +926,6 @@ def main(
yolo_model=yolo_model, yolo_model=yolo_model,
window=window, window=window,
stride=effective_stride, stride=effective_stride,
target_fps=None if no_target_fps else target_fps,
nats_url=nats_url, nats_url=nats_url,
nats_subject=nats_subject, nats_subject=nats_subject,
max_frames=max_frames, max_frames=max_frames,
@@ -894,6 +935,8 @@ def main(
silhouette_visualize_dir=silhouette_visualize_dir, silhouette_visualize_dir=silhouette_visualize_dir,
result_export_path=result_export_path, result_export_path=result_export_path,
result_export_format=result_export_format, result_export_format=result_export_format,
visualize=visualize,
target_fps=effective_target_fps,
) )
raise SystemExit(pipeline.run()) raise SystemExit(pipeline.run())
except ValueError as err: except ValueError as err:
@@ -2,7 +2,6 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
import sys
from typing import ClassVar, Protocol, cast, override from typing import ClassVar, Protocol, cast, override
import torch import torch
@@ -13,10 +12,6 @@ from jaxtyping import Float
import jaxtyping import jaxtyping
from torch import Tensor from torch import Tensor
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
from opengait.modeling.backbones.resnet import ResNet9 from opengait.modeling.backbones.resnet import ResNet9
from opengait.modeling.modules import ( from opengait.modeling.modules import (
HorizontalPoolingPyramid, HorizontalPoolingPyramid,
@@ -41,10 +41,54 @@ COLOR_BLACK = (0, 0, 0)
COLOR_DARK_GRAY = (56, 56, 56) COLOR_DARK_GRAY = (56, 56, 56)
COLOR_RED = (0, 0, 255) COLOR_RED = (0, 0, 255)
COLOR_YELLOW = (0, 255, 255) COLOR_YELLOW = (0, 255, 255)
# Type alias for image arrays (NDArray or cv2.Mat) # Type alias for image arrays (NDArray or cv2.Mat)
COLOR_CYAN = (255, 255, 0)
COLOR_ORANGE = (0, 165, 255)
COLOR_MAGENTA = (255, 0, 255)
ImageArray = NDArray[np.uint8] ImageArray = NDArray[np.uint8]
# COCO-format skeleton connections (17 keypoints)
# Connections are pairs of keypoint indices
SKELETON_CONNECTIONS: list[tuple[int, int]] = [
(0, 1), # nose -> left_eye
(0, 2), # nose -> right_eye
(1, 3), # left_eye -> left_ear
(2, 4), # right_eye -> right_ear
(5, 6), # left_shoulder -> right_shoulder
(5, 7), # left_shoulder -> left_elbow
(7, 9), # left_elbow -> left_wrist
(6, 8), # right_shoulder -> right_elbow
(8, 10), # right_elbow -> right_wrist
(11, 12), # left_hip -> right_hip
(5, 11), # left_shoulder -> left_hip
(6, 12), # right_shoulder -> right_hip
(11, 13), # left_hip -> left_knee
(13, 15), # left_knee -> left_ankle
(12, 14), # right_hip -> right_knee
(14, 16), # right_knee -> right_ankle
]
# Keypoint names for COCO format (17 keypoints)
KEYPOINT_NAMES: list[str] = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
# Joints where angles are typically calculated (for scoliosis/ gait analysis)
ANGLE_JOINTS: list[tuple[int, int, int]] = [
(5, 7, 9), # left_shoulder -> left_elbow -> left_wrist
(6, 8, 10), # right_shoulder -> right_elbow -> right_wrist
(7, 5, 11), # left_elbow -> left_shoulder -> left_hip
(8, 6, 12), # right_elbow -> right_shoulder -> right_hip
(5, 11, 13), # left_shoulder -> left_hip -> left_knee
(6, 12, 14), # right_shoulder -> right_hip -> right_knee
(11, 13, 15),# left_hip -> left_knee -> left_ankle
(12, 14, 16),# right_hip -> right_knee -> right_ankle
]
class OpenCVVisualizer: class OpenCVVisualizer:
def __init__(self) -> None: def __init__(self) -> None:
@@ -149,6 +193,134 @@ class OpenCVVisualizer:
thickness, thickness,
) )
def _draw_pose_skeleton(
self,
frame: ImageArray,
pose_data: dict[str, object] | None,
) -> None:
"""Draw pose skeleton on frame.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
pose_data: Pose data dictionary from Sports2D or similar
Expected format: {'keypoints': [[x1, y1], [x2, y2], ...],
'confidence': [c1, c2, ...],
'angles': {'joint_name': angle, ...}}
"""
if pose_data is None:
return
keypoints_obj = pose_data.get('keypoints')
if keypoints_obj is None:
return
# Convert keypoints to numpy array
keypoints = np.asarray(keypoints_obj, dtype=np.float32)
if keypoints.size == 0:
return
h, w = frame.shape[:2]
# Get confidence scores if available
confidence_obj = pose_data.get('confidence')
confidences = (
np.asarray(confidence_obj, dtype=np.float32)
if confidence_obj is not None
else np.ones(len(keypoints), dtype=np.float32)
)
# Draw skeleton connections
for connection in SKELETON_CONNECTIONS:
idx1, idx2 = connection
if idx1 < len(keypoints) and idx2 < len(keypoints):
# Check confidence threshold (0.3)
if confidences[idx1] > 0.3 and confidences[idx2] > 0.3:
pt1 = (int(keypoints[idx1][0]), int(keypoints[idx1][1]))
pt2 = (int(keypoints[idx2][0]), int(keypoints[idx2][1]))
# Clip to frame bounds
pt1 = (max(0, min(w - 1, pt1[0])), max(0, min(h - 1, pt1[1])))
pt2 = (max(0, min(w - 1, pt2[0])), max(0, min(h - 1, pt2[1])))
_ = cv2.line(frame, pt1, pt2, COLOR_CYAN, 2)
# Draw keypoints
for i, (kp, conf) in enumerate(zip(keypoints, confidences)):
if conf > 0.3 and i < len(keypoints):
x, y = int(kp[0]), int(kp[1])
# Clip to frame bounds
x = max(0, min(w - 1, x))
y = max(0, min(h - 1, y))
# Draw keypoint as circle
_ = cv2.circle(frame, (x, y), 4, COLOR_MAGENTA, -1)
_ = cv2.circle(frame, (x, y), 4, COLOR_WHITE, 1)
def _draw_pose_angles(
self,
frame: ImageArray,
pose_data: dict[str, object] | None,
) -> None:
"""Draw pose angles as text overlay.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
pose_data: Pose data dictionary with 'angles' key
"""
if pose_data is None:
return
angles_obj = pose_data.get('angles')
if angles_obj is None:
return
angles = cast(dict[str, float], angles_obj)
if not angles:
return
# Draw angles in top-right corner
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.45
thickness = 1
line_height = 20
margin = 10
h, w = frame.shape[:2]
# Filter and format angles
angle_texts: list[tuple[str, float]] = []
for name, angle in angles.items():
# Only show angles that are reasonable (0-180 degrees)
if 0 <= angle <= 180:
angle_texts.append((str(name), float(angle)))
# Sort by name for consistent display
angle_texts.sort(key=lambda x: x[0])
# Draw from top-right
for i, (name, angle) in enumerate(angle_texts[:8]): # Limit to 8 angles
text = f"{name}: {angle:.1f}"
(text_width, text_height), _ = cv2.getTextSize(
text, font, font_scale, thickness
)
x_pos = w - margin - text_width - 10
y_pos = margin + (i + 1) * line_height
# Draw background rectangle
_ = cv2.rectangle(
frame,
(x_pos - 4, y_pos - text_height - 4),
(x_pos + text_width + 4, y_pos + 4),
COLOR_BLACK,
-1,
)
# Draw text in orange
_ = cv2.putText(
frame,
text,
(x_pos, y_pos),
font,
font_scale,
COLOR_ORANGE,
thickness,
)
def _prepare_main_frame( def _prepare_main_frame(
self, self,
frame: ImageArray, frame: ImageArray,
@@ -157,6 +329,7 @@ class OpenCVVisualizer:
fps: float, fps: float,
label: str | None, label: str | None,
confidence: float | None, confidence: float | None,
pose_data: dict[str, object] | None = None,
) -> ImageArray: ) -> ImageArray:
"""Prepare main display frame with bbox and text overlay. """Prepare main display frame with bbox and text overlay.
@@ -167,6 +340,7 @@ class OpenCVVisualizer:
fps: Current FPS fps: Current FPS
label: Classification label or None label: Classification label or None
confidence: Classification confidence or None confidence: Classification confidence or None
pose_data: Pose data dictionary or None
Returns: Returns:
Processed frame ready for display Processed frame ready for display
@@ -187,6 +361,10 @@ class OpenCVVisualizer:
self._draw_bbox(display_frame, bbox) self._draw_bbox(display_frame, bbox)
self._draw_text_overlay(display_frame, track_id, fps, label, confidence) self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
# Draw pose skeleton and angles if available
self._draw_pose_skeleton(display_frame, pose_data)
self._draw_pose_angles(display_frame, pose_data)
return display_frame return display_frame
def _upscale_silhouette( def _upscale_silhouette(
@@ -521,6 +699,7 @@ class OpenCVVisualizer:
label: str | None, label: str | None,
confidence: float | None, confidence: float | None,
fps: float, fps: float,
pose_data: dict[str, object] | None = None,
) -> bool: ) -> bool:
"""Update visualization with new frame data. """Update visualization with new frame data.
@@ -533,6 +712,7 @@ class OpenCVVisualizer:
label: Classification label or None label: Classification label or None
confidence: Classification confidence [0,1] or None confidence: Classification confidence [0,1] or None
fps: Current FPS fps: Current FPS
pose_data: Pose data dictionary or None
Returns: Returns:
False if user requested quit (pressed 'q'), True otherwise False if user requested quit (pressed 'q'), True otherwise
@@ -541,7 +721,7 @@ class OpenCVVisualizer:
# Prepare and show main window # Prepare and show main window
main_display = self._prepare_main_frame( main_display = self._prepare_main_frame(
frame, bbox, track_id, fps, label, confidence frame, bbox, track_id, fps, label, confidence, pose_data
) )
cv2.imshow(MAIN_WINDOW, main_display) cv2.imshow(MAIN_WINDOW, main_display)
+1 -1
View File
@@ -1,7 +1,7 @@
import math import math
import random import random
import numpy as np import numpy as np
from utils import get_msg_mgr from opengait.utils import get_msg_mgr
class CollateFn(object): class CollateFn(object):
+1 -1
View File
@@ -3,7 +3,7 @@ import pickle
import os.path as osp import os.path as osp
import torch.utils.data as tordata import torch.utils.data as tordata
import json import json
from utils import get_msg_mgr from opengait.utils import get_msg_mgr
class DataSet(tordata.Dataset): class DataSet(tordata.Dataset):
+1 -1
View File
@@ -4,7 +4,7 @@ import torchvision.transforms as T
import cv2 import cv2
import math import math
from data import transform as base_transform from data import transform as base_transform
from utils import is_list, is_dict, get_valid_args from opengait.utils import is_list, is_dict, get_valid_args
class NoOperation(): class NoOperation():
-149
View File
@@ -1,149 +0,0 @@
from __future__ import annotations
import argparse
import logging
import sys
from typing import cast
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
def _positive_float(value: str) -> float:
parsed = float(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("target-fps must be positive")
return parsed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline")
parser.add_argument(
"--source", type=str, required=True, help="Video source path or camera ID"
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="Model checkpoint path"
)
parser.add_argument(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
help="Config file path",
)
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on")
parser.add_argument(
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name"
)
parser.add_argument(
"--window", type=int, default=30, help="Window size for classification"
)
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
parser.add_argument(
"--target-fps",
type=_positive_float,
default=15.0,
help="Target FPS for temporal downsampling before windowing",
)
parser.add_argument(
"--window-mode",
type=str,
choices=["manual", "sliding", "chunked"],
default="manual",
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
)
parser.add_argument(
"--no-target-fps",
action="store_true",
help="Disable temporal downsampling and use all frames",
)
parser.add_argument(
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
)
parser.add_argument(
"--nats-subject", type=str, default="scoliosis.result", help="NATS subject"
)
parser.add_argument(
"--max-frames", type=int, default=None, help="Maximum frames to process"
)
parser.add_argument(
"--preprocess-only", action="store_true", help="Only preprocess silhouettes"
)
parser.add_argument(
"--silhouette-export-path",
type=str,
default=None,
help="Path to export silhouettes",
)
parser.add_argument(
"--silhouette-export-format", type=str, default="pickle", help="Export format"
)
parser.add_argument(
"--silhouette-visualize-dir",
type=str,
default=None,
help="Directory for silhouette visualizations",
)
parser.add_argument(
"--result-export-path", type=str, default=None, help="Path to export results"
)
parser.add_argument(
"--result-export-format", type=str, default="json", help="Result export format"
)
parser.add_argument(
"--visualize", action="store_true", help="Enable real-time visualization"
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
# Validate preprocess-only mode requires silhouette export path
if args.preprocess_only and not args.silhouette_export_path:
print(
"Error: --silhouette-export-path is required when using --preprocess-only",
file=sys.stderr,
)
raise SystemExit(2)
try:
# Import here to avoid circular imports
from .pipeline import validate_runtime_inputs
validate_runtime_inputs(
source=args.source, checkpoint=args.checkpoint, config=args.config
)
effective_stride = resolve_stride(
window=cast(int, args.window),
stride=cast(int, args.stride),
window_mode=cast(WindowMode, args.window_mode),
)
pipeline = ScoliosisPipeline(
source=cast(str, args.source),
checkpoint=cast(str, args.checkpoint),
config=cast(str, args.config),
device=cast(str, args.device),
yolo_model=cast(str, args.yolo_model),
window=cast(int, args.window),
stride=effective_stride,
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
nats_url=cast(str | None, args.nats_url),
nats_subject=cast(str, args.nats_subject),
max_frames=cast(int | None, args.max_frames),
preprocess_only=cast(bool, args.preprocess_only),
silhouette_export_path=cast(str | None, args.silhouette_export_path),
silhouette_export_format=cast(str, args.silhouette_export_format),
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
result_export_path=cast(str | None, args.result_export_path),
result_export_format=cast(str, args.result_export_format),
visualize=cast(bool, args.visualize),
)
raise SystemExit(pipeline.run())
except ValueError as err:
print(f"Error: {err}", file=sys.stderr)
raise SystemExit(2) from err
except RuntimeError as err:
print(f"Runtime error: {err}", file=sys.stderr)
raise SystemExit(1) from err
+1 -1
View File
@@ -1,7 +1,7 @@
import os import os
from time import strftime, localtime from time import strftime, localtime
import numpy as np import numpy as np
from utils import get_msg_mgr, mkdir from opengait.utils import get_msg_mgr, mkdir
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
from .re_rank import re_ranking from .re_rank import re_ranking
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from utils import is_tensor from opengait.utils import is_tensor
def cuda_dist(x, y, metric='euc'): def cuda_dist(x, y, metric='euc'):
+1 -1
View File
@@ -4,7 +4,7 @@ import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
from modeling import models from modeling import models
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
parser = argparse.ArgumentParser(description='Main program for opengait.') parser = argparse.ArgumentParser(description='Main program for opengait.')
parser.add_argument('--local_rank', type=int, default=0, parser.add_argument('--local_rank', type=int, default=0,
+4 -4
View File
@@ -28,11 +28,11 @@ from data.transform import get_transform
from data.collate_fn import CollateFn from data.collate_fn import CollateFn
from data.dataset import DataSet from data.dataset import DataSet
import data.sampler as Samplers import data.sampler as Samplers
from utils import Odict, mkdir, ddp_all_gather from opengait.utils import Odict, mkdir, ddp_all_gather
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions from evaluation import evaluator as eval_functions
from utils import NoOp from opengait.utils import NoOp
from utils import get_msg_mgr from opengait.utils import get_msg_mgr
__all__ = ['BaseModel'] __all__ = ['BaseModel']
+3 -3
View File
@@ -3,9 +3,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from . import losses from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module from opengait.utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict from opengait.utils import Odict
from utils import get_msg_mgr from opengait.utils import get_msg_mgr
class LossAggregator(nn.Module): class LossAggregator(nn.Module):
+2 -2
View File
@@ -1,9 +1,9 @@
from ctypes import ArgumentError from ctypes import ArgumentError
import torch.nn as nn import torch.nn as nn
import torch import torch
from utils import Odict from opengait.utils import Odict
import functools import functools
from utils import ddp_all_gather from opengait.utils import ddp_all_gather
def gather_and_scale_wrapper(func): def gather_and_scale_wrapper(func):
@@ -132,7 +132,7 @@ class Post_ResNet9(ResNet):
return x return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones from ... import backbones
class Baseline(nn.Module): class Baseline(nn.Module):
def __init__(self, model_cfg): def __init__(self, model_cfg):
+1 -1
View File
@@ -4,7 +4,7 @@ from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc
import torch.optim as optim import torch.optim as optim
from einops import rearrange from einops import rearrange
from utils import get_valid_args from opengait.utils import get_valid_args
import warnings import warnings
import random import random
from torchvision.utils import flow_to_image from torchvision.utils import flow_to_image
@@ -161,7 +161,7 @@ class Post_ResNet9(ResNet):
return x return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones from ... import backbones
class GaitBaseFusion_denoise(nn.Module): class GaitBaseFusion_denoise(nn.Module):
def __init__(self, model_cfg): def __init__(self, model_cfg):
+1 -1
View File
@@ -6,7 +6,7 @@ from ..base_model import BaseModel
from .gaitgl import GaitGL from .gaitgl import GaitGL
from ..modules import GaitAlign from ..modules import GaitAlign
from torchvision.transforms import Resize from torchvision.transforms import Resize
from utils import get_valid_args, get_attr_from, is_list_or_tuple from opengait.utils import get_valid_args, get_attr_from, is_list_or_tuple
import os.path as osp import os.path as osp
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..base_model import BaseModel from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
from utils import clones from opengait.utils import clones
class BasicConv1d(nn.Module): class BasicConv1d(nn.Module):
+2 -2
View File
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from ..base_model import BaseModel from ..base_model import BaseModel
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
from utils import np2var, list2var, get_valid_args, ddp_all_gather from opengait.utils import np2var, list2var, get_valid_args, ddp_all_gather
from data.transform import get_transform from data.transform import get_transform
from einops import rearrange from einops import rearrange
@@ -143,7 +143,7 @@ class GaitSSB_Pretrain(BaseModel):
import torch.optim as optim import torch.optim as optim
import numpy as np import numpy as np
from utils import get_valid_args, list2var from opengait.utils import get_valid_args, list2var
class no_grad(torch.no_grad): class no_grad(torch.no_grad):
def __init__(self, enable=True): def __init__(self, enable=True):
+1 -1
View File
@@ -782,7 +782,7 @@ from ..modules import BasicBlock2D, BasicBlockP3D
import torch.optim as optim import torch.optim as optim
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
from utils import get_valid_args, get_attr_from from opengait.utils import get_valid_args, get_attr_from
class SwinGait(BaseModel): class SwinGait(BaseModel):
def __init__(self, cfgs, training): def __init__(self, cfgs, training):
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from utils import clones, is_list_or_tuple from opengait.utils import clones, is_list_or_tuple
from torchvision.ops import RoIAlign from torchvision.ops import RoIAlign
+5 -1
View File
@@ -20,6 +20,7 @@ dependencies = [
"scikit-learn", "scikit-learn",
"matplotlib", "matplotlib",
"cvmmap-client", "cvmmap-client",
"nats-py",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -35,7 +36,10 @@ wandb = [
] ]
[tool.setuptools] [tool.setuptools]
packages = ["opengait"]
[tool.setuptools.packages.find]
where = [".", "opengait-studio"]
include = ["opengait", "opengait.*", "opengait_studio", "opengait_studio.*"]
[dependency-groups] [dependency-groups]
dev = [ dev = [
@@ -251,7 +251,7 @@ class TestNatsPublisherIntegration:
except ImportError: except ImportError:
pytest.skip("nats-py not installed") pytest.skip("nats-py not installed")
from opengait.demo.output import NatsPublisher, create_result from opengait_studio.output import NatsPublisher, create_result
# Create publisher # Create publisher
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT) publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
@@ -341,7 +341,7 @@ class TestNatsPublisherIntegration:
def test_nats_publisher_graceful_when_server_unavailable(self) -> None: def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
"""Test that publisher handles missing server gracefully.""" """Test that publisher handles missing server gracefully."""
try: try:
from opengait.demo.output import NatsPublisher from opengait_studio.output import NatsPublisher
except ImportError: except ImportError:
pytest.skip("output module not available") pytest.skip("output module not available")
@@ -380,7 +380,7 @@ class TestNatsPublisherIntegration:
import asyncio import asyncio
import nats # type: ignore[import-untyped] import nats # type: ignore[import-untyped]
from opengait.demo.output import NatsPublisher, create_result from opengait_studio.output import NatsPublisher, create_result
except ImportError as e: except ImportError as e:
pytest.skip(f"Required module not available: {e}") pytest.skip(f"Required module not available: {e}")
@@ -15,7 +15,7 @@ from numpy.typing import NDArray
import pytest import pytest
import torch import torch
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2] REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4" SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
@@ -31,7 +31,7 @@ def _device_for_runtime() -> str:
def _run_pipeline_cli( def _run_pipeline_cli(
*args: str, timeout_seconds: int = 120 *args: str, timeout_seconds: int = 120
) -> subprocess.CompletedProcess[str]: ) -> subprocess.CompletedProcess[str]:
command = [sys.executable, "-m", "opengait.demo", *args] command = [sys.executable, "-m", "opengait_studio", *args]
return subprocess.run( return subprocess.run(
command, command,
cwd=REPO_ROOT, cwd=REPO_ROOT,
@@ -728,14 +728,14 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
This is a regression test for the window freeze issue when no person is detected. This is a regression test for the window freeze issue when no person is detected.
The window should refresh every frame to prevent freezing. The window should refresh every frame to prevent freezing.
""" """
from opengait.demo.pipeline import ScoliosisPipeline from opengait_studio.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies # Create a minimal pipeline with mocked dependencies
with ( with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo, mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source, mock.patch("opengait_studio.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher, mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier, mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
): ):
# Setup mock detector that returns no detections (causing process_frame to return None) # Setup mock detector that returns no detections (causing process_frame to return None)
mock_detector = mock.MagicMock() mock_detector = mock.MagicMock()
@@ -791,16 +791,16 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None: def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
from opengait.demo.pipeline import ScoliosisPipeline from opengait_studio.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies # Create a minimal pipeline with mocked dependencies
with ( with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo, mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source, mock.patch("opengait_studio.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher, mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier, mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person, mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil, mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
): ):
# Create mock detection result for frames 0-1 (valid detection) # Create mock detection result for frames 0-1 (valid detection)
mock_box = mock.MagicMock() mock_box = mock.MagicMock()
@@ -919,7 +919,7 @@ def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
def test_frame_pacer_emission_count_24_to_15() -> None: def test_frame_pacer_emission_count_24_to_15() -> None:
from opengait.demo.pipeline import _FramePacer from opengait_studio.pipeline import _FramePacer
pacer = _FramePacer(15.0) pacer = _FramePacer(15.0)
interval_ns = int(1_000_000_000 / 24) interval_ns = int(1_000_000_000 / 24)
@@ -928,7 +928,7 @@ def test_frame_pacer_emission_count_24_to_15() -> None:
def test_frame_pacer_requires_positive_target_fps() -> None: def test_frame_pacer_requires_positive_target_fps() -> None:
from opengait.demo.pipeline import _FramePacer from opengait_studio.pipeline import _FramePacer
with pytest.raises(ValueError, match="target_fps must be positive"): with pytest.raises(ValueError, match="target_fps must be positive"):
_FramePacer(0.0) _FramePacer(0.0)
@@ -950,6 +950,6 @@ def test_resolve_stride_modes(
mode: Literal["manual", "sliding", "chunked"], mode: Literal["manual", "sliding", "chunked"],
expected: int, expected: int,
) -> None: ) -> None:
from opengait.demo.pipeline import resolve_stride from opengait_studio.pipeline import resolve_stride
assert resolve_stride(window, stride, mode) == expected assert resolve_stride(window, stride, mode) == expected
@@ -8,7 +8,7 @@ import pytest
from beartype.roar import BeartypeCallHintParamViolation from beartype.roar import BeartypeCallHintParamViolation
from jaxtyping import TypeCheckError from jaxtyping import TypeCheckError
from opengait.demo.preprocess import mask_to_silhouette from opengait_studio.preprocess import mask_to_silhouette
class TestMaskToSilhouette: class TestMaskToSilhouette:
@@ -18,7 +18,7 @@ import torch
from torch import Tensor from torch import Tensor
if TYPE_CHECKING: if TYPE_CHECKING:
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
# Constants for test configuration # Constants for test configuration
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml") CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@@ -27,7 +27,7 @@ CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@pytest.fixture @pytest.fixture
def demo() -> "ScoNetDemo": def demo() -> "ScoNetDemo":
"""Create ScoNetDemo without loading checkpoint (CPU-only).""" """Create ScoNetDemo without loading checkpoint (CPU-only)."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
return ScoNetDemo( return ScoNetDemo(
cfg_path=str(CONFIG_PATH), cfg_path=str(CONFIG_PATH),
@@ -71,7 +71,7 @@ class TestScoNetDemoConstruction:
def test_construction_from_config_no_checkpoint(self) -> None: def test_construction_from_config_no_checkpoint(self) -> None:
"""Test construction with config only, no checkpoint.""" """Test construction with config only, no checkpoint."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo( demo = ScoNetDemo(
cfg_path=str(CONFIG_PATH), cfg_path=str(CONFIG_PATH),
@@ -87,7 +87,7 @@ class TestScoNetDemoConstruction:
def test_construction_with_relative_path(self) -> None: def test_construction_with_relative_path(self) -> None:
"""Test construction handles relative config path correctly.""" """Test construction handles relative config path correctly."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo( demo = ScoNetDemo(
cfg_path="configs/sconet/sconet_scoliosis1k.yaml", cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
@@ -100,7 +100,7 @@ class TestScoNetDemoConstruction:
def test_construction_invalid_config_raises(self) -> None: def test_construction_invalid_config_raises(self) -> None:
"""Test construction raises with invalid config path.""" """Test construction raises with invalid config path."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
with pytest.raises((FileNotFoundError, TypeError)): with pytest.raises((FileNotFoundError, TypeError)):
_ = ScoNetDemo( _ = ScoNetDemo(
@@ -180,7 +180,7 @@ class TestScoNetDemoPredict:
self, demo: "ScoNetDemo", dummy_sils_single: Tensor self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None: ) -> None:
"""Test predict returns (str, float) tuple with valid label.""" """Test predict returns (str, float) tuple with valid label."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
result_raw = demo.predict(dummy_sils_single) result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw) result = cast(tuple[str, float], result_raw)
@@ -217,7 +217,7 @@ class TestScoNetDemoNoDDP:
def test_no_distributed_init_in_construction(self) -> None: def test_no_distributed_init_in_construction(self) -> None:
"""Test that construction does not call torch.distributed.""" """Test that construction does not call torch.distributed."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
with patch("torch.distributed.is_initialized") as mock_is_init: with patch("torch.distributed.is_initialized") as mock_is_init:
with patch("torch.distributed.init_process_group") as mock_init_pg: with patch("torch.distributed.init_process_group") as mock_init_pg:
@@ -327,14 +327,14 @@ class TestScoNetDemoLabelMap:
def test_label_map_has_three_classes(self) -> None: def test_label_map_has_three_classes(self) -> None:
"""Test LABEL_MAP has exactly 3 classes.""" """Test LABEL_MAP has exactly 3 classes."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
assert len(ScoNetDemo.LABEL_MAP) == 3 assert len(ScoNetDemo.LABEL_MAP) == 3
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2} assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
def test_label_map_values_are_valid_strings(self) -> None: def test_label_map_values_are_valid_strings(self) -> None:
"""Test LABEL_MAP values are valid non-empty strings.""" """Test LABEL_MAP values are valid non-empty strings."""
from opengait.demo.sconet_demo import ScoNetDemo from opengait_studio.sconet_demo import ScoNetDemo
for value in ScoNetDemo.LABEL_MAP.values(): for value in ScoNetDemo.LABEL_MAP.values():
assert isinstance(value, str) assert isinstance(value, str)
@@ -7,14 +7,14 @@ from unittest import mock
import numpy as np import numpy as np
import pytest import pytest
from opengait.demo.input import create_source from opengait_studio.input import create_source
from opengait.demo.visualizer import ( from opengait_studio.visualizer import (
DISPLAY_HEIGHT, DISPLAY_HEIGHT,
DISPLAY_WIDTH, DISPLAY_WIDTH,
ImageArray, ImageArray,
OpenCVVisualizer, OpenCVVisualizer,
) )
from opengait.demo.window import select_person from opengait_studio.window import select_person
REPO_ROOT = Path(__file__).resolve().parents[2] REPO_ROOT = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4" SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
@@ -8,7 +8,7 @@ import pytest
import torch import torch
from numpy.typing import NDArray from numpy.typing import NDArray
from opengait.demo.window import SilhouetteWindow, select_person from opengait_studio.window import SilhouetteWindow, select_person
class TestSilhouetteWindow: class TestSilhouetteWindow:
Generated
+9 -7
View File
@@ -614,7 +614,7 @@ name = "cuda-bindings"
version = "12.9.4" version = "12.9.4"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [ dependencies = [
{ name = "cuda-pathfinder", marker = "sys_platform == 'linux'" }, { name = "cuda-pathfinder", marker = "sys_platform != 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" },
@@ -1611,7 +1611,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21" version = "9.10.2.21"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [ dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
@@ -1622,7 +1622,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83" version = "11.3.3.83"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [ dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
@@ -1649,9 +1649,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90" version = "11.7.3.90"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [ dependencies = [
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" },
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'win32'" },
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
@@ -1662,7 +1662,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93" version = "12.5.8.93"
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
dependencies = [ dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
@@ -1737,6 +1737,7 @@ dependencies = [
{ name = "imageio" }, { name = "imageio" },
{ name = "kornia" }, { name = "kornia" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "nats-py" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.4.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" }, { name = "numpy", version = "2.4.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" },
{ name = "opencv-python" }, { name = "opencv-python" },
@@ -1780,6 +1781,7 @@ requires-dist = [
{ name = "imageio" }, { name = "imageio" },
{ name = "kornia" }, { name = "kornia" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "nats-py" },
{ name = "numpy" }, { name = "numpy" },
{ name = "opencv-python" }, { name = "opencv-python" },
{ name = "pillow" }, { name = "pillow" },
+1 -1
View File
@@ -1,4 +1,4 @@
uv run python -m opengait.demo \ uv run python -m opengait_studio \
--source "cvmmap://camera_5602" \ --source "cvmmap://camera_5602" \
--checkpoint "ckpt/ScoNet-20000.pt" \ --checkpoint "ckpt/ScoNet-20000.pt" \
--config "configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml" \ --config "configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml" \