feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
@@ -57,7 +57,7 @@ CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch \
|
|||||||
Demo CLI entry:
|
Demo CLI entry:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run python -m opengait.demo --help
|
uv run python -m opengait_studio --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## DDP Constraints (Important)
|
## DDP Constraints (Important)
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .pipeline import main
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -11,6 +11,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import nats
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -81,7 +82,7 @@ class ConsolePublisher:
|
|||||||
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
||||||
_ = self._output.write(json_line + "\n")
|
_ = self._output.write(json_line + "\n")
|
||||||
self._output.flush()
|
self._output.flush()
|
||||||
except Exception as e:
|
except (OSError, ValueError, TypeError) as e:
|
||||||
logger.warning(f"Failed to publish to console: {e}")
|
logger.warning(f"Failed to publish to console: {e}")
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
@@ -173,7 +174,7 @@ class NatsPublisher:
|
|||||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except (RuntimeError, OSError) as e:
|
||||||
logger.warning(f"Failed to start background event loop: {e}")
|
logger.warning(f"Failed to start background event loop: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -204,7 +205,6 @@ class NatsPublisher:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import nats
|
|
||||||
|
|
||||||
async def _connect() -> _NatsClient:
|
async def _connect() -> _NatsClient:
|
||||||
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||||
@@ -219,12 +219,7 @@ class NatsPublisher:
|
|||||||
self._connected = True
|
self._connected = True
|
||||||
logger.info(f"Connected to NATS at {self._nats_url}")
|
logger.info(f"Connected to NATS at {self._nats_url}")
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except (RuntimeError, OSError, TimeoutError) as e:
|
||||||
logger.warning(
|
|
||||||
"nats-py package not installed. Install with: pip install nats-py"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -254,15 +249,27 @@ class NatsPublisher:
|
|||||||
_ = await self._nc.publish(self._subject, payload)
|
_ = await self._nc.publish(self._subject, payload)
|
||||||
_ = await self._nc.flush()
|
_ = await self._nc.flush()
|
||||||
|
|
||||||
# Run publish in background loop
|
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
_publish(),
|
_publish(),
|
||||||
self._loop, # pyright: ignore[reportArgumentType]
|
self._loop, # pyright: ignore[reportArgumentType]
|
||||||
)
|
)
|
||||||
future.result(timeout=5.0) # Wait for publish to complete
|
|
||||||
except Exception as e:
|
def _on_done(publish_future: object) -> None:
|
||||||
logger.warning(f"Failed to publish to NATS: {e}")
|
fut = cast("asyncio.Future[None]", publish_future)
|
||||||
self._connected = False # Mark for reconnection on next publish
|
try:
|
||||||
|
exc = fut.exception()
|
||||||
|
except (RuntimeError, OSError) as callback_error:
|
||||||
|
logger.warning(f"NATS publish callback failed: {callback_error}")
|
||||||
|
self._connected = False
|
||||||
|
return
|
||||||
|
if exc is not None:
|
||||||
|
logger.warning(f"Failed to publish to NATS: {exc}")
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
future.add_done_callback(_on_done)
|
||||||
|
except (RuntimeError, OSError, ValueError, TypeError) as e:
|
||||||
|
logger.warning(f"Failed to schedule NATS publish: {e}")
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close NATS connection."""
|
"""Close NATS connection."""
|
||||||
@@ -279,7 +286,7 @@ class NatsPublisher:
|
|||||||
self._loop,
|
self._loop,
|
||||||
)
|
)
|
||||||
future.result(timeout=5.0)
|
future.result(timeout=5.0)
|
||||||
except Exception as e:
|
except (RuntimeError, OSError, TimeoutError) as e:
|
||||||
logger.debug(f"Error closing NATS connection: {e}")
|
logger.debug(f"Error closing NATS connection: {e}")
|
||||||
finally:
|
finally:
|
||||||
self._nc = None
|
self._nc = None
|
||||||
@@ -93,6 +93,19 @@ class _SelectedSilhouette(TypedDict):
|
|||||||
track_id: int
|
track_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class _VizPayload(TypedDict, total=False):
|
||||||
|
result: DemoResult
|
||||||
|
mask_raw: UInt8[ndarray, "h w"] | None
|
||||||
|
bbox: BBoxXYXY | None
|
||||||
|
bbox_mask: BBoxXYXY | None
|
||||||
|
silhouette: Float[ndarray, "64 44"] | None
|
||||||
|
segmentation_input: NDArray[np.float32] | None
|
||||||
|
track_id: int
|
||||||
|
label: str | None
|
||||||
|
confidence: float | None
|
||||||
|
pose: dict[str, object] | None
|
||||||
|
|
||||||
|
|
||||||
class _FramePacer:
|
class _FramePacer:
|
||||||
_interval_ns: int
|
_interval_ns: int
|
||||||
_next_emit_ns: int | None
|
_next_emit_ns: int | None
|
||||||
@@ -131,7 +144,7 @@ class ScoliosisPipeline:
|
|||||||
_result_export_format: str
|
_result_export_format: str
|
||||||
_result_buffer: list[DemoResult]
|
_result_buffer: list[DemoResult]
|
||||||
_visualizer: OpenCVVisualizer | None
|
_visualizer: OpenCVVisualizer | None
|
||||||
_last_viz_payload: dict[str, object] | None
|
_last_viz_payload: _VizPayload | None
|
||||||
_frame_pacer: _FramePacer | None
|
_frame_pacer: _FramePacer | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -208,10 +221,9 @@ class ScoliosisPipeline:
|
|||||||
return time.monotonic_ns()
|
return time.monotonic_ns()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]:
|
||||||
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType]
|
||||||
np.uint8
|
binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||||
)
|
|
||||||
return cast(UInt8[ndarray, "h w"], binary)
|
return cast(UInt8[ndarray, "h w"], binary)
|
||||||
|
|
||||||
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
||||||
@@ -294,7 +306,7 @@ class ScoliosisPipeline:
|
|||||||
self,
|
self,
|
||||||
frame: UInt8[ndarray, "h w c"],
|
frame: UInt8[ndarray, "h w c"],
|
||||||
metadata: dict[str, object],
|
metadata: dict[str, object],
|
||||||
) -> dict[str, object] | None:
|
) -> _VizPayload | None:
|
||||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||||
timestamp_ns = self._extract_timestamp(metadata)
|
timestamp_ns = self._extract_timestamp(metadata)
|
||||||
|
|
||||||
@@ -323,8 +335,8 @@ class ScoliosisPipeline:
|
|||||||
bbox = selected["bbox_frame"]
|
bbox = selected["bbox_frame"]
|
||||||
bbox_mask = selected["bbox_mask"]
|
bbox_mask = selected["bbox_mask"]
|
||||||
track_id = selected["track_id"]
|
track_id = selected["track_id"]
|
||||||
|
pose_data = None
|
||||||
|
|
||||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
|
||||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||||
self._silhouette_buffer.append(
|
self._silhouette_buffer.append(
|
||||||
{
|
{
|
||||||
@@ -350,7 +362,9 @@ class ScoliosisPipeline:
|
|||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": None,
|
"label": None,
|
||||||
"confidence": None,
|
"confidence": None,
|
||||||
|
"pose": pose_data,
|
||||||
}
|
}
|
||||||
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||||
|
|
||||||
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
||||||
timestamp_ns
|
timestamp_ns
|
||||||
@@ -364,9 +378,8 @@ class ScoliosisPipeline:
|
|||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": None,
|
"label": None,
|
||||||
"confidence": None,
|
"confidence": None,
|
||||||
|
"pose": pose_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
||||||
segmentation_input = self._window.buffered_silhouettes
|
segmentation_input = self._window.buffered_silhouettes
|
||||||
|
|
||||||
if not self._window.should_classify():
|
if not self._window.should_classify():
|
||||||
@@ -380,8 +393,8 @@ class ScoliosisPipeline:
|
|||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": None,
|
"label": None,
|
||||||
"confidence": None,
|
"confidence": None,
|
||||||
|
"pose": pose_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
window_tensor = self._window.get_tensor(device=self._device)
|
window_tensor = self._window.get_tensor(device=self._device)
|
||||||
label, confidence = cast(
|
label, confidence = cast(
|
||||||
tuple[str, float],
|
tuple[str, float],
|
||||||
@@ -415,6 +428,7 @@ class ScoliosisPipeline:
|
|||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": label,
|
"label": label,
|
||||||
"confidence": confidence,
|
"confidence": confidence,
|
||||||
|
"pose": pose_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
def run(self) -> int:
|
def run(self) -> int:
|
||||||
@@ -445,7 +459,7 @@ class ScoliosisPipeline:
|
|||||||
viz_payload = None
|
viz_payload = None
|
||||||
try:
|
try:
|
||||||
viz_payload = self.process_frame(frame_u8, metadata)
|
viz_payload = self.process_frame(frame_u8, metadata)
|
||||||
except Exception as frame_error:
|
except (RuntimeError, ValueError, TypeError, OSError) as frame_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping frame %d due to processing error: %s",
|
"Skipping frame %d due to processing error: %s",
|
||||||
frame_idx,
|
frame_idx,
|
||||||
@@ -457,8 +471,8 @@ class ScoliosisPipeline:
|
|||||||
# Cache valid payload for no-detection frames
|
# Cache valid payload for no-detection frames
|
||||||
if viz_payload is not None:
|
if viz_payload is not None:
|
||||||
# Cache a copy to prevent mutation of original data
|
# Cache a copy to prevent mutation of original data
|
||||||
viz_payload_dict = cast(dict[str, object], viz_payload)
|
viz_payload_dict = cast(_VizPayload, viz_payload)
|
||||||
cached: dict[str, object] = {}
|
cached: _VizPayload = {}
|
||||||
for k, v in viz_payload_dict.items():
|
for k, v in viz_payload_dict.items():
|
||||||
copy_method = cast(
|
copy_method = cast(
|
||||||
Callable[[], object] | None, getattr(v, "copy", None)
|
Callable[[], object] | None, getattr(v, "copy", None)
|
||||||
@@ -477,12 +491,12 @@ class ScoliosisPipeline:
|
|||||||
viz_data["bbox_mask"] = None
|
viz_data["bbox_mask"] = None
|
||||||
viz_data["label"] = None
|
viz_data["label"] = None
|
||||||
viz_data["confidence"] = None
|
viz_data["confidence"] = None
|
||||||
|
viz_data["pose"] = None
|
||||||
else:
|
else:
|
||||||
viz_data = None
|
viz_data = None
|
||||||
|
|
||||||
if viz_data is not None:
|
if viz_data is not None:
|
||||||
# Cast viz_payload to dict for type checking
|
# Cast viz_payload to dict for type checking
|
||||||
viz_dict = cast(dict[str, object], viz_data)
|
viz_dict = cast(_VizPayload, viz_data)
|
||||||
mask_raw_obj = viz_dict.get("mask_raw")
|
mask_raw_obj = viz_dict.get("mask_raw")
|
||||||
bbox_obj = viz_dict.get("bbox")
|
bbox_obj = viz_dict.get("bbox")
|
||||||
bbox_mask_obj = viz_dict.get("bbox_mask")
|
bbox_mask_obj = viz_dict.get("bbox_mask")
|
||||||
@@ -492,8 +506,8 @@ class ScoliosisPipeline:
|
|||||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||||
label_obj = viz_dict.get("label")
|
label_obj = viz_dict.get("label")
|
||||||
confidence_obj = viz_dict.get("confidence")
|
confidence_obj = viz_dict.get("confidence")
|
||||||
|
pose_obj = viz_dict.get("pose")
|
||||||
|
|
||||||
# Cast extracted values to expected types
|
|
||||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||||
bbox = cast(BBoxXYXY | None, bbox_obj)
|
bbox = cast(BBoxXYXY | None, bbox_obj)
|
||||||
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
|
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
|
||||||
@@ -504,6 +518,7 @@ class ScoliosisPipeline:
|
|||||||
)
|
)
|
||||||
label = cast(str | None, label_obj)
|
label = cast(str | None, label_obj)
|
||||||
confidence = cast(float | None, confidence_obj)
|
confidence = cast(float | None, confidence_obj)
|
||||||
|
pose_data = cast(dict[str, object] | None, pose_obj)
|
||||||
else:
|
else:
|
||||||
# No detection and no cache - use default values
|
# No detection and no cache - use default values
|
||||||
mask_raw = None
|
mask_raw = None
|
||||||
@@ -514,19 +529,37 @@ class ScoliosisPipeline:
|
|||||||
segmentation_input = None
|
segmentation_input = None
|
||||||
label = None
|
label = None
|
||||||
confidence = None
|
confidence = None
|
||||||
|
pose_data = None
|
||||||
|
|
||||||
keep_running = self._visualizer.update(
|
# Try keyword arg for pose_data (backward compatible with old signatures)
|
||||||
frame_u8,
|
try:
|
||||||
bbox,
|
keep_running = self._visualizer.update(
|
||||||
bbox_mask,
|
frame_u8,
|
||||||
track_id,
|
bbox,
|
||||||
mask_raw,
|
bbox_mask,
|
||||||
silhouette,
|
track_id,
|
||||||
segmentation_input,
|
mask_raw,
|
||||||
label,
|
silhouette,
|
||||||
confidence,
|
segmentation_input,
|
||||||
ema_fps,
|
label,
|
||||||
)
|
confidence,
|
||||||
|
ema_fps,
|
||||||
|
pose_data=pose_data,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# Fallback for legacy visualizers that don't accept pose_data
|
||||||
|
keep_running = self._visualizer.update(
|
||||||
|
frame_u8,
|
||||||
|
bbox,
|
||||||
|
bbox_mask,
|
||||||
|
track_id,
|
||||||
|
mask_raw,
|
||||||
|
silhouette,
|
||||||
|
segmentation_input,
|
||||||
|
label,
|
||||||
|
confidence,
|
||||||
|
ema_fps,
|
||||||
|
)
|
||||||
if not keep_running:
|
if not keep_running:
|
||||||
logger.info("Visualization closed by user.")
|
logger.info("Visualization closed by user.")
|
||||||
break
|
break
|
||||||
@@ -635,7 +668,7 @@ class ScoliosisPipeline:
|
|||||||
frames.append(item["frame"])
|
frames.append(item["frame"])
|
||||||
track_ids.append(item["track_id"])
|
track_ids.append(item["track_id"])
|
||||||
timestamps.append(item["timestamp_ns"])
|
timestamps.append(item["timestamp_ns"])
|
||||||
silhouette_array = cast(ndarray, item["silhouette"])
|
silhouette_array = cast(NDArray[np.float32], item["silhouette"])
|
||||||
silhouettes.append(silhouette_array.flatten().tolist())
|
silhouettes.append(silhouette_array.flatten().tolist())
|
||||||
|
|
||||||
table = pa.table(
|
table = pa.table(
|
||||||
@@ -830,6 +863,12 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Directory to save silhouette PNG visualizations.",
|
help="Directory to save silhouette PNG visualizations.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--visualize",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Enable real-time visualization.",
|
||||||
|
)
|
||||||
def main(
|
def main(
|
||||||
source: str,
|
source: str,
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
@@ -839,7 +878,7 @@ def main(
|
|||||||
window: int,
|
window: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
window_mode: str,
|
window_mode: str,
|
||||||
target_fps: float | None,
|
target_fps: float,
|
||||||
no_target_fps: bool,
|
no_target_fps: bool,
|
||||||
nats_url: str | None,
|
nats_url: str | None,
|
||||||
nats_subject: str,
|
nats_subject: str,
|
||||||
@@ -850,7 +889,10 @@ def main(
|
|||||||
result_export_path: str | None,
|
result_export_path: str | None,
|
||||||
result_export_format: str,
|
result_export_format: str,
|
||||||
silhouette_visualize_dir: str | None,
|
silhouette_visualize_dir: str | None,
|
||||||
|
visualize: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Resolve effective target_fps: respect --no-target_fps to disable pacing
|
||||||
|
effective_target_fps = None if no_target_fps else target_fps
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
@@ -884,7 +926,6 @@ def main(
|
|||||||
yolo_model=yolo_model,
|
yolo_model=yolo_model,
|
||||||
window=window,
|
window=window,
|
||||||
stride=effective_stride,
|
stride=effective_stride,
|
||||||
target_fps=None if no_target_fps else target_fps,
|
|
||||||
nats_url=nats_url,
|
nats_url=nats_url,
|
||||||
nats_subject=nats_subject,
|
nats_subject=nats_subject,
|
||||||
max_frames=max_frames,
|
max_frames=max_frames,
|
||||||
@@ -894,6 +935,8 @@ def main(
|
|||||||
silhouette_visualize_dir=silhouette_visualize_dir,
|
silhouette_visualize_dir=silhouette_visualize_dir,
|
||||||
result_export_path=result_export_path,
|
result_export_path=result_export_path,
|
||||||
result_export_format=result_export_format,
|
result_export_format=result_export_format,
|
||||||
|
visualize=visualize,
|
||||||
|
target_fps=effective_target_fps,
|
||||||
)
|
)
|
||||||
raise SystemExit(pipeline.run())
|
raise SystemExit(pipeline.run())
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
from typing import ClassVar, Protocol, cast, override
|
from typing import ClassVar, Protocol, cast, override
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,10 +12,6 @@ from jaxtyping import Float
|
|||||||
import jaxtyping
|
import jaxtyping
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
|
||||||
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
|
|
||||||
|
|
||||||
from opengait.modeling.backbones.resnet import ResNet9
|
from opengait.modeling.backbones.resnet import ResNet9
|
||||||
from opengait.modeling.modules import (
|
from opengait.modeling.modules import (
|
||||||
HorizontalPoolingPyramid,
|
HorizontalPoolingPyramid,
|
||||||
@@ -41,10 +41,54 @@ COLOR_BLACK = (0, 0, 0)
|
|||||||
COLOR_DARK_GRAY = (56, 56, 56)
|
COLOR_DARK_GRAY = (56, 56, 56)
|
||||||
COLOR_RED = (0, 0, 255)
|
COLOR_RED = (0, 0, 255)
|
||||||
COLOR_YELLOW = (0, 255, 255)
|
COLOR_YELLOW = (0, 255, 255)
|
||||||
|
|
||||||
# Type alias for image arrays (NDArray or cv2.Mat)
|
# Type alias for image arrays (NDArray or cv2.Mat)
|
||||||
|
COLOR_CYAN = (255, 255, 0)
|
||||||
|
COLOR_ORANGE = (0, 165, 255)
|
||||||
|
COLOR_MAGENTA = (255, 0, 255)
|
||||||
ImageArray = NDArray[np.uint8]
|
ImageArray = NDArray[np.uint8]
|
||||||
|
|
||||||
|
# COCO-format skeleton connections (17 keypoints)
|
||||||
|
# Connections are pairs of keypoint indices
|
||||||
|
SKELETON_CONNECTIONS: list[tuple[int, int]] = [
|
||||||
|
(0, 1), # nose -> left_eye
|
||||||
|
(0, 2), # nose -> right_eye
|
||||||
|
(1, 3), # left_eye -> left_ear
|
||||||
|
(2, 4), # right_eye -> right_ear
|
||||||
|
(5, 6), # left_shoulder -> right_shoulder
|
||||||
|
(5, 7), # left_shoulder -> left_elbow
|
||||||
|
(7, 9), # left_elbow -> left_wrist
|
||||||
|
(6, 8), # right_shoulder -> right_elbow
|
||||||
|
(8, 10), # right_elbow -> right_wrist
|
||||||
|
(11, 12), # left_hip -> right_hip
|
||||||
|
(5, 11), # left_shoulder -> left_hip
|
||||||
|
(6, 12), # right_shoulder -> right_hip
|
||||||
|
(11, 13), # left_hip -> left_knee
|
||||||
|
(13, 15), # left_knee -> left_ankle
|
||||||
|
(12, 14), # right_hip -> right_knee
|
||||||
|
(14, 16), # right_knee -> right_ankle
|
||||||
|
]
|
||||||
|
|
||||||
|
# Keypoint names for COCO format (17 keypoints)
|
||||||
|
KEYPOINT_NAMES: list[str] = [
|
||||||
|
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||||
|
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||||
|
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||||
|
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Joints where angles are typically calculated (for scoliosis/ gait analysis)
|
||||||
|
ANGLE_JOINTS: list[tuple[int, int, int]] = [
|
||||||
|
(5, 7, 9), # left_shoulder -> left_elbow -> left_wrist
|
||||||
|
(6, 8, 10), # right_shoulder -> right_elbow -> right_wrist
|
||||||
|
(7, 5, 11), # left_elbow -> left_shoulder -> left_hip
|
||||||
|
(8, 6, 12), # right_elbow -> right_shoulder -> right_hip
|
||||||
|
(5, 11, 13), # left_shoulder -> left_hip -> left_knee
|
||||||
|
(6, 12, 14), # right_shoulder -> right_hip -> right_knee
|
||||||
|
(11, 13, 15),# left_hip -> left_knee -> left_ankle
|
||||||
|
(12, 14, 16),# right_hip -> right_knee -> right_ankle
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OpenCVVisualizer:
|
class OpenCVVisualizer:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -149,6 +193,134 @@ class OpenCVVisualizer:
|
|||||||
thickness,
|
thickness,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _draw_pose_skeleton(
|
||||||
|
self,
|
||||||
|
frame: ImageArray,
|
||||||
|
pose_data: dict[str, object] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Draw pose skeleton on frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||||
|
pose_data: Pose data dictionary from Sports2D or similar
|
||||||
|
Expected format: {'keypoints': [[x1, y1], [x2, y2], ...],
|
||||||
|
'confidence': [c1, c2, ...],
|
||||||
|
'angles': {'joint_name': angle, ...}}
|
||||||
|
"""
|
||||||
|
if pose_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
keypoints_obj = pose_data.get('keypoints')
|
||||||
|
if keypoints_obj is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert keypoints to numpy array
|
||||||
|
keypoints = np.asarray(keypoints_obj, dtype=np.float32)
|
||||||
|
if keypoints.size == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
|
||||||
|
# Get confidence scores if available
|
||||||
|
confidence_obj = pose_data.get('confidence')
|
||||||
|
confidences = (
|
||||||
|
np.asarray(confidence_obj, dtype=np.float32)
|
||||||
|
if confidence_obj is not None
|
||||||
|
else np.ones(len(keypoints), dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw skeleton connections
|
||||||
|
for connection in SKELETON_CONNECTIONS:
|
||||||
|
idx1, idx2 = connection
|
||||||
|
if idx1 < len(keypoints) and idx2 < len(keypoints):
|
||||||
|
# Check confidence threshold (0.3)
|
||||||
|
if confidences[idx1] > 0.3 and confidences[idx2] > 0.3:
|
||||||
|
pt1 = (int(keypoints[idx1][0]), int(keypoints[idx1][1]))
|
||||||
|
pt2 = (int(keypoints[idx2][0]), int(keypoints[idx2][1]))
|
||||||
|
# Clip to frame bounds
|
||||||
|
pt1 = (max(0, min(w - 1, pt1[0])), max(0, min(h - 1, pt1[1])))
|
||||||
|
pt2 = (max(0, min(w - 1, pt2[0])), max(0, min(h - 1, pt2[1])))
|
||||||
|
_ = cv2.line(frame, pt1, pt2, COLOR_CYAN, 2)
|
||||||
|
|
||||||
|
# Draw keypoints
|
||||||
|
for i, (kp, conf) in enumerate(zip(keypoints, confidences)):
|
||||||
|
if conf > 0.3 and i < len(keypoints):
|
||||||
|
x, y = int(kp[0]), int(kp[1])
|
||||||
|
# Clip to frame bounds
|
||||||
|
x = max(0, min(w - 1, x))
|
||||||
|
y = max(0, min(h - 1, y))
|
||||||
|
# Draw keypoint as circle
|
||||||
|
_ = cv2.circle(frame, (x, y), 4, COLOR_MAGENTA, -1)
|
||||||
|
_ = cv2.circle(frame, (x, y), 4, COLOR_WHITE, 1)
|
||||||
|
|
||||||
|
def _draw_pose_angles(
|
||||||
|
self,
|
||||||
|
frame: ImageArray,
|
||||||
|
pose_data: dict[str, object] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Draw pose angles as text overlay.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||||
|
pose_data: Pose data dictionary with 'angles' key
|
||||||
|
"""
|
||||||
|
if pose_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
angles_obj = pose_data.get('angles')
|
||||||
|
if angles_obj is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
angles = cast(dict[str, float], angles_obj)
|
||||||
|
if not angles:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Draw angles in top-right corner
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = 0.45
|
||||||
|
thickness = 1
|
||||||
|
line_height = 20
|
||||||
|
margin = 10
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
|
||||||
|
# Filter and format angles
|
||||||
|
angle_texts: list[tuple[str, float]] = []
|
||||||
|
for name, angle in angles.items():
|
||||||
|
# Only show angles that are reasonable (0-180 degrees)
|
||||||
|
if 0 <= angle <= 180:
|
||||||
|
angle_texts.append((str(name), float(angle)))
|
||||||
|
|
||||||
|
# Sort by name for consistent display
|
||||||
|
angle_texts.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# Draw from top-right
|
||||||
|
for i, (name, angle) in enumerate(angle_texts[:8]): # Limit to 8 angles
|
||||||
|
text = f"{name}: {angle:.1f}"
|
||||||
|
(text_width, text_height), _ = cv2.getTextSize(
|
||||||
|
text, font, font_scale, thickness
|
||||||
|
)
|
||||||
|
x_pos = w - margin - text_width - 10
|
||||||
|
y_pos = margin + (i + 1) * line_height
|
||||||
|
|
||||||
|
# Draw background rectangle
|
||||||
|
_ = cv2.rectangle(
|
||||||
|
frame,
|
||||||
|
(x_pos - 4, y_pos - text_height - 4),
|
||||||
|
(x_pos + text_width + 4, y_pos + 4),
|
||||||
|
COLOR_BLACK,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
# Draw text in orange
|
||||||
|
_ = cv2.putText(
|
||||||
|
frame,
|
||||||
|
text,
|
||||||
|
(x_pos, y_pos),
|
||||||
|
font,
|
||||||
|
font_scale,
|
||||||
|
COLOR_ORANGE,
|
||||||
|
thickness,
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_main_frame(
|
def _prepare_main_frame(
|
||||||
self,
|
self,
|
||||||
frame: ImageArray,
|
frame: ImageArray,
|
||||||
@@ -157,6 +329,7 @@ class OpenCVVisualizer:
|
|||||||
fps: float,
|
fps: float,
|
||||||
label: str | None,
|
label: str | None,
|
||||||
confidence: float | None,
|
confidence: float | None,
|
||||||
|
pose_data: dict[str, object] | None = None,
|
||||||
) -> ImageArray:
|
) -> ImageArray:
|
||||||
"""Prepare main display frame with bbox and text overlay.
|
"""Prepare main display frame with bbox and text overlay.
|
||||||
|
|
||||||
@@ -167,6 +340,7 @@ class OpenCVVisualizer:
|
|||||||
fps: Current FPS
|
fps: Current FPS
|
||||||
label: Classification label or None
|
label: Classification label or None
|
||||||
confidence: Classification confidence or None
|
confidence: Classification confidence or None
|
||||||
|
pose_data: Pose data dictionary or None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Processed frame ready for display
|
Processed frame ready for display
|
||||||
@@ -187,6 +361,10 @@ class OpenCVVisualizer:
|
|||||||
self._draw_bbox(display_frame, bbox)
|
self._draw_bbox(display_frame, bbox)
|
||||||
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
|
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
|
||||||
|
|
||||||
|
# Draw pose skeleton and angles if available
|
||||||
|
self._draw_pose_skeleton(display_frame, pose_data)
|
||||||
|
self._draw_pose_angles(display_frame, pose_data)
|
||||||
|
|
||||||
return display_frame
|
return display_frame
|
||||||
|
|
||||||
def _upscale_silhouette(
|
def _upscale_silhouette(
|
||||||
@@ -521,6 +699,7 @@ class OpenCVVisualizer:
|
|||||||
label: str | None,
|
label: str | None,
|
||||||
confidence: float | None,
|
confidence: float | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
|
pose_data: dict[str, object] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update visualization with new frame data.
|
"""Update visualization with new frame data.
|
||||||
|
|
||||||
@@ -533,6 +712,7 @@ class OpenCVVisualizer:
|
|||||||
label: Classification label or None
|
label: Classification label or None
|
||||||
confidence: Classification confidence [0,1] or None
|
confidence: Classification confidence [0,1] or None
|
||||||
fps: Current FPS
|
fps: Current FPS
|
||||||
|
pose_data: Pose data dictionary or None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
False if user requested quit (pressed 'q'), True otherwise
|
False if user requested quit (pressed 'q'), True otherwise
|
||||||
@@ -541,7 +721,7 @@ class OpenCVVisualizer:
|
|||||||
|
|
||||||
# Prepare and show main window
|
# Prepare and show main window
|
||||||
main_display = self._prepare_main_frame(
|
main_display = self._prepare_main_frame(
|
||||||
frame, bbox, track_id, fps, label, confidence
|
frame, bbox, track_id, fps, label, confidence, pose_data
|
||||||
)
|
)
|
||||||
cv2.imshow(MAIN_WINDOW, main_display)
|
cv2.imshow(MAIN_WINDOW, main_display)
|
||||||
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import get_msg_mgr
|
from opengait.utils import get_msg_mgr
|
||||||
|
|
||||||
|
|
||||||
class CollateFn(object):
|
class CollateFn(object):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import pickle
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import torch.utils.data as tordata
|
import torch.utils.data as tordata
|
||||||
import json
|
import json
|
||||||
from utils import get_msg_mgr
|
from opengait.utils import get_msg_mgr
|
||||||
|
|
||||||
|
|
||||||
class DataSet(tordata.Dataset):
|
class DataSet(tordata.Dataset):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import torchvision.transforms as T
|
|||||||
import cv2
|
import cv2
|
||||||
import math
|
import math
|
||||||
from data import transform as base_transform
|
from data import transform as base_transform
|
||||||
from utils import is_list, is_dict, get_valid_args
|
from opengait.utils import is_list, is_dict, get_valid_args
|
||||||
|
|
||||||
|
|
||||||
class NoOperation():
|
class NoOperation():
|
||||||
|
|||||||
@@ -1,149 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
|
|
||||||
|
|
||||||
|
|
||||||
def _positive_float(value: str) -> float:
|
|
||||||
parsed = float(value)
|
|
||||||
if parsed <= 0:
|
|
||||||
raise argparse.ArgumentTypeError("target-fps must be positive")
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline")
|
|
||||||
parser.add_argument(
|
|
||||||
"--source", type=str, required=True, help="Video source path or camera ID"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint", type=str, required=True, help="Model checkpoint path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
||||||
help="Config file path",
|
|
||||||
)
|
|
||||||
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on")
|
|
||||||
parser.add_argument(
|
|
||||||
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--window", type=int, default=30, help="Window size for classification"
|
|
||||||
)
|
|
||||||
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
|
|
||||||
parser.add_argument(
|
|
||||||
"--target-fps",
|
|
||||||
type=_positive_float,
|
|
||||||
default=15.0,
|
|
||||||
help="Target FPS for temporal downsampling before windowing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--window-mode",
|
|
||||||
type=str,
|
|
||||||
choices=["manual", "sliding", "chunked"],
|
|
||||||
default="manual",
|
|
||||||
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-target-fps",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable temporal downsampling and use all frames",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--nats-subject", type=str, default="scoliosis.result", help="NATS subject"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-frames", type=int, default=None, help="Maximum frames to process"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--preprocess-only", action="store_true", help="Only preprocess silhouettes"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--silhouette-export-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to export silhouettes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--silhouette-export-format", type=str, default="pickle", help="Export format"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--silhouette-visualize-dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Directory for silhouette visualizations",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--result-export-path", type=str, default=None, help="Path to export results"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--result-export-format", type=str, default="json", help="Result export format"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--visualize", action="store_true", help="Enable real-time visualization"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate preprocess-only mode requires silhouette export path
|
|
||||||
if args.preprocess_only and not args.silhouette_export_path:
|
|
||||||
print(
|
|
||||||
"Error: --silhouette-export-path is required when using --preprocess-only",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise SystemExit(2)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from .pipeline import validate_runtime_inputs
|
|
||||||
|
|
||||||
validate_runtime_inputs(
|
|
||||||
source=args.source, checkpoint=args.checkpoint, config=args.config
|
|
||||||
)
|
|
||||||
|
|
||||||
effective_stride = resolve_stride(
|
|
||||||
window=cast(int, args.window),
|
|
||||||
stride=cast(int, args.stride),
|
|
||||||
window_mode=cast(WindowMode, args.window_mode),
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = ScoliosisPipeline(
|
|
||||||
source=cast(str, args.source),
|
|
||||||
checkpoint=cast(str, args.checkpoint),
|
|
||||||
config=cast(str, args.config),
|
|
||||||
device=cast(str, args.device),
|
|
||||||
yolo_model=cast(str, args.yolo_model),
|
|
||||||
window=cast(int, args.window),
|
|
||||||
stride=effective_stride,
|
|
||||||
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
|
|
||||||
nats_url=cast(str | None, args.nats_url),
|
|
||||||
nats_subject=cast(str, args.nats_subject),
|
|
||||||
max_frames=cast(int | None, args.max_frames),
|
|
||||||
preprocess_only=cast(bool, args.preprocess_only),
|
|
||||||
silhouette_export_path=cast(str | None, args.silhouette_export_path),
|
|
||||||
silhouette_export_format=cast(str, args.silhouette_export_format),
|
|
||||||
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
|
|
||||||
result_export_path=cast(str | None, args.result_export_path),
|
|
||||||
result_export_format=cast(str, args.result_export_format),
|
|
||||||
visualize=cast(bool, args.visualize),
|
|
||||||
)
|
|
||||||
raise SystemExit(pipeline.run())
|
|
||||||
except ValueError as err:
|
|
||||||
print(f"Error: {err}", file=sys.stderr)
|
|
||||||
raise SystemExit(2) from err
|
|
||||||
except RuntimeError as err:
|
|
||||||
print(f"Runtime error: {err}", file=sys.stderr)
|
|
||||||
raise SystemExit(1) from err
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from time import strftime, localtime
|
from time import strftime, localtime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import get_msg_mgr, mkdir
|
from opengait.utils import get_msg_mgr, mkdir
|
||||||
|
|
||||||
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
|
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
|
||||||
from .re_rank import re_ranking
|
from .re_rank import re_ranking
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from utils import is_tensor
|
from opengait.utils import is_tensor
|
||||||
|
|
||||||
|
|
||||||
def cuda_dist(x, y, metric='euc'):
|
def cuda_dist(x, y, metric='euc'):
|
||||||
|
|||||||
+1
-1
@@ -4,7 +4,7 @@ import argparse
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from modeling import models
|
from modeling import models
|
||||||
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
|
from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Main program for opengait.')
|
parser = argparse.ArgumentParser(description='Main program for opengait.')
|
||||||
parser.add_argument('--local_rank', type=int, default=0,
|
parser.add_argument('--local_rank', type=int, default=0,
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ from data.transform import get_transform
|
|||||||
from data.collate_fn import CollateFn
|
from data.collate_fn import CollateFn
|
||||||
from data.dataset import DataSet
|
from data.dataset import DataSet
|
||||||
import data.sampler as Samplers
|
import data.sampler as Samplers
|
||||||
from utils import Odict, mkdir, ddp_all_gather
|
from opengait.utils import Odict, mkdir, ddp_all_gather
|
||||||
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
||||||
from evaluation import evaluator as eval_functions
|
from evaluation import evaluator as eval_functions
|
||||||
from utils import NoOp
|
from opengait.utils import NoOp
|
||||||
from utils import get_msg_mgr
|
from opengait.utils import get_msg_mgr
|
||||||
|
|
||||||
__all__ = ['BaseModel']
|
__all__ = ['BaseModel']
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from . import losses
|
from . import losses
|
||||||
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
|
from opengait.utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
|
||||||
from utils import Odict
|
from opengait.utils import Odict
|
||||||
from utils import get_msg_mgr
|
from opengait.utils import get_msg_mgr
|
||||||
|
|
||||||
|
|
||||||
class LossAggregator(nn.Module):
|
class LossAggregator(nn.Module):
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from ctypes import ArgumentError
|
from ctypes import ArgumentError
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch
|
import torch
|
||||||
from utils import Odict
|
from opengait.utils import Odict
|
||||||
import functools
|
import functools
|
||||||
from utils import ddp_all_gather
|
from opengait.utils import ddp_all_gather
|
||||||
|
|
||||||
|
|
||||||
def gather_and_scale_wrapper(func):
|
def gather_and_scale_wrapper(func):
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class Post_ResNet9(ResNet):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
||||||
from ... import backbones
|
from ... import backbones
|
||||||
class Baseline(nn.Module):
|
class Baseline(nn.Module):
|
||||||
def __init__(self, model_cfg):
|
def __init__(self, model_cfg):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from ..base_model import BaseModel
|
|||||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc
|
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from utils import get_valid_args
|
from opengait.utils import get_valid_args
|
||||||
import warnings
|
import warnings
|
||||||
import random
|
import random
|
||||||
from torchvision.utils import flow_to_image
|
from torchvision.utils import flow_to_image
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ class Post_ResNet9(ResNet):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
||||||
from ... import backbones
|
from ... import backbones
|
||||||
class GaitBaseFusion_denoise(nn.Module):
|
class GaitBaseFusion_denoise(nn.Module):
|
||||||
def __init__(self, model_cfg):
|
def __init__(self, model_cfg):
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from ..base_model import BaseModel
|
|||||||
from .gaitgl import GaitGL
|
from .gaitgl import GaitGL
|
||||||
from ..modules import GaitAlign
|
from ..modules import GaitAlign
|
||||||
from torchvision.transforms import Resize
|
from torchvision.transforms import Resize
|
||||||
from utils import get_valid_args, get_attr_from, is_list_or_tuple
|
from opengait.utils import get_valid_args, get_attr_from, is_list_or_tuple
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ..base_model import BaseModel
|
from ..base_model import BaseModel
|
||||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
|
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
|
||||||
from utils import clones
|
from opengait.utils import clones
|
||||||
|
|
||||||
|
|
||||||
class BasicConv1d(nn.Module):
|
class BasicConv1d(nn.Module):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||||||
from ..base_model import BaseModel
|
from ..base_model import BaseModel
|
||||||
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
|
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
|
||||||
|
|
||||||
from utils import np2var, list2var, get_valid_args, ddp_all_gather
|
from opengait.utils import np2var, list2var, get_valid_args, ddp_all_gather
|
||||||
from data.transform import get_transform
|
from data.transform import get_transform
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ class GaitSSB_Pretrain(BaseModel):
|
|||||||
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import get_valid_args, list2var
|
from opengait.utils import get_valid_args, list2var
|
||||||
|
|
||||||
class no_grad(torch.no_grad):
|
class no_grad(torch.no_grad):
|
||||||
def __init__(self, enable=True):
|
def __init__(self, enable=True):
|
||||||
|
|||||||
@@ -782,7 +782,7 @@ from ..modules import BasicBlock2D, BasicBlockP3D
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from utils import get_valid_args, get_attr_from
|
from opengait.utils import get_valid_args, get_attr_from
|
||||||
|
|
||||||
class SwinGait(BaseModel):
|
class SwinGait(BaseModel):
|
||||||
def __init__(self, cfgs, training):
|
def __init__(self, cfgs, training):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from utils import clones, is_list_or_tuple
|
from opengait.utils import clones, is_list_or_tuple
|
||||||
from torchvision.ops import RoIAlign
|
from torchvision.ops import RoIAlign
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+5
-1
@@ -20,6 +20,7 @@ dependencies = [
|
|||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"cvmmap-client",
|
"cvmmap-client",
|
||||||
|
"nats-py",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -35,7 +36,10 @@ wandb = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
packages = ["opengait"]
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = [".", "opengait-studio"]
|
||||||
|
include = ["opengait", "opengait.*", "opengait_studio", "opengait_studio.*"]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ class TestNatsPublisherIntegration:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pytest.skip("nats-py not installed")
|
pytest.skip("nats-py not installed")
|
||||||
|
|
||||||
from opengait.demo.output import NatsPublisher, create_result
|
from opengait_studio.output import NatsPublisher, create_result
|
||||||
|
|
||||||
# Create publisher
|
# Create publisher
|
||||||
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
||||||
@@ -341,7 +341,7 @@ class TestNatsPublisherIntegration:
|
|||||||
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
||||||
"""Test that publisher handles missing server gracefully."""
|
"""Test that publisher handles missing server gracefully."""
|
||||||
try:
|
try:
|
||||||
from opengait.demo.output import NatsPublisher
|
from opengait_studio.output import NatsPublisher
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pytest.skip("output module not available")
|
pytest.skip("output module not available")
|
||||||
|
|
||||||
@@ -380,7 +380,7 @@ class TestNatsPublisherIntegration:
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import nats # type: ignore[import-untyped]
|
import nats # type: ignore[import-untyped]
|
||||||
from opengait.demo.output import NatsPublisher, create_result
|
from opengait_studio.output import NatsPublisher, create_result
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
pytest.skip(f"Required module not available: {e}")
|
pytest.skip(f"Required module not available: {e}")
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ from numpy.typing import NDArray
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
|
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
|
||||||
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
|
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
|
||||||
@@ -31,7 +31,7 @@ def _device_for_runtime() -> str:
|
|||||||
def _run_pipeline_cli(
|
def _run_pipeline_cli(
|
||||||
*args: str, timeout_seconds: int = 120
|
*args: str, timeout_seconds: int = 120
|
||||||
) -> subprocess.CompletedProcess[str]:
|
) -> subprocess.CompletedProcess[str]:
|
||||||
command = [sys.executable, "-m", "opengait.demo", *args]
|
command = [sys.executable, "-m", "opengait_studio", *args]
|
||||||
return subprocess.run(
|
return subprocess.run(
|
||||||
command,
|
command,
|
||||||
cwd=REPO_ROOT,
|
cwd=REPO_ROOT,
|
||||||
@@ -728,14 +728,14 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
|||||||
This is a regression test for the window freeze issue when no person is detected.
|
This is a regression test for the window freeze issue when no person is detected.
|
||||||
The window should refresh every frame to prevent freezing.
|
The window should refresh every frame to prevent freezing.
|
||||||
"""
|
"""
|
||||||
from opengait.demo.pipeline import ScoliosisPipeline
|
from opengait_studio.pipeline import ScoliosisPipeline
|
||||||
|
|
||||||
# Create a minimal pipeline with mocked dependencies
|
# Create a minimal pipeline with mocked dependencies
|
||||||
with (
|
with (
|
||||||
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
|
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
|
||||||
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
|
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
|
||||||
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
|
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
|
||||||
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
|
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
|
||||||
):
|
):
|
||||||
# Setup mock detector that returns no detections (causing process_frame to return None)
|
# Setup mock detector that returns no detections (causing process_frame to return None)
|
||||||
mock_detector = mock.MagicMock()
|
mock_detector = mock.MagicMock()
|
||||||
@@ -791,16 +791,16 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
|
def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
|
||||||
from opengait.demo.pipeline import ScoliosisPipeline
|
from opengait_studio.pipeline import ScoliosisPipeline
|
||||||
|
|
||||||
# Create a minimal pipeline with mocked dependencies
|
# Create a minimal pipeline with mocked dependencies
|
||||||
with (
|
with (
|
||||||
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
|
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
|
||||||
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
|
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
|
||||||
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
|
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
|
||||||
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
|
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
|
||||||
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person,
|
mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
|
||||||
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
||||||
):
|
):
|
||||||
# Create mock detection result for frames 0-1 (valid detection)
|
# Create mock detection result for frames 0-1 (valid detection)
|
||||||
mock_box = mock.MagicMock()
|
mock_box = mock.MagicMock()
|
||||||
@@ -919,7 +919,7 @@ def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_frame_pacer_emission_count_24_to_15() -> None:
|
def test_frame_pacer_emission_count_24_to_15() -> None:
|
||||||
from opengait.demo.pipeline import _FramePacer
|
from opengait_studio.pipeline import _FramePacer
|
||||||
|
|
||||||
pacer = _FramePacer(15.0)
|
pacer = _FramePacer(15.0)
|
||||||
interval_ns = int(1_000_000_000 / 24)
|
interval_ns = int(1_000_000_000 / 24)
|
||||||
@@ -928,7 +928,7 @@ def test_frame_pacer_emission_count_24_to_15() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_frame_pacer_requires_positive_target_fps() -> None:
|
def test_frame_pacer_requires_positive_target_fps() -> None:
|
||||||
from opengait.demo.pipeline import _FramePacer
|
from opengait_studio.pipeline import _FramePacer
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="target_fps must be positive"):
|
with pytest.raises(ValueError, match="target_fps must be positive"):
|
||||||
_FramePacer(0.0)
|
_FramePacer(0.0)
|
||||||
@@ -950,6 +950,6 @@ def test_resolve_stride_modes(
|
|||||||
mode: Literal["manual", "sliding", "chunked"],
|
mode: Literal["manual", "sliding", "chunked"],
|
||||||
expected: int,
|
expected: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
from opengait.demo.pipeline import resolve_stride
|
from opengait_studio.pipeline import resolve_stride
|
||||||
|
|
||||||
assert resolve_stride(window, stride, mode) == expected
|
assert resolve_stride(window, stride, mode) == expected
|
||||||
@@ -8,7 +8,7 @@ import pytest
|
|||||||
from beartype.roar import BeartypeCallHintParamViolation
|
from beartype.roar import BeartypeCallHintParamViolation
|
||||||
from jaxtyping import TypeCheckError
|
from jaxtyping import TypeCheckError
|
||||||
|
|
||||||
from opengait.demo.preprocess import mask_to_silhouette
|
from opengait_studio.preprocess import mask_to_silhouette
|
||||||
|
|
||||||
|
|
||||||
class TestMaskToSilhouette:
|
class TestMaskToSilhouette:
|
||||||
@@ -18,7 +18,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
# Constants for test configuration
|
# Constants for test configuration
|
||||||
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
||||||
@@ -27,7 +27,7 @@ CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def demo() -> "ScoNetDemo":
|
def demo() -> "ScoNetDemo":
|
||||||
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
|
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
return ScoNetDemo(
|
return ScoNetDemo(
|
||||||
cfg_path=str(CONFIG_PATH),
|
cfg_path=str(CONFIG_PATH),
|
||||||
@@ -71,7 +71,7 @@ class TestScoNetDemoConstruction:
|
|||||||
|
|
||||||
def test_construction_from_config_no_checkpoint(self) -> None:
|
def test_construction_from_config_no_checkpoint(self) -> None:
|
||||||
"""Test construction with config only, no checkpoint."""
|
"""Test construction with config only, no checkpoint."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
demo = ScoNetDemo(
|
demo = ScoNetDemo(
|
||||||
cfg_path=str(CONFIG_PATH),
|
cfg_path=str(CONFIG_PATH),
|
||||||
@@ -87,7 +87,7 @@ class TestScoNetDemoConstruction:
|
|||||||
|
|
||||||
def test_construction_with_relative_path(self) -> None:
|
def test_construction_with_relative_path(self) -> None:
|
||||||
"""Test construction handles relative config path correctly."""
|
"""Test construction handles relative config path correctly."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
demo = ScoNetDemo(
|
demo = ScoNetDemo(
|
||||||
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
|
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
|
||||||
@@ -100,7 +100,7 @@ class TestScoNetDemoConstruction:
|
|||||||
|
|
||||||
def test_construction_invalid_config_raises(self) -> None:
|
def test_construction_invalid_config_raises(self) -> None:
|
||||||
"""Test construction raises with invalid config path."""
|
"""Test construction raises with invalid config path."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
with pytest.raises((FileNotFoundError, TypeError)):
|
with pytest.raises((FileNotFoundError, TypeError)):
|
||||||
_ = ScoNetDemo(
|
_ = ScoNetDemo(
|
||||||
@@ -180,7 +180,7 @@ class TestScoNetDemoPredict:
|
|||||||
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test predict returns (str, float) tuple with valid label."""
|
"""Test predict returns (str, float) tuple with valid label."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
result_raw = demo.predict(dummy_sils_single)
|
result_raw = demo.predict(dummy_sils_single)
|
||||||
result = cast(tuple[str, float], result_raw)
|
result = cast(tuple[str, float], result_raw)
|
||||||
@@ -217,7 +217,7 @@ class TestScoNetDemoNoDDP:
|
|||||||
|
|
||||||
def test_no_distributed_init_in_construction(self) -> None:
|
def test_no_distributed_init_in_construction(self) -> None:
|
||||||
"""Test that construction does not call torch.distributed."""
|
"""Test that construction does not call torch.distributed."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
with patch("torch.distributed.is_initialized") as mock_is_init:
|
with patch("torch.distributed.is_initialized") as mock_is_init:
|
||||||
with patch("torch.distributed.init_process_group") as mock_init_pg:
|
with patch("torch.distributed.init_process_group") as mock_init_pg:
|
||||||
@@ -327,14 +327,14 @@ class TestScoNetDemoLabelMap:
|
|||||||
|
|
||||||
def test_label_map_has_three_classes(self) -> None:
|
def test_label_map_has_three_classes(self) -> None:
|
||||||
"""Test LABEL_MAP has exactly 3 classes."""
|
"""Test LABEL_MAP has exactly 3 classes."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
assert len(ScoNetDemo.LABEL_MAP) == 3
|
assert len(ScoNetDemo.LABEL_MAP) == 3
|
||||||
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
|
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
|
||||||
|
|
||||||
def test_label_map_values_are_valid_strings(self) -> None:
|
def test_label_map_values_are_valid_strings(self) -> None:
|
||||||
"""Test LABEL_MAP values are valid non-empty strings."""
|
"""Test LABEL_MAP values are valid non-empty strings."""
|
||||||
from opengait.demo.sconet_demo import ScoNetDemo
|
from opengait_studio.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
for value in ScoNetDemo.LABEL_MAP.values():
|
for value in ScoNetDemo.LABEL_MAP.values():
|
||||||
assert isinstance(value, str)
|
assert isinstance(value, str)
|
||||||
@@ -7,14 +7,14 @@ from unittest import mock
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from opengait.demo.input import create_source
|
from opengait_studio.input import create_source
|
||||||
from opengait.demo.visualizer import (
|
from opengait_studio.visualizer import (
|
||||||
DISPLAY_HEIGHT,
|
DISPLAY_HEIGHT,
|
||||||
DISPLAY_WIDTH,
|
DISPLAY_WIDTH,
|
||||||
ImageArray,
|
ImageArray,
|
||||||
OpenCVVisualizer,
|
OpenCVVisualizer,
|
||||||
)
|
)
|
||||||
from opengait.demo.window import select_person
|
from opengait_studio.window import select_person
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||||
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
|
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
|
||||||
@@ -8,7 +8,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from opengait.demo.window import SilhouetteWindow, select_person
|
from opengait_studio.window import SilhouetteWindow, select_person
|
||||||
|
|
||||||
|
|
||||||
class TestSilhouetteWindow:
|
class TestSilhouetteWindow:
|
||||||
@@ -614,7 +614,7 @@ name = "cuda-bindings"
|
|||||||
version = "12.9.4"
|
version = "12.9.4"
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "cuda-pathfinder", marker = "sys_platform == 'linux'" },
|
{ name = "cuda-pathfinder", marker = "sys_platform != 'win32'" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" },
|
||||||
@@ -1611,7 +1611,7 @@ name = "nvidia-cudnn-cu12"
|
|||||||
version = "9.10.2.21"
|
version = "9.10.2.21"
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
{ name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||||
@@ -1622,7 +1622,7 @@ name = "nvidia-cufft-cu12"
|
|||||||
version = "11.3.3.83"
|
version = "11.3.3.83"
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||||
@@ -1649,9 +1649,9 @@ name = "nvidia-cusolver-cu12"
|
|||||||
version = "11.7.3.90"
|
version = "11.7.3.90"
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
{ name = "nvidia-cublas-cu12", marker = "sys_platform != 'win32'" },
|
||||||
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
|
{ name = "nvidia-cusparse-cu12", marker = "sys_platform != 'win32'" },
|
||||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||||
@@ -1662,7 +1662,7 @@ name = "nvidia-cusparse-cu12"
|
|||||||
version = "12.5.8.93"
|
version = "12.5.8.93"
|
||||||
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'win32'" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
{ url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||||
@@ -1737,6 +1737,7 @@ dependencies = [
|
|||||||
{ name = "imageio" },
|
{ name = "imageio" },
|
||||||
{ name = "kornia" },
|
{ name = "kornia" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
|
{ name = "nats-py" },
|
||||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" },
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "numpy", version = "2.4.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" },
|
{ name = "numpy", version = "2.4.2", source = { registry = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/" }, marker = "python_full_version >= '3.11'" },
|
||||||
{ name = "opencv-python" },
|
{ name = "opencv-python" },
|
||||||
@@ -1780,6 +1781,7 @@ requires-dist = [
|
|||||||
{ name = "imageio" },
|
{ name = "imageio" },
|
||||||
{ name = "kornia" },
|
{ name = "kornia" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
|
{ name = "nats-py" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "opencv-python" },
|
{ name = "opencv-python" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow" },
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
uv run python -m opengait.demo \
|
uv run python -m opengait_studio \
|
||||||
--source "cvmmap://camera_5602" \
|
--source "cvmmap://camera_5602" \
|
||||||
--checkpoint "ckpt/ScoNet-20000.pt" \
|
--checkpoint "ckpt/ScoNet-20000.pt" \
|
||||||
--config "configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml" \
|
--config "configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml" \
|
||||||
|
|||||||
Reference in New Issue
Block a user