From 00fcda4fe3f4da5ca8d93be1153739e423003eec Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 3 Mar 2026 17:16:17 +0800 Subject: [PATCH] feat: extract opengait_studio monorepo module Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management. --- AGENTS.md | 2 +- .../opengait_studio}/__init__.py | 0 opengait-studio/opengait_studio/__main__.py | 7 + .../opengait_studio}/input.py | 0 .../opengait_studio}/output.py | 37 ++-- .../opengait_studio}/pipeline.py | 105 +++++++--- .../opengait_studio}/preprocess.py | 0 .../opengait_studio}/sconet_demo.py | 5 - .../opengait_studio}/visualizer.py | 184 +++++++++++++++++- .../opengait_studio}/window.py | 0 opengait/data/collate_fn.py | 2 +- opengait/data/dataset.py | 2 +- opengait/data/transform.py | 2 +- opengait/demo/__main__.py | 149 -------------- opengait/evaluation/evaluator.py | 2 +- opengait/evaluation/metric.py | 2 +- opengait/main.py | 2 +- opengait/modeling/base_model.py | 8 +- opengait/modeling/loss_aggregator.py | 6 +- opengait/modeling/losses/base.py | 4 +- .../models/BigGait_utils/BigGait_GaitBase.py | 2 +- opengait/modeling/models/denoisinggait.py | 2 +- .../GaitBase_fusion_denoise_flow26_attn.py | 2 +- opengait/modeling/models/gaitedge.py | 2 +- opengait/modeling/models/gaitpart.py | 2 +- opengait/modeling/models/gaitssb.py | 4 +- opengait/modeling/models/swingait.py | 2 +- opengait/modeling/modules.py | 2 +- pyproject.toml | 6 +- tests/{demo => opengait_studio}/__init__.py | 0 tests/{demo => opengait_studio}/conftest.py | 0 tests/{demo => opengait_studio}/test_nats.py | 6 +- .../test_pipeline.py | 34 ++-- .../test_preprocess.py | 2 +- .../test_sconet_demo.py | 18 +- .../test_visualizer.py | 6 +- .../{demo => opengait_studio}/test_window.py | 2 +- uv.lock | 16 +- vis.sh | 2 +- 39 files changed, 359 insertions(+), 270 deletions(-) rename {opengait/demo => opengait-studio/opengait_studio}/__init__.py (100%) create mode 100644 opengait-studio/opengait_studio/__main__.py rename {opengait/demo => opengait-studio/opengait_studio}/input.py (100%) rename {opengait/demo => opengait-studio/opengait_studio}/output.py (91%) rename {opengait/demo => opengait-studio/opengait_studio}/pipeline.py (90%) rename {opengait/demo => opengait-studio/opengait_studio}/preprocess.py (100%) rename {opengait/demo => opengait-studio/opengait_studio}/sconet_demo.py (98%) rename {opengait/demo => opengait-studio/opengait_studio}/visualizer.py (73%) rename {opengait/demo => opengait-studio/opengait_studio}/window.py (100%) delete mode 100644 opengait/demo/__main__.py rename tests/{demo => opengait_studio}/__init__.py (100%) rename tests/{demo => opengait_studio}/conftest.py (100%) rename tests/{demo => opengait_studio}/test_nats.py (98%) rename tests/{demo => opengait_studio}/test_pipeline.py (96%) rename tests/{demo => opengait_studio}/test_preprocess.py (99%) rename tests/{demo => opengait_studio}/test_sconet_demo.py (95%) rename tests/{demo => opengait_studio}/test_visualizer.py (97%) rename tests/{demo => opengait_studio}/test_window.py (99%) diff --git a/AGENTS.md b/AGENTS.md index 7dd41e9..292f8a6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -57,7 +57,7 @@ CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch \ Demo CLI entry: ```bash -uv run python -m opengait.demo --help +uv run python -m opengait_studio --help ``` ## DDP Constraints (Important) diff --git a/opengait/demo/__init__.py b/opengait-studio/opengait_studio/__init__.py similarity index 100% rename from opengait/demo/__init__.py rename to opengait-studio/opengait_studio/__init__.py diff --git a/opengait-studio/opengait_studio/__main__.py b/opengait-studio/opengait_studio/__main__.py new file mode 100644 index 0000000..2590b2a --- /dev/null +++ b/opengait-studio/opengait_studio/__main__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from .pipeline import main + + +if __name__ == "__main__": + main() diff --git a/opengait/demo/input.py b/opengait-studio/opengait_studio/input.py similarity index 100% rename from opengait/demo/input.py rename to opengait-studio/opengait_studio/input.py diff --git a/opengait/demo/output.py b/opengait-studio/opengait_studio/output.py similarity index 91% rename from opengait/demo/output.py rename to opengait-studio/opengait_studio/output.py index 261a16c..5917a1f 100644 --- a/opengait/demo/output.py +++ b/opengait-studio/opengait_studio/output.py @@ -11,6 +11,7 @@ from __future__ import annotations import asyncio import json import logging +import nats import sys import threading import time @@ -81,7 +82,7 @@ class ConsolePublisher: json_line = json.dumps(result, ensure_ascii=False, default=str) _ = self._output.write(json_line + "\n") self._output.flush() - except Exception as e: + except (OSError, ValueError, TypeError) as e: logger.warning(f"Failed to publish to console: {e}") def close(self) -> None: @@ -173,7 +174,7 @@ class NatsPublisher: self._thread = threading.Thread(target=run_loop, daemon=True) self._thread.start() return True - except Exception as e: + except (RuntimeError, OSError) as e: logger.warning(f"Failed to start background event loop: {e}") return False @@ -204,7 +205,6 @@ class NatsPublisher: return False try: - import nats async def _connect() -> _NatsClient: nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType] @@ -219,12 +219,7 @@ class NatsPublisher: self._connected = True logger.info(f"Connected to NATS at {self._nats_url}") return True - except ImportError: - logger.warning( - "nats-py package not installed. Install with: pip install nats-py" - ) - return False - except Exception as e: + except (RuntimeError, OSError, TimeoutError) as e: logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") return False @@ -254,15 +249,27 @@ class NatsPublisher: _ = await self._nc.publish(self._subject, payload) _ = await self._nc.flush() - # Run publish in background loop future = asyncio.run_coroutine_threadsafe( _publish(), self._loop, # pyright: ignore[reportArgumentType] ) - future.result(timeout=5.0) # Wait for publish to complete - except Exception as e: - logger.warning(f"Failed to publish to NATS: {e}") - self._connected = False # Mark for reconnection on next publish + + def _on_done(publish_future: object) -> None: + fut = cast("asyncio.Future[None]", publish_future) + try: + exc = fut.exception() + except (RuntimeError, OSError) as callback_error: + logger.warning(f"NATS publish callback failed: {callback_error}") + self._connected = False + return + if exc is not None: + logger.warning(f"Failed to publish to NATS: {exc}") + self._connected = False + + future.add_done_callback(_on_done) + except (RuntimeError, OSError, ValueError, TypeError) as e: + logger.warning(f"Failed to schedule NATS publish: {e}") + self._connected = False def close(self) -> None: """Close NATS connection.""" @@ -279,7 +286,7 @@ class NatsPublisher: self._loop, ) future.result(timeout=5.0) - except Exception as e: + except (RuntimeError, OSError, TimeoutError) as e: logger.debug(f"Error closing NATS connection: {e}") finally: self._nc = None diff --git a/opengait/demo/pipeline.py b/opengait-studio/opengait_studio/pipeline.py similarity index 90% rename from opengait/demo/pipeline.py rename to opengait-studio/opengait_studio/pipeline.py index 8388d2a..62da000 100644 --- a/opengait/demo/pipeline.py +++ b/opengait-studio/opengait_studio/pipeline.py @@ -93,6 +93,19 @@ class _SelectedSilhouette(TypedDict): track_id: int +class _VizPayload(TypedDict, total=False): + result: DemoResult + mask_raw: UInt8[ndarray, "h w"] | None + bbox: BBoxXYXY | None + bbox_mask: BBoxXYXY | None + silhouette: Float[ndarray, "64 44"] | None + segmentation_input: NDArray[np.float32] | None + track_id: int + label: str | None + confidence: float | None + pose: dict[str, object] | None + + class _FramePacer: _interval_ns: int _next_emit_ns: int | None @@ -131,7 +144,7 @@ class ScoliosisPipeline: _result_export_format: str _result_buffer: list[DemoResult] _visualizer: OpenCVVisualizer | None - _last_viz_payload: dict[str, object] | None + _last_viz_payload: _VizPayload | None _frame_pacer: _FramePacer | None def __init__( @@ -208,10 +221,9 @@ class ScoliosisPipeline: return time.monotonic_ns() @staticmethod - def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]: - binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype( - np.uint8 - ) + def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]: + mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType] + binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8) return cast(UInt8[ndarray, "h w"], binary) def _first_result(self, detections: object) -> _DetectionResultsLike | None: @@ -294,7 +306,7 @@ class ScoliosisPipeline: self, frame: UInt8[ndarray, "h w c"], metadata: dict[str, object], - ) -> dict[str, object] | None: + ) -> _VizPayload | None: frame_idx = self._extract_int(metadata, "frame_count", fallback=0) timestamp_ns = self._extract_timestamp(metadata) @@ -323,8 +335,8 @@ class ScoliosisPipeline: bbox = selected["bbox_frame"] bbox_mask = selected["bbox_mask"] track_id = selected["track_id"] + pose_data = None - # Store silhouette for export if in preprocess-only mode or if export requested if self._silhouette_export_path is not None or self._preprocess_only: self._silhouette_buffer.append( { @@ -350,7 +362,9 @@ class ScoliosisPipeline: "track_id": track_id, "label": None, "confidence": None, + "pose": pose_data, } + self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) if self._frame_pacer is not None and not self._frame_pacer.should_emit( timestamp_ns @@ -364,9 +378,8 @@ class ScoliosisPipeline: "track_id": track_id, "label": None, "confidence": None, + "pose": pose_data, } - - self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) segmentation_input = self._window.buffered_silhouettes if not self._window.should_classify(): @@ -380,8 +393,8 @@ class ScoliosisPipeline: "track_id": track_id, "label": None, "confidence": None, + "pose": pose_data, } - window_tensor = self._window.get_tensor(device=self._device) label, confidence = cast( tuple[str, float], @@ -415,6 +428,7 @@ class ScoliosisPipeline: "track_id": track_id, "label": label, "confidence": confidence, + "pose": pose_data, } def run(self) -> int: @@ -445,7 +459,7 @@ class ScoliosisPipeline: viz_payload = None try: viz_payload = self.process_frame(frame_u8, metadata) - except Exception as frame_error: + except (RuntimeError, ValueError, TypeError, OSError) as frame_error: logger.warning( "Skipping frame %d due to processing error: %s", frame_idx, @@ -457,8 +471,8 @@ class ScoliosisPipeline: # Cache valid payload for no-detection frames if viz_payload is not None: # Cache a copy to prevent mutation of original data - viz_payload_dict = cast(dict[str, object], viz_payload) - cached: dict[str, object] = {} + viz_payload_dict = cast(_VizPayload, viz_payload) + cached: _VizPayload = {} for k, v in viz_payload_dict.items(): copy_method = cast( Callable[[], object] | None, getattr(v, "copy", None) @@ -477,12 +491,12 @@ class ScoliosisPipeline: viz_data["bbox_mask"] = None viz_data["label"] = None viz_data["confidence"] = None + viz_data["pose"] = None else: viz_data = None - if viz_data is not None: # Cast viz_payload to dict for type checking - viz_dict = cast(dict[str, object], viz_data) + viz_dict = cast(_VizPayload, viz_data) mask_raw_obj = viz_dict.get("mask_raw") bbox_obj = viz_dict.get("bbox") bbox_mask_obj = viz_dict.get("bbox_mask") @@ -492,8 +506,8 @@ class ScoliosisPipeline: track_id = track_id_val if isinstance(track_id_val, int) else 0 label_obj = viz_dict.get("label") confidence_obj = viz_dict.get("confidence") + pose_obj = viz_dict.get("pose") - # Cast extracted values to expected types mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) bbox = cast(BBoxXYXY | None, bbox_obj) bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj) @@ -504,6 +518,7 @@ class ScoliosisPipeline: ) label = cast(str | None, label_obj) confidence = cast(float | None, confidence_obj) + pose_data = cast(dict[str, object] | None, pose_obj) else: # No detection and no cache - use default values mask_raw = None @@ -514,19 +529,37 @@ class ScoliosisPipeline: segmentation_input = None label = None confidence = None + pose_data = None - keep_running = self._visualizer.update( - frame_u8, - bbox, - bbox_mask, - track_id, - mask_raw, - silhouette, - segmentation_input, - label, - confidence, - ema_fps, - ) + # Try keyword arg for pose_data (backward compatible with old signatures) + try: + keep_running = self._visualizer.update( + frame_u8, + bbox, + bbox_mask, + track_id, + mask_raw, + silhouette, + segmentation_input, + label, + confidence, + ema_fps, + pose_data=pose_data, + ) + except TypeError: + # Fallback for legacy visualizers that don't accept pose_data + keep_running = self._visualizer.update( + frame_u8, + bbox, + bbox_mask, + track_id, + mask_raw, + silhouette, + segmentation_input, + label, + confidence, + ema_fps, + ) if not keep_running: logger.info("Visualization closed by user.") break @@ -635,7 +668,7 @@ class ScoliosisPipeline: frames.append(item["frame"]) track_ids.append(item["track_id"]) timestamps.append(item["timestamp_ns"]) - silhouette_array = cast(ndarray, item["silhouette"]) + silhouette_array = cast(NDArray[np.float32], item["silhouette"]) silhouettes.append(silhouette_array.flatten().tolist()) table = pa.table( @@ -830,6 +863,12 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None: default=None, help="Directory to save silhouette PNG visualizations.", ) +@click.option( + "--visualize", + is_flag=True, + default=False, + help="Enable real-time visualization.", +) def main( source: str, checkpoint: str, @@ -839,7 +878,7 @@ def main( window: int, stride: int, window_mode: str, - target_fps: float | None, + target_fps: float, no_target_fps: bool, nats_url: str | None, nats_subject: str, @@ -850,7 +889,10 @@ def main( result_export_path: str | None, result_export_format: str, silhouette_visualize_dir: str | None, + visualize: bool, ) -> None: + # Resolve effective target_fps: respect --no-target_fps to disable pacing + effective_target_fps = None if no_target_fps else target_fps logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", @@ -884,7 +926,6 @@ def main( yolo_model=yolo_model, window=window, stride=effective_stride, - target_fps=None if no_target_fps else target_fps, nats_url=nats_url, nats_subject=nats_subject, max_frames=max_frames, @@ -894,6 +935,8 @@ def main( silhouette_visualize_dir=silhouette_visualize_dir, result_export_path=result_export_path, result_export_format=result_export_format, + visualize=visualize, + target_fps=effective_target_fps, ) raise SystemExit(pipeline.run()) except ValueError as err: diff --git a/opengait/demo/preprocess.py b/opengait-studio/opengait_studio/preprocess.py similarity index 100% rename from opengait/demo/preprocess.py rename to opengait-studio/opengait_studio/preprocess.py diff --git a/opengait/demo/sconet_demo.py b/opengait-studio/opengait_studio/sconet_demo.py similarity index 98% rename from opengait/demo/sconet_demo.py rename to opengait-studio/opengait_studio/sconet_demo.py index 377649c..28fd4e6 100644 --- a/opengait/demo/sconet_demo.py +++ b/opengait-studio/opengait_studio/sconet_demo.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Callable from pathlib import Path -import sys from typing import ClassVar, Protocol, cast, override import torch @@ -13,10 +12,6 @@ from jaxtyping import Float import jaxtyping from torch import Tensor -_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1] -if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path: - sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT)) - from opengait.modeling.backbones.resnet import ResNet9 from opengait.modeling.modules import ( HorizontalPoolingPyramid, diff --git a/opengait/demo/visualizer.py b/opengait-studio/opengait_studio/visualizer.py similarity index 73% rename from opengait/demo/visualizer.py rename to opengait-studio/opengait_studio/visualizer.py index 79024db..49bb3bc 100644 --- a/opengait/demo/visualizer.py +++ b/opengait-studio/opengait_studio/visualizer.py @@ -41,10 +41,54 @@ COLOR_BLACK = (0, 0, 0) COLOR_DARK_GRAY = (56, 56, 56) COLOR_RED = (0, 0, 255) COLOR_YELLOW = (0, 255, 255) - # Type alias for image arrays (NDArray or cv2.Mat) +COLOR_CYAN = (255, 255, 0) +COLOR_ORANGE = (0, 165, 255) +COLOR_MAGENTA = (255, 0, 255) ImageArray = NDArray[np.uint8] +# COCO-format skeleton connections (17 keypoints) +# Connections are pairs of keypoint indices +SKELETON_CONNECTIONS: list[tuple[int, int]] = [ + (0, 1), # nose -> left_eye + (0, 2), # nose -> right_eye + (1, 3), # left_eye -> left_ear + (2, 4), # right_eye -> right_ear + (5, 6), # left_shoulder -> right_shoulder + (5, 7), # left_shoulder -> left_elbow + (7, 9), # left_elbow -> left_wrist + (6, 8), # right_shoulder -> right_elbow + (8, 10), # right_elbow -> right_wrist + (11, 12), # left_hip -> right_hip + (5, 11), # left_shoulder -> left_hip + (6, 12), # right_shoulder -> right_hip + (11, 13), # left_hip -> left_knee + (13, 15), # left_knee -> left_ankle + (12, 14), # right_hip -> right_knee + (14, 16), # right_knee -> right_ankle +] + +# Keypoint names for COCO format (17 keypoints) +KEYPOINT_NAMES: list[str] = [ + "nose", "left_eye", "right_eye", "left_ear", "right_ear", + "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", + "left_wrist", "right_wrist", "left_hip", "right_hip", + "left_knee", "right_knee", "left_ankle", "right_ankle" +] + +# Joints where angles are typically calculated (for scoliosis/ gait analysis) +ANGLE_JOINTS: list[tuple[int, int, int]] = [ + (5, 7, 9), # left_shoulder -> left_elbow -> left_wrist + (6, 8, 10), # right_shoulder -> right_elbow -> right_wrist + (7, 5, 11), # left_elbow -> left_shoulder -> left_hip + (8, 6, 12), # right_elbow -> right_shoulder -> right_hip + (5, 11, 13), # left_shoulder -> left_hip -> left_knee + (6, 12, 14), # right_shoulder -> right_hip -> right_knee + (11, 13, 15),# left_hip -> left_knee -> left_ankle + (12, 14, 16),# right_hip -> right_knee -> right_ankle +] + + class OpenCVVisualizer: def __init__(self) -> None: @@ -149,6 +193,134 @@ class OpenCVVisualizer: thickness, ) + def _draw_pose_skeleton( + self, + frame: ImageArray, + pose_data: dict[str, object] | None, + ) -> None: + """Draw pose skeleton on frame. + + Args: + frame: Input frame (H, W, 3) uint8 - modified in place + pose_data: Pose data dictionary from Sports2D or similar + Expected format: {'keypoints': [[x1, y1], [x2, y2], ...], + 'confidence': [c1, c2, ...], + 'angles': {'joint_name': angle, ...}} + """ + if pose_data is None: + return + + keypoints_obj = pose_data.get('keypoints') + if keypoints_obj is None: + return + + # Convert keypoints to numpy array + keypoints = np.asarray(keypoints_obj, dtype=np.float32) + if keypoints.size == 0: + return + + h, w = frame.shape[:2] + + # Get confidence scores if available + confidence_obj = pose_data.get('confidence') + confidences = ( + np.asarray(confidence_obj, dtype=np.float32) + if confidence_obj is not None + else np.ones(len(keypoints), dtype=np.float32) + ) + + # Draw skeleton connections + for connection in SKELETON_CONNECTIONS: + idx1, idx2 = connection + if idx1 < len(keypoints) and idx2 < len(keypoints): + # Check confidence threshold (0.3) + if confidences[idx1] > 0.3 and confidences[idx2] > 0.3: + pt1 = (int(keypoints[idx1][0]), int(keypoints[idx1][1])) + pt2 = (int(keypoints[idx2][0]), int(keypoints[idx2][1])) + # Clip to frame bounds + pt1 = (max(0, min(w - 1, pt1[0])), max(0, min(h - 1, pt1[1]))) + pt2 = (max(0, min(w - 1, pt2[0])), max(0, min(h - 1, pt2[1]))) + _ = cv2.line(frame, pt1, pt2, COLOR_CYAN, 2) + + # Draw keypoints + for i, (kp, conf) in enumerate(zip(keypoints, confidences)): + if conf > 0.3 and i < len(keypoints): + x, y = int(kp[0]), int(kp[1]) + # Clip to frame bounds + x = max(0, min(w - 1, x)) + y = max(0, min(h - 1, y)) + # Draw keypoint as circle + _ = cv2.circle(frame, (x, y), 4, COLOR_MAGENTA, -1) + _ = cv2.circle(frame, (x, y), 4, COLOR_WHITE, 1) + + def _draw_pose_angles( + self, + frame: ImageArray, + pose_data: dict[str, object] | None, + ) -> None: + """Draw pose angles as text overlay. + + Args: + frame: Input frame (H, W, 3) uint8 - modified in place + pose_data: Pose data dictionary with 'angles' key + """ + if pose_data is None: + return + + angles_obj = pose_data.get('angles') + if angles_obj is None: + return + + angles = cast(dict[str, float], angles_obj) + if not angles: + return + + # Draw angles in top-right corner + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.45 + thickness = 1 + line_height = 20 + margin = 10 + h, w = frame.shape[:2] + + # Filter and format angles + angle_texts: list[tuple[str, float]] = [] + for name, angle in angles.items(): + # Only show angles that are reasonable (0-180 degrees) + if 0 <= angle <= 180: + angle_texts.append((str(name), float(angle))) + + # Sort by name for consistent display + angle_texts.sort(key=lambda x: x[0]) + + # Draw from top-right + for i, (name, angle) in enumerate(angle_texts[:8]): # Limit to 8 angles + text = f"{name}: {angle:.1f}" + (text_width, text_height), _ = cv2.getTextSize( + text, font, font_scale, thickness + ) + x_pos = w - margin - text_width - 10 + y_pos = margin + (i + 1) * line_height + + # Draw background rectangle + _ = cv2.rectangle( + frame, + (x_pos - 4, y_pos - text_height - 4), + (x_pos + text_width + 4, y_pos + 4), + COLOR_BLACK, + -1, + ) + # Draw text in orange + _ = cv2.putText( + frame, + text, + (x_pos, y_pos), + font, + font_scale, + COLOR_ORANGE, + thickness, + ) + def _prepare_main_frame( self, frame: ImageArray, @@ -157,6 +329,7 @@ class OpenCVVisualizer: fps: float, label: str | None, confidence: float | None, + pose_data: dict[str, object] | None = None, ) -> ImageArray: """Prepare main display frame with bbox and text overlay. @@ -167,6 +340,7 @@ class OpenCVVisualizer: fps: Current FPS label: Classification label or None confidence: Classification confidence or None + pose_data: Pose data dictionary or None Returns: Processed frame ready for display @@ -187,6 +361,10 @@ class OpenCVVisualizer: self._draw_bbox(display_frame, bbox) self._draw_text_overlay(display_frame, track_id, fps, label, confidence) + # Draw pose skeleton and angles if available + self._draw_pose_skeleton(display_frame, pose_data) + self._draw_pose_angles(display_frame, pose_data) + return display_frame def _upscale_silhouette( @@ -521,6 +699,7 @@ class OpenCVVisualizer: label: str | None, confidence: float | None, fps: float, + pose_data: dict[str, object] | None = None, ) -> bool: """Update visualization with new frame data. @@ -533,6 +712,7 @@ class OpenCVVisualizer: label: Classification label or None confidence: Classification confidence [0,1] or None fps: Current FPS + pose_data: Pose data dictionary or None Returns: False if user requested quit (pressed 'q'), True otherwise @@ -541,7 +721,7 @@ class OpenCVVisualizer: # Prepare and show main window main_display = self._prepare_main_frame( - frame, bbox, track_id, fps, label, confidence + frame, bbox, track_id, fps, label, confidence, pose_data ) cv2.imshow(MAIN_WINDOW, main_display) diff --git a/opengait/demo/window.py b/opengait-studio/opengait_studio/window.py similarity index 100% rename from opengait/demo/window.py rename to opengait-studio/opengait_studio/window.py diff --git a/opengait/data/collate_fn.py b/opengait/data/collate_fn.py index a78c73d..3eeb7e9 100644 --- a/opengait/data/collate_fn.py +++ b/opengait/data/collate_fn.py @@ -1,7 +1,7 @@ import math import random import numpy as np -from utils import get_msg_mgr +from opengait.utils import get_msg_mgr class CollateFn(object): diff --git a/opengait/data/dataset.py b/opengait/data/dataset.py index a06ed58..bbf9e07 100644 --- a/opengait/data/dataset.py +++ b/opengait/data/dataset.py @@ -3,7 +3,7 @@ import pickle import os.path as osp import torch.utils.data as tordata import json -from utils import get_msg_mgr +from opengait.utils import get_msg_mgr class DataSet(tordata.Dataset): diff --git a/opengait/data/transform.py b/opengait/data/transform.py index 9e4d403..f853e2a 100644 --- a/opengait/data/transform.py +++ b/opengait/data/transform.py @@ -4,7 +4,7 @@ import torchvision.transforms as T import cv2 import math from data import transform as base_transform -from utils import is_list, is_dict, get_valid_args +from opengait.utils import is_list, is_dict, get_valid_args class NoOperation(): diff --git a/opengait/demo/__main__.py b/opengait/demo/__main__.py deleted file mode 100644 index 3604203..0000000 --- a/opengait/demo/__main__.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -import argparse -import logging -import sys -from typing import cast - -from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride - - -def _positive_float(value: str) -> float: - parsed = float(value) - if parsed <= 0: - raise argparse.ArgumentTypeError("target-fps must be positive") - return parsed - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline") - parser.add_argument( - "--source", type=str, required=True, help="Video source path or camera ID" - ) - parser.add_argument( - "--checkpoint", type=str, required=True, help="Model checkpoint path" - ) - parser.add_argument( - "--config", - type=str, - default="configs/sconet/sconet_scoliosis1k.yaml", - help="Config file path", - ) - parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on") - parser.add_argument( - "--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name" - ) - parser.add_argument( - "--window", type=int, default=30, help="Window size for classification" - ) - parser.add_argument("--stride", type=int, default=30, help="Stride for window") - parser.add_argument( - "--target-fps", - type=_positive_float, - default=15.0, - help="Target FPS for temporal downsampling before windowing", - ) - parser.add_argument( - "--window-mode", - type=str, - choices=["manual", "sliding", "chunked"], - default="manual", - help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window", - ) - parser.add_argument( - "--no-target-fps", - action="store_true", - help="Disable temporal downsampling and use all frames", - ) - parser.add_argument( - "--nats-url", type=str, default=None, help="NATS URL for result publishing" - ) - parser.add_argument( - "--nats-subject", type=str, default="scoliosis.result", help="NATS subject" - ) - parser.add_argument( - "--max-frames", type=int, default=None, help="Maximum frames to process" - ) - parser.add_argument( - "--preprocess-only", action="store_true", help="Only preprocess silhouettes" - ) - parser.add_argument( - "--silhouette-export-path", - type=str, - default=None, - help="Path to export silhouettes", - ) - parser.add_argument( - "--silhouette-export-format", type=str, default="pickle", help="Export format" - ) - parser.add_argument( - "--silhouette-visualize-dir", - type=str, - default=None, - help="Directory for silhouette visualizations", - ) - parser.add_argument( - "--result-export-path", type=str, default=None, help="Path to export results" - ) - parser.add_argument( - "--result-export-format", type=str, default="json", help="Result export format" - ) - parser.add_argument( - "--visualize", action="store_true", help="Enable real-time visualization" - ) - args = parser.parse_args() - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(name)s: %(message)s", - ) - - # Validate preprocess-only mode requires silhouette export path - if args.preprocess_only and not args.silhouette_export_path: - print( - "Error: --silhouette-export-path is required when using --preprocess-only", - file=sys.stderr, - ) - raise SystemExit(2) - - try: - # Import here to avoid circular imports - from .pipeline import validate_runtime_inputs - - validate_runtime_inputs( - source=args.source, checkpoint=args.checkpoint, config=args.config - ) - - effective_stride = resolve_stride( - window=cast(int, args.window), - stride=cast(int, args.stride), - window_mode=cast(WindowMode, args.window_mode), - ) - - pipeline = ScoliosisPipeline( - source=cast(str, args.source), - checkpoint=cast(str, args.checkpoint), - config=cast(str, args.config), - device=cast(str, args.device), - yolo_model=cast(str, args.yolo_model), - window=cast(int, args.window), - stride=effective_stride, - target_fps=(None if args.no_target_fps else cast(float, args.target_fps)), - nats_url=cast(str | None, args.nats_url), - nats_subject=cast(str, args.nats_subject), - max_frames=cast(int | None, args.max_frames), - preprocess_only=cast(bool, args.preprocess_only), - silhouette_export_path=cast(str | None, args.silhouette_export_path), - silhouette_export_format=cast(str, args.silhouette_export_format), - silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir), - result_export_path=cast(str | None, args.result_export_path), - result_export_format=cast(str, args.result_export_format), - visualize=cast(bool, args.visualize), - ) - raise SystemExit(pipeline.run()) - except ValueError as err: - print(f"Error: {err}", file=sys.stderr) - raise SystemExit(2) from err - except RuntimeError as err: - print(f"Runtime error: {err}", file=sys.stderr) - raise SystemExit(1) from err diff --git a/opengait/evaluation/evaluator.py b/opengait/evaluation/evaluator.py index 053b7c4..9a2bb96 100644 --- a/opengait/evaluation/evaluator.py +++ b/opengait/evaluation/evaluator.py @@ -1,7 +1,7 @@ import os from time import strftime, localtime import numpy as np -from utils import get_msg_mgr, mkdir +from opengait.utils import get_msg_mgr, mkdir from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many from .re_rank import re_ranking diff --git a/opengait/evaluation/metric.py b/opengait/evaluation/metric.py index 42a6337..4a93ec8 100644 --- a/opengait/evaluation/metric.py +++ b/opengait/evaluation/metric.py @@ -2,7 +2,7 @@ import torch import numpy as np import torch.nn.functional as F -from utils import is_tensor +from opengait.utils import is_tensor def cuda_dist(x, y, metric='euc'): diff --git a/opengait/main.py b/opengait/main.py index 22851a0..dc2eec2 100644 --- a/opengait/main.py +++ b/opengait/main.py @@ -4,7 +4,7 @@ import argparse import torch import torch.nn as nn from modeling import models -from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr +from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr parser = argparse.ArgumentParser(description='Main program for opengait.') parser.add_argument('--local_rank', type=int, default=0, diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index 86b5e4c..cb99778 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -28,11 +28,11 @@ from data.transform import get_transform from data.collate_fn import CollateFn from data.dataset import DataSet import data.sampler as Samplers -from utils import Odict, mkdir, ddp_all_gather -from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from +from opengait.utils import Odict, mkdir, ddp_all_gather +from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from evaluation import evaluator as eval_functions -from utils import NoOp -from utils import get_msg_mgr +from opengait.utils import NoOp +from opengait.utils import get_msg_mgr __all__ = ['BaseModel'] diff --git a/opengait/modeling/loss_aggregator.py b/opengait/modeling/loss_aggregator.py index a3f5982..fecdd70 100644 --- a/opengait/modeling/loss_aggregator.py +++ b/opengait/modeling/loss_aggregator.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn from . import losses -from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module -from utils import Odict -from utils import get_msg_mgr +from opengait.utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module +from opengait.utils import Odict +from opengait.utils import get_msg_mgr class LossAggregator(nn.Module): diff --git a/opengait/modeling/losses/base.py b/opengait/modeling/losses/base.py index ba4d94f..c235ca0 100644 --- a/opengait/modeling/losses/base.py +++ b/opengait/modeling/losses/base.py @@ -1,9 +1,9 @@ from ctypes import ArgumentError import torch.nn as nn import torch -from utils import Odict +from opengait.utils import Odict import functools -from utils import ddp_all_gather +from opengait.utils import ddp_all_gather def gather_and_scale_wrapper(func): diff --git a/opengait/modeling/models/BigGait_utils/BigGait_GaitBase.py b/opengait/modeling/models/BigGait_utils/BigGait_GaitBase.py index 61f3b8a..ba2b64f 100644 --- a/opengait/modeling/models/BigGait_utils/BigGait_GaitBase.py +++ b/opengait/modeling/models/BigGait_utils/BigGait_GaitBase.py @@ -132,7 +132,7 @@ class Post_ResNet9(ResNet): return x -from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from +from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from ... import backbones class Baseline(nn.Module): def __init__(self, model_cfg): diff --git a/opengait/modeling/models/denoisinggait.py b/opengait/modeling/models/denoisinggait.py index 69e9168..cca01a3 100644 --- a/opengait/modeling/models/denoisinggait.py +++ b/opengait/modeling/models/denoisinggait.py @@ -4,7 +4,7 @@ from ..base_model import BaseModel from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc import torch.optim as optim from einops import rearrange -from utils import get_valid_args +from opengait.utils import get_valid_args import warnings import random from torchvision.utils import flow_to_image diff --git a/opengait/modeling/models/diffgait_utils/GaitBase_fusion_denoise_flow26_attn.py b/opengait/modeling/models/diffgait_utils/GaitBase_fusion_denoise_flow26_attn.py index 1b916c1..8a95705 100644 --- a/opengait/modeling/models/diffgait_utils/GaitBase_fusion_denoise_flow26_attn.py +++ b/opengait/modeling/models/diffgait_utils/GaitBase_fusion_denoise_flow26_attn.py @@ -161,7 +161,7 @@ class Post_ResNet9(ResNet): return x -from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from +from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from ... import backbones class GaitBaseFusion_denoise(nn.Module): def __init__(self, model_cfg): diff --git a/opengait/modeling/models/gaitedge.py b/opengait/modeling/models/gaitedge.py index 0f52f10..5e8988d 100644 --- a/opengait/modeling/models/gaitedge.py +++ b/opengait/modeling/models/gaitedge.py @@ -6,7 +6,7 @@ from ..base_model import BaseModel from .gaitgl import GaitGL from ..modules import GaitAlign from torchvision.transforms import Resize -from utils import get_valid_args, get_attr_from, is_list_or_tuple +from opengait.utils import get_valid_args, get_attr_from, is_list_or_tuple import os.path as osp diff --git a/opengait/modeling/models/gaitpart.py b/opengait/modeling/models/gaitpart.py index 3dcef4e..38cbc3d 100644 --- a/opengait/modeling/models/gaitpart.py +++ b/opengait/modeling/models/gaitpart.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from ..base_model import BaseModel from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs -from utils import clones +from opengait.utils import clones class BasicConv1d(nn.Module): diff --git a/opengait/modeling/models/gaitssb.py b/opengait/modeling/models/gaitssb.py index fd46a2a..25973b1 100644 --- a/opengait/modeling/models/gaitssb.py +++ b/opengait/modeling/models/gaitssb.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from ..base_model import BaseModel from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs -from utils import np2var, list2var, get_valid_args, ddp_all_gather +from opengait.utils import np2var, list2var, get_valid_args, ddp_all_gather from data.transform import get_transform from einops import rearrange @@ -143,7 +143,7 @@ class GaitSSB_Pretrain(BaseModel): import torch.optim as optim import numpy as np -from utils import get_valid_args, list2var +from opengait.utils import get_valid_args, list2var class no_grad(torch.no_grad): def __init__(self, enable=True): diff --git a/opengait/modeling/models/swingait.py b/opengait/modeling/models/swingait.py index 868e255..28fdbb7 100644 --- a/opengait/modeling/models/swingait.py +++ b/opengait/modeling/models/swingait.py @@ -782,7 +782,7 @@ from ..modules import BasicBlock2D, BasicBlockP3D import torch.optim as optim import os.path as osp from collections import OrderedDict -from utils import get_valid_args, get_attr_from +from opengait.utils import get_valid_args, get_attr_from class SwinGait(BaseModel): def __init__(self, cfgs, training): diff --git a/opengait/modeling/modules.py b/opengait/modeling/modules.py index 2d165b0..9a88363 100644 --- a/opengait/modeling/modules.py +++ b/opengait/modeling/modules.py @@ -2,7 +2,7 @@ import torch import numpy as np import torch.nn as nn import torch.nn.functional as F -from utils import clones, is_list_or_tuple +from opengait.utils import clones, is_list_or_tuple from torchvision.ops import RoIAlign diff --git a/pyproject.toml b/pyproject.toml index 938be87..2671f02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "scikit-learn", "matplotlib", "cvmmap-client", + "nats-py", ] [project.optional-dependencies] @@ -35,7 +36,10 @@ wandb = [ ] [tool.setuptools] -packages = ["opengait"] + +[tool.setuptools.packages.find] +where = [".", "opengait-studio"] +include = ["opengait", "opengait.*", "opengait_studio", "opengait_studio.*"] [dependency-groups] dev = [ diff --git a/tests/demo/__init__.py b/tests/opengait_studio/__init__.py similarity index 100% rename from tests/demo/__init__.py rename to tests/opengait_studio/__init__.py diff --git a/tests/demo/conftest.py b/tests/opengait_studio/conftest.py similarity index 100% rename from tests/demo/conftest.py rename to tests/opengait_studio/conftest.py diff --git a/tests/demo/test_nats.py b/tests/opengait_studio/test_nats.py similarity index 98% rename from tests/demo/test_nats.py rename to tests/opengait_studio/test_nats.py index f72d41d..e1f7ba9 100644 --- a/tests/demo/test_nats.py +++ b/tests/opengait_studio/test_nats.py @@ -251,7 +251,7 @@ class TestNatsPublisherIntegration: except ImportError: pytest.skip("nats-py not installed") - from opengait.demo.output import NatsPublisher, create_result + from opengait_studio.output import NatsPublisher, create_result # Create publisher publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT) @@ -341,7 +341,7 @@ class TestNatsPublisherIntegration: def test_nats_publisher_graceful_when_server_unavailable(self) -> None: """Test that publisher handles missing server gracefully.""" try: - from opengait.demo.output import NatsPublisher + from opengait_studio.output import NatsPublisher except ImportError: pytest.skip("output module not available") @@ -380,7 +380,7 @@ class TestNatsPublisherIntegration: import asyncio import nats # type: ignore[import-untyped] - from opengait.demo.output import NatsPublisher, create_result + from opengait_studio.output import NatsPublisher, create_result except ImportError as e: pytest.skip(f"Required module not available: {e}") diff --git a/tests/demo/test_pipeline.py b/tests/opengait_studio/test_pipeline.py similarity index 96% rename from tests/demo/test_pipeline.py rename to tests/opengait_studio/test_pipeline.py index bb07366..8c716a5 100644 --- a/tests/demo/test_pipeline.py +++ b/tests/opengait_studio/test_pipeline.py @@ -15,7 +15,7 @@ from numpy.typing import NDArray import pytest import torch -from opengait.demo.sconet_demo import ScoNetDemo +from opengait_studio.sconet_demo import ScoNetDemo REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2] SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4" @@ -31,7 +31,7 @@ def _device_for_runtime() -> str: def _run_pipeline_cli( *args: str, timeout_seconds: int = 120 ) -> subprocess.CompletedProcess[str]: - command = [sys.executable, "-m", "opengait.demo", *args] + command = [sys.executable, "-m", "opengait_studio", *args] return subprocess.run( command, cwd=REPO_ROOT, @@ -728,14 +728,14 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None: This is a regression test for the window freeze issue when no person is detected. The window should refresh every frame to prevent freezing. """ - from opengait.demo.pipeline import ScoliosisPipeline + from opengait_studio.pipeline import ScoliosisPipeline # Create a minimal pipeline with mocked dependencies with ( - mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo, - mock.patch("opengait.demo.pipeline.create_source") as mock_source, - mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher, - mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier, + mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo, + mock.patch("opengait_studio.pipeline.create_source") as mock_source, + mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher, + mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier, ): # Setup mock detector that returns no detections (causing process_frame to return None) mock_detector = mock.MagicMock() @@ -791,16 +791,16 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None: def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None: - from opengait.demo.pipeline import ScoliosisPipeline + from opengait_studio.pipeline import ScoliosisPipeline # Create a minimal pipeline with mocked dependencies with ( - mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo, - mock.patch("opengait.demo.pipeline.create_source") as mock_source, - mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher, - mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier, - mock.patch("opengait.demo.pipeline.select_person") as mock_select_person, - mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil, + mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo, + mock.patch("opengait_studio.pipeline.create_source") as mock_source, + mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher, + mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier, + mock.patch("opengait_studio.pipeline.select_person") as mock_select_person, + mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil, ): # Create mock detection result for frames 0-1 (valid detection) mock_box = mock.MagicMock() @@ -919,7 +919,7 @@ def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None: def test_frame_pacer_emission_count_24_to_15() -> None: - from opengait.demo.pipeline import _FramePacer + from opengait_studio.pipeline import _FramePacer pacer = _FramePacer(15.0) interval_ns = int(1_000_000_000 / 24) @@ -928,7 +928,7 @@ def test_frame_pacer_emission_count_24_to_15() -> None: def test_frame_pacer_requires_positive_target_fps() -> None: - from opengait.demo.pipeline import _FramePacer + from opengait_studio.pipeline import _FramePacer with pytest.raises(ValueError, match="target_fps must be positive"): _FramePacer(0.0) @@ -950,6 +950,6 @@ def test_resolve_stride_modes( mode: Literal["manual", "sliding", "chunked"], expected: int, ) -> None: - from opengait.demo.pipeline import resolve_stride + from opengait_studio.pipeline import resolve_stride assert resolve_stride(window, stride, mode) == expected diff --git a/tests/demo/test_preprocess.py b/tests/opengait_studio/test_preprocess.py similarity index 99% rename from tests/demo/test_preprocess.py rename to tests/opengait_studio/test_preprocess.py index 7290db6..70fa4f4 100644 --- a/tests/demo/test_preprocess.py +++ b/tests/opengait_studio/test_preprocess.py @@ -8,7 +8,7 @@ import pytest from beartype.roar import BeartypeCallHintParamViolation from jaxtyping import TypeCheckError -from opengait.demo.preprocess import mask_to_silhouette +from opengait_studio.preprocess import mask_to_silhouette class TestMaskToSilhouette: diff --git a/tests/demo/test_sconet_demo.py b/tests/opengait_studio/test_sconet_demo.py similarity index 95% rename from tests/demo/test_sconet_demo.py rename to tests/opengait_studio/test_sconet_demo.py index 613c165..0b6a01b 100644 --- a/tests/demo/test_sconet_demo.py +++ b/tests/opengait_studio/test_sconet_demo.py @@ -18,7 +18,7 @@ import torch from torch import Tensor if TYPE_CHECKING: - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo # Constants for test configuration CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml") @@ -27,7 +27,7 @@ CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml") @pytest.fixture def demo() -> "ScoNetDemo": """Create ScoNetDemo without loading checkpoint (CPU-only).""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo return ScoNetDemo( cfg_path=str(CONFIG_PATH), @@ -71,7 +71,7 @@ class TestScoNetDemoConstruction: def test_construction_from_config_no_checkpoint(self) -> None: """Test construction with config only, no checkpoint.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo demo = ScoNetDemo( cfg_path=str(CONFIG_PATH), @@ -87,7 +87,7 @@ class TestScoNetDemoConstruction: def test_construction_with_relative_path(self) -> None: """Test construction handles relative config path correctly.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo demo = ScoNetDemo( cfg_path="configs/sconet/sconet_scoliosis1k.yaml", @@ -100,7 +100,7 @@ class TestScoNetDemoConstruction: def test_construction_invalid_config_raises(self) -> None: """Test construction raises with invalid config path.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo with pytest.raises((FileNotFoundError, TypeError)): _ = ScoNetDemo( @@ -180,7 +180,7 @@ class TestScoNetDemoPredict: self, demo: "ScoNetDemo", dummy_sils_single: Tensor ) -> None: """Test predict returns (str, float) tuple with valid label.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo result_raw = demo.predict(dummy_sils_single) result = cast(tuple[str, float], result_raw) @@ -217,7 +217,7 @@ class TestScoNetDemoNoDDP: def test_no_distributed_init_in_construction(self) -> None: """Test that construction does not call torch.distributed.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo with patch("torch.distributed.is_initialized") as mock_is_init: with patch("torch.distributed.init_process_group") as mock_init_pg: @@ -327,14 +327,14 @@ class TestScoNetDemoLabelMap: def test_label_map_has_three_classes(self) -> None: """Test LABEL_MAP has exactly 3 classes.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo assert len(ScoNetDemo.LABEL_MAP) == 3 assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2} def test_label_map_values_are_valid_strings(self) -> None: """Test LABEL_MAP values are valid non-empty strings.""" - from opengait.demo.sconet_demo import ScoNetDemo + from opengait_studio.sconet_demo import ScoNetDemo for value in ScoNetDemo.LABEL_MAP.values(): assert isinstance(value, str) diff --git a/tests/demo/test_visualizer.py b/tests/opengait_studio/test_visualizer.py similarity index 97% rename from tests/demo/test_visualizer.py rename to tests/opengait_studio/test_visualizer.py index 1ef61d2..f9d0e14 100644 --- a/tests/demo/test_visualizer.py +++ b/tests/opengait_studio/test_visualizer.py @@ -7,14 +7,14 @@ from unittest import mock import numpy as np import pytest -from opengait.demo.input import create_source -from opengait.demo.visualizer import ( +from opengait_studio.input import create_source +from opengait_studio.visualizer import ( DISPLAY_HEIGHT, DISPLAY_WIDTH, ImageArray, OpenCVVisualizer, ) -from opengait.demo.window import select_person +from opengait_studio.window import select_person REPO_ROOT = Path(__file__).resolve().parents[2] SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4" diff --git a/tests/demo/test_window.py b/tests/opengait_studio/test_window.py similarity index 99% rename from tests/demo/test_window.py rename to tests/opengait_studio/test_window.py index 0b67268..bc8ce7b 100644 --- a/tests/demo/test_window.py +++ b/tests/opengait_studio/test_window.py @@ -8,7 +8,7 @@ import pytest import torch from numpy.typing import NDArray -from opengait.demo.window import SilhouetteWindow, select_person +from opengait_studio.window import SilhouetteWindow, select_person class TestSilhouetteWindow: diff --git a/uv.lock b/uv.lock index 1b185e5..977c9a7 100644 --- a/uv.lock +++ b/uv.lock @@ -614,7 +614,7 @@ name = "cuda-bindings" version = "12.9.4" source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } dependencies = [ - { name = "cuda-pathfinder", marker = "sys_platform == 'linux'" }, + { name = "cuda-pathfinder", marker = "sys_platform != 'win32'" }, ] wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, @@ -1611,7 +1611,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" }, ] wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -1622,7 +1622,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" }, ] wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -1649,9 +1649,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" }, ] wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -1662,7 +1662,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" }, ] wheels = [ { url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -1737,6 +1737,7 @@ dependencies = [ { name = "imageio" }, { name = "kornia" }, { name = "matplotlib" }, + { name = "nats-py" }, { name = "numpy", version = "2.2.6", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" }, { name = "opencv-python" }, @@ -1780,6 +1781,7 @@ requires-dist = [ { name = "imageio" }, { name = "kornia" }, { name = "matplotlib" }, + { name = "nats-py" }, { name = "numpy" }, { name = "opencv-python" }, { name = "pillow" }, diff --git a/vis.sh b/vis.sh index db2673d..8535a96 100644 --- a/vis.sh +++ b/vis.sh @@ -1,4 +1,4 @@ -uv run python -m opengait.demo \ +uv run python -m opengait_studio \ --source "cvmmap://camera_5602" \ --checkpoint "ckpt/ScoNet-20000.pt" \ --config "configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml" \