feat: extract opengait_studio monorepo module

Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
2026-03-03 17:16:17 +08:00
parent 5c6bef1ca1
commit 00fcda4fe3
39 changed files with 359 additions and 270 deletions
+1 -1
View File
@@ -1,7 +1,7 @@
import math
import random
import numpy as np
from utils import get_msg_mgr
from opengait.utils import get_msg_mgr
class CollateFn(object):
+1 -1
View File
@@ -3,7 +3,7 @@ import pickle
import os.path as osp
import torch.utils.data as tordata
import json
from utils import get_msg_mgr
from opengait.utils import get_msg_mgr
class DataSet(tordata.Dataset):
+1 -1
View File
@@ -4,7 +4,7 @@ import torchvision.transforms as T
import cv2
import math
from data import transform as base_transform
from utils import is_list, is_dict, get_valid_args
from opengait.utils import is_list, is_dict, get_valid_args
class NoOperation():
View File
-149
View File
@@ -1,149 +0,0 @@
from __future__ import annotations
import argparse
import logging
import sys
from typing import cast
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
def _positive_float(value: str) -> float:
parsed = float(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("target-fps must be positive")
return parsed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline")
parser.add_argument(
"--source", type=str, required=True, help="Video source path or camera ID"
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="Model checkpoint path"
)
parser.add_argument(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
help="Config file path",
)
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on")
parser.add_argument(
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name"
)
parser.add_argument(
"--window", type=int, default=30, help="Window size for classification"
)
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
parser.add_argument(
"--target-fps",
type=_positive_float,
default=15.0,
help="Target FPS for temporal downsampling before windowing",
)
parser.add_argument(
"--window-mode",
type=str,
choices=["manual", "sliding", "chunked"],
default="manual",
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
)
parser.add_argument(
"--no-target-fps",
action="store_true",
help="Disable temporal downsampling and use all frames",
)
parser.add_argument(
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
)
parser.add_argument(
"--nats-subject", type=str, default="scoliosis.result", help="NATS subject"
)
parser.add_argument(
"--max-frames", type=int, default=None, help="Maximum frames to process"
)
parser.add_argument(
"--preprocess-only", action="store_true", help="Only preprocess silhouettes"
)
parser.add_argument(
"--silhouette-export-path",
type=str,
default=None,
help="Path to export silhouettes",
)
parser.add_argument(
"--silhouette-export-format", type=str, default="pickle", help="Export format"
)
parser.add_argument(
"--silhouette-visualize-dir",
type=str,
default=None,
help="Directory for silhouette visualizations",
)
parser.add_argument(
"--result-export-path", type=str, default=None, help="Path to export results"
)
parser.add_argument(
"--result-export-format", type=str, default="json", help="Result export format"
)
parser.add_argument(
"--visualize", action="store_true", help="Enable real-time visualization"
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
# Validate preprocess-only mode requires silhouette export path
if args.preprocess_only and not args.silhouette_export_path:
print(
"Error: --silhouette-export-path is required when using --preprocess-only",
file=sys.stderr,
)
raise SystemExit(2)
try:
# Import here to avoid circular imports
from .pipeline import validate_runtime_inputs
validate_runtime_inputs(
source=args.source, checkpoint=args.checkpoint, config=args.config
)
effective_stride = resolve_stride(
window=cast(int, args.window),
stride=cast(int, args.stride),
window_mode=cast(WindowMode, args.window_mode),
)
pipeline = ScoliosisPipeline(
source=cast(str, args.source),
checkpoint=cast(str, args.checkpoint),
config=cast(str, args.config),
device=cast(str, args.device),
yolo_model=cast(str, args.yolo_model),
window=cast(int, args.window),
stride=effective_stride,
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
nats_url=cast(str | None, args.nats_url),
nats_subject=cast(str, args.nats_subject),
max_frames=cast(int | None, args.max_frames),
preprocess_only=cast(bool, args.preprocess_only),
silhouette_export_path=cast(str | None, args.silhouette_export_path),
silhouette_export_format=cast(str, args.silhouette_export_format),
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
result_export_path=cast(str | None, args.result_export_path),
result_export_format=cast(str, args.result_export_format),
visualize=cast(bool, args.visualize),
)
raise SystemExit(pipeline.run())
except ValueError as err:
print(f"Error: {err}", file=sys.stderr)
raise SystemExit(2) from err
except RuntimeError as err:
print(f"Runtime error: {err}", file=sys.stderr)
raise SystemExit(1) from err
-219
View File
@@ -1,219 +0,0 @@
"""
Input adapters for OpenGait demo.
Provides generator-based interfaces for video sources:
- OpenCV (video files, cameras)
- cv-mmap (shared memory streams)
"""
from collections.abc import AsyncIterator, Generator, Iterable
from typing import Protocol, cast
import logging
import numpy as np
logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
class _FrameMetadata(Protocol):
frame_count: int
timestamp_ns: int
# Protocol for cv-mmap client (needed at runtime for cast)
class _CvMmapClient(Protocol):
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
def opencv_source(
path: str | int, max_frames: int | None = None
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
"""
Generator that yields frames from an OpenCV video source.
Parameters
----------
path : str | int
Video file path or camera index (e.g., 0 for default camera)
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Yields
------
tuple[np.ndarray, dict[str, object]]
(frame_array, metadata_dict) where metadata includes:
- frame_count: frame index (0-based)
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
- source: the path/int provided
"""
import time
import cv2
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise RuntimeError(f"Failed to open video source: {path}")
is_file_source = isinstance(path, str)
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
fallback_fps = source_fps if fps_valid else 30.0
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
start_ns = time.monotonic_ns()
frame_idx = 0
try:
while max_frames is None or frame_idx < max_frames:
ret, frame = cap.read()
if not ret:
# End of stream
break
if is_file_source:
pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
if np.isfinite(pos_msec) and pos_msec > 0.0:
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
else:
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
else:
timestamp_ns = time.monotonic_ns()
metadata: dict[str, object] = {
"frame_count": frame_idx,
"timestamp_ns": timestamp_ns,
"source": path,
}
if fps_valid:
metadata["source_fps"] = source_fps
yield frame, metadata
frame_idx += 1
finally:
cap.release()
logger.debug(f"OpenCV source closed: {path}")
def cvmmap_source(
name: str, max_frames: int | None = None
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
"""
Generator that yields frames from a cv-mmap shared memory stream.
Bridges async cv-mmap client to synchronous generator using asyncio.run().
Parameters
----------
name : str
Base name of the cv-mmap source (e.g., "default")
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Yields
------
tuple[np.ndarray, dict[str, object]]
(frame_array, metadata_dict) where metadata includes:
- frame_count: frame index from cv-mmap
- timestamp_ns: timestamp in nanoseconds from cv-mmap
- source: the cv-mmap name
Raises
------
ImportError
If cvmmap package is not available
RuntimeError
If cv-mmap stream disconnects or errors
"""
import asyncio
# Import cvmmap only when function is called
# Use try/except for runtime import check
try:
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
except ImportError as e:
raise ImportError(
"cvmmap package is required for cv-mmap sources. "
+ "Install from: https://github.com/crosstyan/cv-mmap"
) from e
# Cast to protocol type for type checking
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
frame_count = 0
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
"""Async generator wrapper."""
async for frame, meta in client:
yield frame, meta
# Bridge async to sync using asyncio.run()
# We process frames one at a time to keep it simple and robust
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
agen = _async_generator().__aiter__()
while max_frames is None or frame_count < max_frames:
try:
frame, meta = loop.run_until_complete(agen.__anext__())
except StopAsyncIteration:
break
metadata: dict[str, object] = {
"frame_count": meta.frame_count,
"timestamp_ns": meta.timestamp_ns,
"source": f"cvmmap://{name}",
}
yield frame, metadata
frame_count += 1
finally:
loop.close()
logger.debug(f"cv-mmap source closed: {name}")
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
"""
Factory function to create a frame source from a string specification.
Parameters
----------
source : str
Source specification:
- '0', '1', etc. -> Camera index (OpenCV)
- 'cvmmap://name' -> cv-mmap shared memory stream
- Any other string -> Video file path (OpenCV)
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Returns
-------
FrameStream
Generator yielding (frame, metadata) tuples
Examples
--------
>>> for frame, meta in create_source('0'): # Camera 0
... process(frame)
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
... process(frame)
>>> for frame, meta in create_source('/path/to/video.mp4'):
... process(frame)
"""
# Check for cv-mmap protocol
if source.startswith("cvmmap://"):
name = source[len("cvmmap://") :]
return cvmmap_source(name, max_frames)
# Check for camera index (single digit string)
if source.isdigit():
return opencv_source(int(source), max_frames)
# Otherwise treat as file path
return opencv_source(source, max_frames)
-383
View File
@@ -1,383 +0,0 @@
"""
Output publishers for OpenGait demo results.
Provides pluggable result publishing:
- ConsolePublisher: JSONL to stdout
- NatsPublisher: NATS message broker integration
"""
from __future__ import annotations
import asyncio
import json
import logging
import sys
import threading
import time
from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable
if TYPE_CHECKING:
from types import TracebackType
logger = logging.getLogger(__name__)
class DemoResult(TypedDict):
"""Typed result dictionary for demo pipeline output.
Contains classification result with frame metadata.
"""
frame: int
track_id: int
label: str
confidence: float
window: int
timestamp_ns: int
@runtime_checkable
class ResultPublisher(Protocol):
"""Protocol for result publishers."""
def publish(self, result: DemoResult) -> None:
"""
Publish a result dictionary.
Parameters
----------
result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
...
class ConsolePublisher:
"""Publisher that outputs JSON Lines to stdout."""
_output: TextIO
def __init__(self, output: TextIO = sys.stdout) -> None:
"""
Initialize console publisher.
Parameters
----------
output : TextIO
File-like object to write to (default: sys.stdout)
"""
self._output = output
def publish(self, result: DemoResult) -> None:
"""
Publish result as JSON line.
Parameters
----------
result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
try:
json_line = json.dumps(result, ensure_ascii=False, default=str)
_ = self._output.write(json_line + "\n")
self._output.flush()
except Exception as e:
logger.warning(f"Failed to publish to console: {e}")
def close(self) -> None:
"""Close the publisher (no-op for console)."""
pass
def __enter__(self) -> ConsolePublisher:
"""Context manager entry."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Context manager exit."""
self.close()
class _NatsClient(Protocol):
"""Protocol for connected NATS client."""
async def publish(self, subject: str, payload: bytes) -> object: ...
async def close(self) -> object: ...
async def flush(self) -> object: ...
class NatsPublisher:
"""
Publisher that sends results to NATS message broker.
This is a sync-friendly wrapper around the async nats-py client.
Uses a background thread with dedicated event loop to bridge sync
publish calls to async NATS operations, making it safe to use in
both sync and async contexts.
"""
_nats_url: str
_subject: str
_nc: _NatsClient | None
_connected: bool
_loop: asyncio.AbstractEventLoop | None
_thread: threading.Thread | None
_lock: threading.Lock
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
"""
Initialize NATS publisher.
Parameters
----------
nats_url : str
NATS server URL (e.g., "nats://localhost:4222")
subject : str
NATS subject to publish to (default: "scoliosis.result")
"""
self._nats_url = nats_url
self._subject = subject
self._nc = None
self._connected = False
self._loop = None
self._thread = None
self._lock = threading.Lock()
def _start_background_loop(self) -> bool:
"""
Start background thread with event loop for async operations.
Returns
-------
bool
True if loop is running, False otherwise
"""
with self._lock:
if self._loop is not None and self._loop.is_running():
return True
try:
loop = asyncio.new_event_loop()
self._loop = loop
def run_loop() -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
self._thread = threading.Thread(target=run_loop, daemon=True)
self._thread.start()
return True
except Exception as e:
logger.warning(f"Failed to start background event loop: {e}")
return False
def _stop_background_loop(self) -> None:
"""Stop the background event loop and thread."""
with self._lock:
if self._loop is not None and self._loop.is_running():
_ = self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread is not None and self._thread.is_alive():
self._thread.join(timeout=2.0)
self._loop = None
self._thread = None
def _ensure_connected(self) -> bool:
"""
Ensure connection to NATS server.
Returns
-------
bool
True if connected, False otherwise
"""
with self._lock:
if self._connected and self._nc is not None:
return True
if not self._start_background_loop():
return False
try:
import nats
async def _connect() -> _NatsClient:
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
return cast(_NatsClient, nc)
# Run connection in background loop
future = asyncio.run_coroutine_threadsafe(
_connect(),
self._loop, # pyright: ignore[reportArgumentType]
)
self._nc = future.result(timeout=10.0)
self._connected = True
logger.info(f"Connected to NATS at {self._nats_url}")
return True
except ImportError:
logger.warning(
"nats-py package not installed. Install with: pip install nats-py"
)
return False
except Exception as e:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False
def publish(self, result: DemoResult) -> None:
"""
Publish result to NATS subject.
Parameters
----------
result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
if not self._ensure_connected():
# Graceful degradation: log warning but don't crash
logger.debug(
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
)
return
try:
async def _publish() -> None:
if self._nc is not None:
payload = json.dumps(
result, ensure_ascii=False, default=str
).encode("utf-8")
_ = await self._nc.publish(self._subject, payload)
_ = await self._nc.flush()
# Run publish in background loop
future = asyncio.run_coroutine_threadsafe(
_publish(),
self._loop, # pyright: ignore[reportArgumentType]
)
future.result(timeout=5.0) # Wait for publish to complete
except Exception as e:
logger.warning(f"Failed to publish to NATS: {e}")
self._connected = False # Mark for reconnection on next publish
def close(self) -> None:
"""Close NATS connection."""
with self._lock:
if self._nc is not None and self._connected and self._loop is not None:
try:
async def _close() -> None:
if self._nc is not None:
_ = await self._nc.close()
future = asyncio.run_coroutine_threadsafe(
_close(),
self._loop,
)
future.result(timeout=5.0)
except Exception as e:
logger.debug(f"Error closing NATS connection: {e}")
finally:
self._nc = None
self._connected = False
self._stop_background_loop()
def __enter__(self) -> NatsPublisher:
"""Context manager entry."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Context manager exit."""
self.close()
def create_publisher(
nats_url: str | None,
subject: str = "scoliosis.result",
) -> ResultPublisher:
"""
Factory function to create appropriate publisher.
Parameters
----------
nats_url : str | None
NATS server URL. If None or empty, returns ConsolePublisher.
subject : str
NATS subject to publish to (default: "scoliosis.result")
Returns
-------
ResultPublisher
NatsPublisher if nats_url provided, otherwise ConsolePublisher
Examples
--------
>>> # Console output (default)
>>> pub = create_publisher(None)
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
>>>
>>> # NATS output
>>> pub = create_publisher("nats://localhost:4222")
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
>>>
>>> # Context manager usage
>>> with create_publisher("nats://localhost:4222") as pub:
... pub.publish(result)
"""
if nats_url:
return NatsPublisher(nats_url, subject)
return ConsolePublisher()
def create_result(
frame: int,
track_id: int,
label: str,
confidence: float,
window: int | tuple[int, int],
timestamp_ns: int | None = None,
) -> DemoResult:
"""
Create a standardized result dictionary.
Parameters
----------
frame : int
Frame number
track_id : int
Track/person identifier
label : str
Classification label (e.g., "normal", "scoliosis")
confidence : float
Confidence score (0.0 to 1.0)
window : int | tuple[int, int]
Frame window as int (end frame) or tuple [start, end] that produced this result
Frame window [start, end] that produced this result
timestamp_ns : int | None
Timestamp in nanoseconds. If None, uses current time.
Returns
-------
DemoResult
Standardized result dictionary
"""
return {
"frame": frame,
"track_id": track_id,
"label": label,
"confidence": confidence,
"window": window if isinstance(window, int) else window[1],
"timestamp_ns": timestamp_ns
if timestamp_ns is not None
else time.monotonic_ns(),
}
-904
View File
@@ -1,904 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
from beartype import beartype
import click
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
from ultralytics.models.yolo.model import YOLO
from .input import FrameStream, create_source
from .output import DemoResult, ResultPublisher, create_publisher, create_result
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
if TYPE_CHECKING:
from .visualizer import OpenCVVisualizer
logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
if window_mode == "manual":
return stride
if window_mode == "sliding":
return 1
return window
class _BoxesLike(Protocol):
@property
def xyxy(self) -> NDArray[np.float32] | object: ...
@property
def id(self) -> NDArray[np.int64] | object | None: ...
class _MasksLike(Protocol):
@property
def data(self) -> NDArray[np.float32] | object: ...
class _DetectionResultsLike(Protocol):
@property
def boxes(self) -> _BoxesLike: ...
@property
def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol):
def __call__(
self,
source: object,
*,
persist: bool = True,
verbose: bool = False,
device: str | None = None,
classes: list[int] | None = None,
) -> object: ...
class _SelectedSilhouette(TypedDict):
"""Selected silhouette payload produced from detector outputs.
Fields:
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
mask_raw: Full-resolution binary person mask in mask/image space.
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
track_id: Tracking ID from detector, or `0` for fallback path.
"""
silhouette: Float[ndarray, "64 44"]
mask_raw: UInt8[ndarray, "h w"]
bbox_frame: BBoxXYXY
bbox_mask: BBoxXYXY
track_id: int
class _FramePacer:
_interval_ns: int
_next_emit_ns: int | None
def __init__(self, target_fps: float) -> None:
if target_fps <= 0:
raise ValueError(f"target_fps must be positive, got {target_fps}")
self._interval_ns = int(1_000_000_000 / target_fps)
self._next_emit_ns = None
def should_emit(self, timestamp_ns: int) -> bool:
if self._next_emit_ns is None:
self._next_emit_ns = timestamp_ns + self._interval_ns
return True
if timestamp_ns >= self._next_emit_ns:
while self._next_emit_ns <= timestamp_ns:
self._next_emit_ns += self._interval_ns
return True
return False
class ScoliosisPipeline:
_detector: object
_source: FrameStream
_window: SilhouetteWindow
_publisher: ResultPublisher
_classifier: ScoNetDemo
_device: str
_closed: bool
_preprocess_only: bool
_silhouette_export_path: Path | None
_silhouette_export_format: str
_silhouette_buffer: list[dict[str, object]]
_silhouette_visualize_dir: Path | None
_result_export_path: Path | None
_result_export_format: str
_result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None
_frame_pacer: _FramePacer | None
def __init__(
self,
*,
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
preprocess_only: bool = False,
silhouette_export_path: str | None = None,
silhouette_export_format: str = "pickle",
silhouette_visualize_dir: str | None = None,
result_export_path: str | None = None,
result_export_format: str = "json",
visualize: bool = False,
target_fps: float | None = 15.0,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
self._window = SilhouetteWindow(window_size=window, stride=stride)
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
self._classifier = ScoNetDemo(
cfg_path=config,
checkpoint_path=checkpoint,
device=device,
)
self._device = device
self._closed = False
self._preprocess_only = preprocess_only
self._silhouette_export_path = (
Path(silhouette_export_path) if silhouette_export_path else None
)
self._silhouette_export_format = silhouette_export_format
# Normalize format alias: pkl -> pickle
if self._silhouette_export_format == "pkl":
self._silhouette_export_format = "pickle"
self._silhouette_buffer = []
self._silhouette_visualize_dir = (
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
)
self._result_export_path = (
Path(result_export_path) if result_export_path else None
)
self._result_export_format = result_export_format
self._result_buffer = []
if visualize:
from .visualizer import OpenCVVisualizer
self._visualizer = OpenCVVisualizer()
else:
self._visualizer = None
self._last_viz_payload = None
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
value = meta.get(key)
if isinstance(value, int):
return value
return fallback
@staticmethod
def _extract_timestamp(meta: dict[str, object]) -> int:
value = meta.get("timestamp_ns")
if isinstance(value, int):
return value
return time.monotonic_ns()
@staticmethod
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(UInt8[ndarray, "h w"], binary)
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
if isinstance(detections, list):
return cast(_DetectionResultsLike, detections[0]) if detections else None
if isinstance(detections, tuple):
return cast(_DetectionResultsLike, detections[0]) if detections else None
return cast(_DetectionResultsLike, detections)
def _select_silhouette(
self,
result: _DetectionResultsLike,
) -> _SelectedSilhouette | None:
selected = select_person(result)
if selected is not None:
mask_raw, bbox_mask, bbox_frame, track_id = selected
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
)
if silhouette is not None:
return {
"silhouette": silhouette,
"mask_raw": mask_raw,
"bbox_frame": bbox_frame,
"bbox_mask": bbox_mask,
"track_id": int(track_id),
}
fallback = cast(
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
frame_to_person_mask(result),
)
if fallback is None:
return None
mask_u8, bbox_mask = fallback
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(mask_u8, bbox_mask),
)
if silhouette is None:
return None
# Convert mask-space bbox to frame-space for visualization
# Use result.orig_shape to get frame dimensions safely
orig_shape = getattr(result, "orig_shape", None)
if (
orig_shape is not None
and isinstance(orig_shape, (tuple, list))
and len(orig_shape) >= 2
):
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
scale_x = frame_w / mask_w
scale_y = frame_h / mask_h
bbox_frame = (
int(bbox_mask[0] * scale_x),
int(bbox_mask[1] * scale_y),
int(bbox_mask[2] * scale_x),
int(bbox_mask[3] * scale_y),
)
else:
# Fallback: use mask-space bbox if dimensions invalid
bbox_frame = bbox_mask
else:
# Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8
return {
"silhouette": silhouette,
"mask_raw": mask_u8,
"bbox_frame": bbox_frame,
"bbox_mask": bbox_mask,
"track_id": 0,
}
@jaxtyped(typechecker=beartype)
def process_frame(
self,
frame: UInt8[ndarray, "h w c"],
metadata: dict[str, object],
) -> dict[str, object] | None:
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
timestamp_ns = self._extract_timestamp(metadata)
track_fn_obj = getattr(self._detector, "track", None)
if not callable(track_fn_obj):
raise RuntimeError("YOLO detector does not expose a callable track()")
track_fn = cast(_TrackCallable, track_fn_obj)
detections = track_fn(
frame,
persist=True,
verbose=False,
device=self._device,
classes=[0],
)
first = self._first_result(detections)
if first is None:
return None
selected = self._select_silhouette(first)
if selected is None:
return None
silhouette = selected["silhouette"]
mask_raw = selected["mask_raw"]
bbox = selected["bbox_frame"]
bbox_mask = selected["bbox_mask"]
track_id = selected["track_id"]
# Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only:
self._silhouette_buffer.append(
{
"frame": frame_idx,
"track_id": track_id,
"timestamp_ns": timestamp_ns,
"silhouette": silhouette.copy(),
}
)
# Visualize silhouette if requested
if self._silhouette_visualize_dir is not None:
self._visualize_silhouette(silhouette, frame_idx, track_id)
if self._preprocess_only:
# Return visualization payload for display even in preprocess-only mode
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": None,
"track_id": track_id,
"label": None,
"confidence": None,
}
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns
):
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": self._window.buffered_silhouettes,
"track_id": track_id,
"label": None,
"confidence": None,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify():
# Return visualization payload even when not classifying yet
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": None,
"confidence": None,
}
window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast(
tuple[str, float],
self._classifier.predict(window_tensor),
)
self._window.mark_classified()
window_start = self._window.window_start_frame
result = create_result(
frame=frame_idx,
track_id=track_id,
label=label,
confidence=float(confidence),
window=(max(0, window_start), frame_idx),
timestamp_ns=timestamp_ns,
)
# Store result for export if export path specified
if self._result_export_path is not None:
self._result_buffer.append(result)
self._publisher.publish(result)
# Return result with visualization payload
return {
"result": result,
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": label,
"confidence": confidence,
}
def run(self) -> int:
frame_count = 0
start_time = time.perf_counter()
# EMA FPS state (alpha=0.1 for smoothing)
ema_fps = 0.0
alpha = 0.1
prev_time = start_time
try:
for item in self._source:
frame, metadata = item
frame_u8 = np.asarray(frame, dtype=np.uint8)
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
frame_count += 1
# Compute per-frame EMA FPS
curr_time = time.perf_counter()
delta = curr_time - prev_time
prev_time = curr_time
if delta > 0:
instant_fps = 1.0 / delta
if ema_fps == 0.0:
ema_fps = instant_fps
else:
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
viz_payload = None
try:
viz_payload = self.process_frame(frame_u8, metadata)
except Exception as frame_error:
logger.warning(
"Skipping frame %d due to processing error: %s",
frame_idx,
frame_error,
)
# Update visualizer if enabled
if self._visualizer is not None:
# Cache valid payload for no-detection frames
if viz_payload is not None:
# Cache a copy to prevent mutation of original data
viz_payload_dict = cast(dict[str, object], viz_payload)
cached: dict[str, object] = {}
for k, v in viz_payload_dict.items():
copy_method = cast(
Callable[[], object] | None, getattr(v, "copy", None)
)
if copy_method is not None and callable(copy_method):
cached[k] = copy_method()
else:
cached[k] = v
self._last_viz_payload = cached
if viz_payload is not None:
viz_data = viz_payload
elif self._last_viz_payload is not None:
viz_data = dict(self._last_viz_payload)
viz_data["bbox"] = None
viz_data["bbox_mask"] = None
viz_data["label"] = None
viz_data["confidence"] = None
else:
viz_data = None
if viz_data is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask")
silhouette_obj = viz_dict.get("silhouette")
segmentation_input_obj = viz_dict.get("segmentation_input")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
segmentation_input = cast(
NDArray[np.float32] | None,
segmentation_input_obj,
)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
else:
# No detection and no cache - use default values
mask_raw = None
bbox = None
bbox_mask = None
track_id = 0
silhouette = None
segmentation_input = None
label = None
confidence = None
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0.0
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
return 0
except KeyboardInterrupt:
logger.info("Interrupted by user, shutting down cleanly.")
return 130
finally:
self.close()
def close(self) -> None:
if self._closed:
return
# Close visualizer if enabled
if self._visualizer is not None:
with suppress(Exception):
self._visualizer.close()
# Export silhouettes if requested
if self._silhouette_export_path is not None and self._silhouette_buffer:
self._export_silhouettes()
# Export results if requested
if self._result_export_path is not None and self._result_buffer:
self._export_results()
close_fn = getattr(self._publisher, "close", None)
if callable(close_fn):
with suppress(Exception):
_ = close_fn()
self._closed = True
def _export_silhouettes(self) -> None:
"""Export silhouettes to file in specified format."""
if self._silhouette_export_path is None:
return
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
if self._silhouette_export_format == "pickle":
import pickle
with open(self._silhouette_export_path, "wb") as f:
pickle.dump(self._silhouette_buffer, f)
logger.info(
"Exported %d silhouettes to %s",
len(self._silhouette_buffer),
self._silhouette_export_path,
)
elif self._silhouette_export_format == "parquet":
self._export_parquet_silhouettes()
else:
raise ValueError(
f"Unsupported silhouette export format: {self._silhouette_export_format}"
)
def _visualize_silhouette(
self,
silhouette: Float[ndarray, "64 44"],
frame_idx: int,
track_id: int,
) -> None:
"""Save silhouette as PNG image."""
if self._silhouette_visualize_dir is None:
return
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
# Convert float silhouette to uint8 (0-255)
silhouette_u8 = (silhouette * 255).astype(np.uint8)
# Create deterministic filename
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
output_path = self._silhouette_visualize_dir / filename
# Save using PIL
from PIL import Image
Image.fromarray(silhouette_u8).save(output_path)
def _export_parquet_silhouettes(self) -> None:
"""Export silhouettes to parquet format."""
import importlib
try:
pa = importlib.import_module("pyarrow")
pq = importlib.import_module("pyarrow.parquet")
except ImportError as e:
raise RuntimeError(
"Parquet export requires pyarrow. Install with: pip install pyarrow"
) from e
# Convert silhouettes to columnar format
frames = []
track_ids = []
timestamps = []
silhouettes = []
for item in self._silhouette_buffer:
frames.append(item["frame"])
track_ids.append(item["track_id"])
timestamps.append(item["timestamp_ns"])
silhouette_array = cast(ndarray, item["silhouette"])
silhouettes.append(silhouette_array.flatten().tolist())
table = pa.table(
{
"frame": pa.array(frames, type=pa.int64()),
"track_id": pa.array(track_ids, type=pa.int64()),
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
}
)
pq.write_table(table, self._silhouette_export_path)
logger.info(
"Exported %d silhouettes to parquet: %s",
len(self._silhouette_buffer),
self._silhouette_export_path,
)
def _export_results(self) -> None:
"""Export results to file in specified format."""
if self._result_export_path is None:
return
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
if self._result_export_format == "json":
import json
with open(self._result_export_path, "w", encoding="utf-8") as f:
for result in self._result_buffer:
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
logger.info(
"Exported %d results to JSON: %s",
len(self._result_buffer),
self._result_export_path,
)
elif self._result_export_format == "pickle":
import pickle
with open(self._result_export_path, "wb") as f:
pickle.dump(self._result_buffer, f)
logger.info(
"Exported %d results to pickle: %s",
len(self._result_buffer),
self._result_export_path,
)
elif self._result_export_format == "parquet":
self._export_parquet_results()
else:
raise ValueError(
f"Unsupported result export format: {self._result_export_format}"
)
def _export_parquet_results(self) -> None:
"""Export results to parquet format."""
import importlib
try:
pa = importlib.import_module("pyarrow")
pq = importlib.import_module("pyarrow.parquet")
except ImportError as e:
raise RuntimeError(
"Parquet export requires pyarrow. Install with: pip install pyarrow"
) from e
frames = []
track_ids = []
labels = []
confidences = []
windows = []
timestamps = []
for result in self._result_buffer:
frames.append(result["frame"])
track_ids.append(result["track_id"])
labels.append(result["label"])
confidences.append(result["confidence"])
windows.append(result["window"])
timestamps.append(result["timestamp_ns"])
table = pa.table(
{
"frame": pa.array(frames, type=pa.int64()),
"track_id": pa.array(track_ids, type=pa.int64()),
"label": pa.array(labels, type=pa.string()),
"confidence": pa.array(confidences, type=pa.float64()),
"window": pa.array(windows, type=pa.int64()),
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
}
)
pq.write_table(table, self._result_export_path)
logger.info(
"Exported %d results to parquet: %s",
len(self._result_buffer),
self._result_export_path,
)
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
if source.startswith("cvmmap://") or source.isdigit():
pass
else:
source_path = Path(source)
if not source_path.is_file():
raise ValueError(f"Video source not found: {source}")
checkpoint_path = Path(checkpoint)
if not checkpoint_path.is_file():
raise ValueError(f"Checkpoint not found: {checkpoint}")
config_path = Path(config)
if not config_path.is_file():
raise ValueError(f"Config not found: {config}")
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option("--source", type=str, required=True)
@click.option("--checkpoint", type=str, required=True)
@click.option(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
show_default=True,
)
@click.option("--device", type=str, default="cuda:0", show_default=True)
@click.option(
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True
)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option(
"--window-mode",
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
default="manual",
show_default=True,
help=(
"Window scheduling mode: manual uses --stride; "
"sliding forces stride=1; chunked forces stride=window"
),
)
@click.option(
"--target-fps",
type=click.FloatRange(min=0.1),
default=15.0,
show_default=True,
)
@click.option("--no-target-fps", is_flag=True, default=False)
@click.option("--nats-url", type=str, default=None)
@click.option(
"--nats-subject",
type=str,
default="scoliosis.result",
show_default=True,
)
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
@click.option(
"--preprocess-only",
is_flag=True,
default=False,
help="Only preprocess silhouettes, skip classification.",
)
@click.option(
"--silhouette-export-path",
type=str,
default=None,
help="Path to export silhouettes (required for preprocess-only mode).",
)
@click.option(
"--silhouette-export-format",
type=click.Choice(["pickle", "parquet"]),
default="pickle",
show_default=True,
help="Format for silhouette export.",
)
@click.option(
"--result-export-path",
type=str,
default=None,
help="Path to export inference results.",
)
@click.option(
"--result-export-format",
type=click.Choice(["json", "pickle", "parquet"]),
default="json",
show_default=True,
help="Format for result export.",
)
@click.option(
"--silhouette-visualize-dir",
type=str,
default=None,
help="Directory to save silhouette PNG visualizations.",
)
def main(
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
window_mode: str,
target_fps: float | None,
no_target_fps: bool,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
preprocess_only: bool,
silhouette_export_path: str | None,
silhouette_export_format: str,
result_export_path: str | None,
result_export_format: str,
silhouette_visualize_dir: str | None,
) -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
# Validate preprocess-only mode requirements
if preprocess_only and not silhouette_export_path:
raise click.UsageError(
"--silhouette-export-path is required when using --preprocess-only"
)
try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
effective_stride = resolve_stride(
window=window,
stride=stride,
window_mode=cast(WindowMode, window_mode.lower()),
)
if effective_stride != stride:
logger.info(
"window_mode=%s overrides stride=%d -> effective_stride=%d",
window_mode,
stride,
effective_stride,
)
pipeline = ScoliosisPipeline(
source=source,
checkpoint=checkpoint,
config=config,
device=device,
yolo_model=yolo_model,
window=window,
stride=effective_stride,
target_fps=None if no_target_fps else target_fps,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,
preprocess_only=preprocess_only,
silhouette_export_path=silhouette_export_path,
silhouette_export_format=silhouette_export_format,
silhouette_visualize_dir=silhouette_visualize_dir,
result_export_path=result_export_path,
result_export_format=result_export_format,
)
raise SystemExit(pipeline.run())
except ValueError as err:
click.echo(f"Error: {err}", err=True)
raise SystemExit(2) from err
except RuntimeError as err:
click.echo(f"Runtime error: {err}", err=True)
raise SystemExit(1) from err
-335
View File
@@ -1,335 +0,0 @@
from collections.abc import Callable
import math
from typing import cast
import cv2
from beartype import beartype
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
SIL_HEIGHT = 64
SIL_WIDTH = 44
SIL_FULL_WIDTH = 64
SIDE_CUT = 10
MIN_MASK_AREA = 500
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32]
BBoxXYXY = tuple[int, int, int, int]
"""
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
"""
def _read_attr(container: object, key: str) -> object | None:
if isinstance(container, dict):
dict_obj = cast(dict[object, object], container)
return dict_obj.get(key)
try:
return cast(object, object.__getattribute__(container, key))
except AttributeError:
return None
def _to_numpy_array(value: object) -> NDArray[np.generic]:
current: object = value
if isinstance(current, np.ndarray):
return current
detach_obj = _read_attr(current, "detach")
if callable(detach_obj):
detach_fn = cast(Callable[[], object], detach_obj)
current = detach_fn()
cpu_obj = _read_attr(current, "cpu")
if callable(cpu_obj):
cpu_fn = cast(Callable[[], object], cpu_obj)
current = cpu_fn()
numpy_obj = _read_attr(current, "numpy")
if callable(numpy_obj):
numpy_fn = cast(Callable[[], object], numpy_obj)
as_numpy = numpy_fn()
if isinstance(as_numpy, np.ndarray):
return as_numpy
return cast(NDArray[np.generic], np.asarray(current))
def _fill_binary_holes(mask_u8: UInt8Array) -> UInt8Array:
mask_bin = np.where(mask_u8 > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
h, w = cast(tuple[int, int], mask_bin.shape)
if h <= 2 or w <= 2:
return mask_bin
seed_candidates = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)]
seed: tuple[int, int] | None = None
for x, y in seed_candidates:
if int(mask_bin[y, x]) == 0:
seed = (x, y)
break
if seed is None:
return mask_bin
flood = mask_bin.copy()
flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
_ = cv2.floodFill(flood, flood_mask, seed, 255)
holes = cv2.bitwise_not(flood)
filled = cv2.bitwise_or(mask_bin, holes)
return cast(UInt8Array, filled)
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
"""Extract bounding box from binary mask in XYXY format.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
Returns:
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
"""
mask_u8 = np.asarray(mask, dtype=np.uint8)
coords = np.argwhere(mask_u8 > 0)
if int(coords.size) == 0:
return None
ys = coords[:, 0].astype(np.int64)
xs = coords[:, 1].astype(np.int64)
x1 = int(np.min(xs))
x2 = int(np.max(xs)) + 1
y1 = int(np.min(ys))
y2 = int(np.max(ys)) + 1
if x2 <= x1 or y2 <= y1:
return None
return (x1, y1, x2, y2)
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
"""Sanitize bounding box to ensure it's within image bounds.
Args:
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
height: Image height.
width: Image width.
Returns:
Sanitized bounding box in XYXY format, or None if invalid.
"""
x1, y1, x2, y2 = bbox
x1c = max(0, min(int(x1), width - 1))
y1c = max(0, min(int(y1), height - 1))
x2c = max(0, min(int(x2), width))
y2c = max(0, min(int(y2), height))
if x2c <= x1c or y2c <= y1c:
return None
return (x1c, y1c, x2c, y2c)
@jaxtyped(typechecker=beartype)
def frame_to_person_mask(
result: object, min_area: int = MIN_MASK_AREA
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
"""Extract person mask and bounding box from detection result.
Args:
result: Detection results object with boxes and masks attributes.
min_area: Minimum mask area to consider valid.
Returns:
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
or None if no valid detections.
"""
masks_obj = _read_attr(result, "masks")
if masks_obj is None:
return None
masks_data_obj = _read_attr(masks_obj, "data")
if masks_data_obj is None:
return None
masks_raw = _to_numpy_array(masks_data_obj)
masks_float = np.asarray(masks_raw, dtype=np.float32)
if masks_float.ndim == 2:
masks_float = masks_float[np.newaxis, ...]
if masks_float.ndim != 3:
return None
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
if mask_count <= 0:
return None
box_values: list[tuple[float, float, float, float]] | None = None
boxes_obj = _read_attr(result, "boxes")
if boxes_obj is not None:
xyxy_obj = _read_attr(boxes_obj, "xyxy")
if xyxy_obj is not None:
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
x1f = cast(np.float64, xyxy_2d[0, 0])
y1f = cast(np.float64, xyxy_2d[0, 1])
x2f = cast(np.float64, xyxy_2d[0, 2])
y2f = cast(np.float64, xyxy_2d[0, 3])
box_values = [
(
float(x1f),
float(y1f),
float(x2f),
float(y2f),
)
]
elif xyxy_raw.ndim == 2:
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
if int(shape_2d[1]) >= 4:
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
box_values = []
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
box_values.append(
(
float(x1f),
float(y1f),
float(x2f),
float(y2f),
)
)
best_area = -1
best_mask: UInt8[ndarray, "h w"] | None = None
best_bbox: BBoxXYXY | None = None
for idx in range(mask_count):
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
if mask_float.ndim != 2:
continue
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
area = int(np.count_nonzero(mask_u8))
if area < min_area:
continue
bbox: BBoxXYXY | None = None
shape_2d = cast(tuple[int, int], mask_binary.shape)
h = int(shape_2d[0])
w = int(shape_2d[1])
if box_values is not None:
box_count = len(box_values)
if idx >= box_count:
continue
row0, row1, row2, row3 = box_values[idx]
bbox_candidate = (
int(math.floor(row0)),
int(math.floor(row1)),
int(math.ceil(row2)),
int(math.ceil(row3)),
)
bbox = _sanitize_bbox(bbox_candidate, h, w)
if bbox is None:
bbox = _bbox_from_mask(mask_u8)
if bbox is None:
continue
if area > best_area:
best_area = area
best_mask = mask_u8
best_bbox = bbox
if best_mask is None or best_bbox is None:
return None
return best_mask, best_bbox
@jaxtyped(typechecker=beartype)
def mask_to_silhouette(
mask: UInt8[ndarray, "h w"],
bbox: BBoxXYXY,
) -> Float[ndarray, "64 44"] | None:
"""Convert mask to standardized silhouette using bounding box.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
Returns:
Standardized silhouette array of shape (64, 44) with dtype float32,
or None if conversion fails.
"""
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
mask_u8 = _fill_binary_holes(mask_u8)
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
return None
mask_shape = cast(tuple[int, int], mask_u8.shape)
h = int(mask_shape[0])
w = int(mask_shape[1])
bbox_sanitized = _sanitize_bbox(bbox, h, w)
if bbox_sanitized is None:
return None
x1, y1, x2, y2 = bbox_sanitized
cropped = mask_u8[y1:y2, x1:x2]
if cropped.size == 0:
return None
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
if int(row_nonzero.size) == 0:
return None
top = int(cast(np.int64, row_nonzero[0]))
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
tightened = cropped[top:bottom, :]
if tightened.size == 0:
return None
tight_shape = cast(tuple[int, int], tightened.shape)
tight_h = int(tight_shape[0])
tight_w = int(tight_shape[1])
if tight_h <= 0 or tight_w <= 0:
return None
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
resized = np.asarray(
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
dtype=np.uint8,
)
if resized_w >= SIL_FULL_WIDTH:
start = (resized_w - SIL_FULL_WIDTH) // 2
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
else:
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
normalized_64 = np.pad(
resized,
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
silhouette = np.asarray(
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
)
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
return None
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
np.float32
)
return cast(Float[ndarray, "64 44"], silhouette_norm)
-317
View File
@@ -1,317 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
import sys
from typing import ClassVar, Protocol, cast, override
import torch
import torch.nn as nn
from beartype import beartype
from einops import rearrange
from jaxtyping import Float
import jaxtyping
from torch import Tensor
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
from opengait.modeling.backbones.resnet import ResNet9
from opengait.modeling.modules import (
HorizontalPoolingPyramid,
PackSequenceWrapper as TemporalPool,
SeparateBNNecks,
SeparateFCs,
)
from opengait.utils import common as common_utils
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
ConfigLoader = Callable[[str], dict[str, object]]
config_loader = cast(ConfigLoader, common_utils.config_loader)
class TemporalPoolLike(Protocol):
def __call__(
self,
seqs: Tensor,
seqL: object,
dim: int = 2,
options: dict[str, int] | None = None,
) -> object: ...
class HppLike(Protocol):
def __call__(self, x: Tensor) -> Tensor: ...
class FCsLike(Protocol):
def __call__(self, x: Tensor) -> Tensor: ...
class BNNecksLike(Protocol):
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
class ScoNetDemo(nn.Module):
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
cfg_path: str
cfg: dict[str, object]
backbone: ResNet9
temporal_pool: TemporalPoolLike
hpp: HppLike
fcs: FCsLike
bn_necks: BNNecksLike
device: torch.device
@jaxtyped(typechecker=beartype)
def __init__(
self,
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
checkpoint_path: str | Path | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()
resolved_cfg = self._resolve_path(cfg_path)
self.cfg_path = str(resolved_cfg)
self.cfg = config_loader(self.cfg_path)
model_cfg = self._extract_model_cfg(self.cfg)
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
if backbone_cfg.get("type") != "ResNet9":
raise ValueError(
"ScoNetDemo currently supports backbone type ResNet9 only."
)
self.backbone = ResNet9(
block=self._extract_str(backbone_cfg, "block"),
channels=self._extract_int_list(backbone_cfg, "channels"),
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
layers=self._extract_int_list(backbone_cfg, "layers"),
strides=self._extract_int_list(backbone_cfg, "strides"),
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
)
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
bin_num = self._extract_int_list(model_cfg, "bin_num")
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
self.fcs = cast(
FCsLike,
SeparateFCs(
parts_num=self._extract_int(fcs_cfg, "parts_num"),
in_channels=self._extract_int(fcs_cfg, "in_channels"),
out_channels=self._extract_int(fcs_cfg, "out_channels"),
norm=self._extract_bool(fcs_cfg, "norm", default=False),
),
)
self.bn_necks = cast(
BNNecksLike,
SeparateBNNecks(
parts_num=self._extract_int(bn_cfg, "parts_num"),
in_channels=self._extract_int(bn_cfg, "in_channels"),
class_num=self._extract_int(bn_cfg, "class_num"),
norm=self._extract_bool(bn_cfg, "norm", default=True),
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
),
)
self.device = (
torch.device(device) if device is not None else torch.device("cpu")
)
_ = self.to(self.device)
if checkpoint_path is not None:
_ = self.load_checkpoint(checkpoint_path)
_ = self.eval()
@staticmethod
def _resolve_path(path: str | Path) -> Path:
candidate = Path(path)
if candidate.is_file():
return candidate
if candidate.is_absolute():
return candidate
repo_root = Path(__file__).resolve().parents[2]
return repo_root / candidate
@staticmethod
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
model_cfg_obj = cfg.get("model_cfg")
if not isinstance(model_cfg_obj, dict):
raise TypeError("model_cfg must be a dictionary.")
return cast(dict[str, object], model_cfg_obj)
@staticmethod
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
value = container.get(key)
if not isinstance(value, dict):
raise TypeError(f"{key} must be a dictionary.")
return cast(dict[str, object], value)
@staticmethod
def _extract_str(container: dict[str, object], key: str) -> str:
value = container.get(key)
if not isinstance(value, str):
raise TypeError(f"{key} must be a string.")
return value
@staticmethod
def _extract_int(
container: dict[str, object], key: str, default: int | None = None
) -> int:
value = container.get(key, default)
if not isinstance(value, int):
raise TypeError(f"{key} must be an int.")
return value
@staticmethod
def _extract_bool(
container: dict[str, object], key: str, default: bool | None = None
) -> bool:
value = container.get(key, default)
if not isinstance(value, bool):
raise TypeError(f"{key} must be a bool.")
return value
@staticmethod
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
value = container.get(key)
if not isinstance(value, list):
raise TypeError(f"{key} must be a list[int].")
values = cast(list[object], value)
if not all(isinstance(v, int) for v in values):
raise TypeError(f"{key} must be a list[int].")
return cast(list[int], values)
@staticmethod
def _normalize_state_dict(
state_dict_obj: dict[object, object],
) -> dict[str, Tensor]:
prefix_remap: tuple[tuple[str, str], ...] = (
("Backbone.forward_block.", "backbone."),
("FCs.", "fcs."),
("BNNecks.", "bn_necks."),
)
cleaned_state_dict: dict[str, Tensor] = {}
for key_obj, value_obj in state_dict_obj.items():
if not isinstance(key_obj, str):
raise TypeError("Checkpoint state_dict keys must be strings.")
if not isinstance(value_obj, Tensor):
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
for source_prefix, target_prefix in prefix_remap:
if key.startswith(source_prefix):
key = f"{target_prefix}{key[len(source_prefix) :]}"
break
if key in cleaned_state_dict:
raise RuntimeError(
f"Checkpoint key normalization collision detected for key '{key}'."
)
cleaned_state_dict[key] = value_obj
return cleaned_state_dict
@jaxtyped(typechecker=beartype)
def load_checkpoint(
self,
checkpoint_path: str | Path,
map_location: str | torch.device | None = None,
strict: bool = True,
) -> None:
resolved_ckpt = self._resolve_path(checkpoint_path)
checkpoint_obj = cast(
object,
torch.load(
str(resolved_ckpt),
map_location=map_location if map_location is not None else self.device,
),
)
state_dict_obj: object = checkpoint_obj
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
if not isinstance(state_dict_obj, dict):
raise TypeError("Unsupported checkpoint format.")
cleaned_state_dict = self._normalize_state_dict(
cast(dict[object, object], state_dict_obj)
)
try:
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
except RuntimeError as exc:
raise RuntimeError(
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
) from exc
_ = self.eval()
def _prepare_sils(self, sils: Tensor) -> Tensor:
if sils.ndim == 4:
sils = sils.unsqueeze(1)
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
sils = rearrange(sils, "b s c h w -> b c s h w")
if sils.ndim != 5 or sils.shape[1] != 1:
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
return sils.float().to(self.device)
def _forward_backbone(self, sils: Tensor) -> Tensor:
batch, channels, seq, height, width = sils.shape
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
frame_feats = cast(Tensor, self.backbone(framewise))
_, out_channels, out_h, out_w = frame_feats.shape
return (
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
.transpose(1, 2)
.contiguous()
)
@override
@jaxtyped(typechecker=beartype)
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
with torch.inference_mode():
prepared_sils = self._prepare_sils(sils)
outs = self._forward_backbone(prepared_sils)
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
if (
not isinstance(pooled_obj, tuple)
or not pooled_obj
or not isinstance(pooled_obj[0], Tensor)
):
raise TypeError("TemporalPool output is invalid.")
pooled = pooled_obj[0]
feat = self.hpp(pooled)
embed_1 = self.fcs(feat)
_, logits = self.bn_necks(embed_1)
mean_logits = logits.mean(dim=-1)
pred_ids = torch.argmax(mean_logits, dim=-1)
probs = torch.softmax(mean_logits, dim=-1)
confidence = torch.gather(
probs, dim=-1, index=pred_ids.unsqueeze(-1)
).squeeze(-1)
return {"logits": logits, "label": pred_ids, "confidence": confidence}
@jaxtyped(typechecker=beartype)
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
outputs = cast(dict[str, Tensor], self.forward(sils))
labels = outputs["label"]
confidence = outputs["confidence"]
if labels.numel() != 1:
raise ValueError("predict expects batch size 1.")
label_id = int(labels.item())
return self.LABEL_MAP[label_id], float(confidence.item())
-587
View File
@@ -1,587 +0,0 @@
"""OpenCV-based visualizer for demo pipeline.
Provides real-time visualization of detection, segmentation, and classification results
with interactive mode switching for mask display.
"""
from __future__ import annotations
import logging
from typing import cast
import cv2
import numpy as np
from numpy.typing import NDArray
from .preprocess import BBoxXYXY
logger = logging.getLogger(__name__)
# Window names
MAIN_WINDOW = "Scoliosis Detection"
SEG_WINDOW = "Normalized Silhouette"
RAW_WINDOW = "Raw Mask"
WINDOW_SEG_INPUT = "Segmentation Input"
# Silhouette dimensions (from preprocess.py)
SIL_HEIGHT = 64
SIL_WIDTH = 44
# Display dimensions for upscaled silhouette
DISPLAY_HEIGHT = 256
DISPLAY_WIDTH = 176
RAW_STATS_PAD = 54
MODE_LABEL_PAD = 26
# Colors (BGR)
COLOR_GREEN = (0, 255, 0)
COLOR_WHITE = (255, 255, 255)
COLOR_BLACK = (0, 0, 0)
COLOR_DARK_GRAY = (56, 56, 56)
COLOR_RED = (0, 0, 255)
COLOR_YELLOW = (0, 255, 255)
# Type alias for image arrays (NDArray or cv2.Mat)
ImageArray = NDArray[np.uint8]
class OpenCVVisualizer:
def __init__(self) -> None:
self.show_raw_window: bool = False
self.show_raw_debug: bool = False
self._windows_created: bool = False
self._raw_window_created: bool = False
def _ensure_windows(self) -> None:
if not self._windows_created:
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(WINDOW_SEG_INPUT, cv2.WINDOW_NORMAL)
self._windows_created = True
def _ensure_raw_window(self) -> None:
if not self._raw_window_created:
cv2.namedWindow(RAW_WINDOW, cv2.WINDOW_NORMAL)
self._raw_window_created = True
def _hide_raw_window(self) -> None:
if self._raw_window_created:
cv2.destroyWindow(RAW_WINDOW)
self._raw_window_created = False
def _draw_bbox(
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
) -> None:
"""Draw bounding box on frame if present.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None
"""
if bbox is None:
return
x1, y1, x2, y2 = bbox
# Draw rectangle with green color, thickness 2
_ = cv2.rectangle(frame, (x1, y1), (x2, y2), COLOR_GREEN, 2)
def _draw_text_overlay(
self,
frame: ImageArray,
track_id: int,
fps: float,
label: str | None,
confidence: float | None,
) -> None:
"""Draw text overlay with track info, FPS, label, and confidence.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
track_id: Tracking ID
fps: Current FPS
label: Classification label or None
confidence: Classification confidence or None
"""
# Prepare text lines
lines: list[str] = []
lines.append(f"ID: {track_id}")
lines.append(f"FPS: {fps:.1f}")
if label is not None:
if confidence is not None:
lines.append(f"{label}: {confidence:.2%}")
else:
lines.append(label)
# Draw text with background for readability
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
line_height = 25
margin = 10
for i, text in enumerate(lines):
y_pos = margin + (i + 1) * line_height
# Draw background rectangle
(text_width, text_height), _ = cv2.getTextSize(
text, font, font_scale, thickness
)
_ = cv2.rectangle(
frame,
(margin, y_pos - text_height - 5),
(margin + text_width + 10, y_pos + 5),
COLOR_BLACK,
-1,
)
# Draw text
_ = cv2.putText(
frame,
text,
(margin + 5, y_pos),
font,
font_scale,
COLOR_WHITE,
thickness,
)
def _prepare_main_frame(
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
track_id: int,
fps: float,
label: str | None,
confidence: float | None,
) -> ImageArray:
"""Prepare main display frame with bbox and text overlay.
Args:
frame: Input frame (H, W, C) uint8
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID
fps: Current FPS
label: Classification label or None
confidence: Classification confidence or None
Returns:
Processed frame ready for display
"""
# Ensure BGR format (convert grayscale if needed)
if len(frame.shape) == 2:
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
elif frame.shape[2] == 1:
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
elif frame.shape[2] == 3:
display_frame = frame.copy()
elif frame.shape[2] == 4:
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR))
else:
display_frame = frame.copy()
# Draw bbox and text (modifies in place)
self._draw_bbox(display_frame, bbox)
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
return display_frame
def _upscale_silhouette(
self,
silhouette: NDArray[np.float32] | NDArray[np.uint8],
) -> ImageArray:
"""Upscale silhouette to display size.
Args:
silhouette: Input silhouette (64, 44) float32 [0,1] or uint8 [0,255]
Returns:
Upscaled silhouette (256, 176) uint8
"""
# Normalize to uint8 if needed
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
sil_u8 = (silhouette * 255).astype(np.uint8)
else:
sil_u8 = silhouette.astype(np.uint8)
# Upscale using nearest neighbor to preserve pixelation
upscaled = cast(
ImageArray,
cv2.resize(
sil_u8,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
)
return upscaled
def _normalize_mask_for_display(self, mask: NDArray[np.generic]) -> ImageArray:
mask_array = np.asarray(mask)
if mask_array.dtype == np.bool_:
bool_scaled = np.where(mask_array, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, bool_scaled)
if mask_array.dtype == np.uint8:
mask_array = cast(ImageArray, mask_array)
max_u8 = int(np.max(mask_array)) if mask_array.size > 0 else 0
if max_u8 <= 1:
scaled_u8 = np.where(mask_array > 0, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, scaled_u8)
return cast(ImageArray, mask_array)
if np.issubdtype(mask_array.dtype, np.integer):
max_int = float(np.max(mask_array)) if mask_array.size > 0 else 0.0
if max_int <= 1.0:
return cast(
ImageArray, (mask_array.astype(np.float32) * 255.0).astype(np.uint8)
)
clipped = np.clip(mask_array, 0, 255).astype(np.uint8)
return cast(ImageArray, clipped)
mask_float = np.asarray(mask_array, dtype=np.float32)
max_val = float(np.max(mask_float)) if mask_float.size > 0 else 0.0
if max_val <= 0.0:
return np.zeros(mask_float.shape, dtype=np.uint8)
normalized = np.clip((mask_float / max_val) * 255.0, 0.0, 255.0).astype(
np.uint8
)
return cast(ImageArray, normalized)
def _draw_raw_stats(self, image: ImageArray, mask_raw: ImageArray | None) -> None:
if mask_raw is None:
return
mask = np.asarray(mask_raw)
if mask.size == 0:
return
stats = [
f"raw: {mask.dtype}",
f"min/max: {float(mask.min()):.3f}/{float(mask.max()):.3f}",
f"nnz: {int(np.count_nonzero(mask))}",
]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.45
thickness = 1
line_h = 18
x0 = 8
y0 = 20
for i, txt in enumerate(stats):
y = y0 + i * line_h
(tw, th), _ = cv2.getTextSize(txt, font, font_scale, thickness)
_ = cv2.rectangle(
image, (x0 - 4, y - th - 4), (x0 + tw + 4, y + 4), COLOR_BLACK, -1
)
_ = cv2.putText(
image, txt, (x0, y), font, font_scale, COLOR_YELLOW, thickness
)
def _prepare_segmentation_view(
self,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
bbox: BBoxXYXY | None,
) -> ImageArray:
_ = mask_raw
_ = bbox
return self._prepare_normalized_view(silhouette)
def _fit_gray_to_display(
self,
gray: ImageArray,
out_h: int = DISPLAY_HEIGHT,
out_w: int = DISPLAY_WIDTH,
) -> ImageArray:
src_h, src_w = gray.shape[:2]
if src_h <= 0 or src_w <= 0:
return np.zeros((out_h, out_w), dtype=np.uint8)
scale = min(out_w / src_w, out_h / src_h)
new_w = max(1, int(round(src_w * scale)))
new_h = max(1, int(round(src_h * scale)))
resized = cast(
ImageArray,
cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_NEAREST),
)
canvas = np.zeros((out_h, out_w), dtype=np.uint8)
x0 = (out_w - new_w) // 2
y0 = (out_h - new_h) // 2
canvas[y0 : y0 + new_h, x0 : x0 + new_w] = resized
return cast(ImageArray, canvas)
def _crop_mask_to_bbox(
self,
mask_gray: ImageArray,
bbox: BBoxXYXY | None,
) -> ImageArray:
if bbox is None:
return mask_gray
h, w = mask_gray.shape[:2]
x1, y1, x2, y2 = bbox
x1c = max(0, min(w, int(x1)))
x2c = max(0, min(w, int(x2)))
y1c = max(0, min(h, int(y1)))
y2c = max(0, min(h, int(y2)))
if x2c <= x1c or y2c <= y1c:
return mask_gray
cropped = mask_gray[y1c:y2c, x1c:x2c]
if cropped.size == 0:
return mask_gray
return cast(ImageArray, cropped)
def _prepare_segmentation_input_view(
self,
silhouettes: NDArray[np.float32] | None,
) -> ImageArray:
if silhouettes is None or silhouettes.size == 0:
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Input Silhouettes (No Data)")
return placeholder
n_frames = int(silhouettes.shape[0])
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
rows = int(np.ceil(n_frames / tiles_per_row))
tile_h = DISPLAY_HEIGHT
tile_w = DISPLAY_WIDTH
grid = np.zeros((rows * tile_h, tiles_per_row * tile_w), dtype=np.uint8)
for idx in range(n_frames):
sil = silhouettes[idx]
tile = self._upscale_silhouette(sil)
r = idx // tiles_per_row
c = idx % tiles_per_row
y0, y1 = r * tile_h, (r + 1) * tile_h
x0, x1 = c * tile_w, (c + 1) * tile_w
grid[y0:y1, x0:x1] = tile
grid_bgr = cast(ImageArray, cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR))
for idx in range(n_frames):
r = idx // tiles_per_row
c = idx % tiles_per_row
y0 = r * tile_h
x0 = c * tile_w
cv2.putText(
grid_bgr,
str(idx),
(x0 + 8, y0 + 22),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 255),
2,
cv2.LINE_AA,
)
return grid_bgr
def _prepare_raw_view(
self,
mask_raw: ImageArray | None,
bbox: BBoxXYXY | None = None,
) -> ImageArray:
"""Prepare raw mask view.
Args:
mask_raw: Raw binary mask or None
Returns:
Displayable image with mode indicator
"""
if mask_raw is None:
# Create placeholder
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Raw Mask (No Data)")
return placeholder
# Ensure single channel
if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = cast(ImageArray, mask_raw)
mask_gray = self._normalize_mask_for_display(mask_gray)
mask_gray = self._crop_mask_to_bbox(mask_gray, bbox)
debug_pad = RAW_STATS_PAD if self.show_raw_debug else 0
content_h = max(1, DISPLAY_HEIGHT - debug_pad - MODE_LABEL_PAD)
mask_resized = self._fit_gray_to_display(
mask_gray, out_h=content_h, out_w=DISPLAY_WIDTH
)
full_mask = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
full_mask[debug_pad : debug_pad + content_h, :] = mask_resized
# Convert to BGR for display
mask_bgr = cast(ImageArray, cv2.cvtColor(full_mask, cv2.COLOR_GRAY2BGR))
if self.show_raw_debug:
self._draw_raw_stats(mask_bgr, mask_raw)
self._draw_mode_indicator(mask_bgr, "Raw Mask")
return mask_bgr
def _prepare_normalized_view(
self,
silhouette: NDArray[np.float32] | None,
) -> ImageArray:
"""Prepare normalized silhouette view.
Args:
silhouette: Normalized silhouette (64, 44) or None
Returns:
Displayable image with mode indicator
"""
if silhouette is None:
# Create placeholder
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Normalized (No Data)")
return placeholder
# Upscale and convert
upscaled = self._upscale_silhouette(silhouette)
content_h = max(1, DISPLAY_HEIGHT - MODE_LABEL_PAD)
sil_compact = self._fit_gray_to_display(
upscaled, out_h=content_h, out_w=DISPLAY_WIDTH
)
sil_canvas = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
sil_canvas[:content_h, :] = sil_compact
sil_bgr = cast(ImageArray, cv2.cvtColor(sil_canvas, cv2.COLOR_GRAY2BGR))
self._draw_mode_indicator(sil_bgr, "Normalized")
return sil_bgr
def _draw_mode_indicator(self, image: ImageArray, label: str) -> None:
h, w = image.shape[:2]
mode_text = label
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 1
# Get text size for background
(text_width, text_height), _ = cv2.getTextSize(
mode_text, font, font_scale, thickness
)
x_pos = 14
y_pos = h - 8
y_top = max(0, h - MODE_LABEL_PAD)
_ = cv2.rectangle(
image,
(0, y_top),
(w, h),
COLOR_DARK_GRAY,
-1,
)
_ = cv2.rectangle(
image,
(x_pos - 6, y_pos - text_height - 6),
(x_pos + text_width + 8, y_pos + 6),
COLOR_DARK_GRAY,
-1,
)
# Draw text
_ = cv2.putText(
image,
mode_text,
(x_pos, y_pos),
font,
font_scale,
COLOR_YELLOW,
thickness,
)
def update(
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
bbox_mask: BBoxXYXY | None,
track_id: int,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
segmentation_input: NDArray[np.float32] | None,
label: str | None,
confidence: float | None,
fps: float,
) -> bool:
"""Update visualization with new frame data.
Args:
frame: Input frame (H, W, C) uint8
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID
mask_raw: Raw binary mask (H, W) uint8 or None
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
label: Classification label or None
confidence: Classification confidence [0,1] or None
fps: Current FPS
Returns:
False if user requested quit (pressed 'q'), True otherwise
"""
self._ensure_windows()
# Prepare and show main window
main_display = self._prepare_main_frame(
frame, bbox, track_id, fps, label, confidence
)
cv2.imshow(MAIN_WINDOW, main_display)
# Prepare and show segmentation window
seg_display = self._prepare_segmentation_view(mask_raw, silhouette, bbox)
cv2.imshow(SEG_WINDOW, seg_display)
if self.show_raw_window:
self._ensure_raw_window()
raw_display = self._prepare_raw_view(mask_raw, bbox_mask)
cv2.imshow(RAW_WINDOW, raw_display)
seg_input_display = self._prepare_segmentation_input_view(segmentation_input)
cv2.imshow(WINDOW_SEG_INPUT, seg_input_display)
# Handle keyboard input
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
return False
elif key == ord("r"):
self.show_raw_window = not self.show_raw_window
if self.show_raw_window:
self._ensure_raw_window()
logger.debug("Raw mask window enabled")
else:
self._hide_raw_window()
logger.debug("Raw mask window disabled")
elif key == ord("d"):
self.show_raw_debug = not self.show_raw_debug
logger.debug(
"Raw mask debug overlay %s",
"enabled" if self.show_raw_debug else "disabled",
)
return True
def close(self) -> None:
if self._windows_created:
self._hide_raw_window()
cv2.destroyAllWindows()
self._windows_created = False
self._raw_window_created = False
-375
View File
@@ -1,375 +0,0 @@
"""Sliding window / ring buffer manager for real-time gait analysis.
This module provides bounded buffer management for silhouette sequences
with track ID tracking and gap detection.
"""
from collections import deque
from typing import TYPE_CHECKING, Protocol, cast, final
import numpy as np
import torch
from jaxtyping import Float
from numpy import ndarray
from .preprocess import BBoxXYXY
if TYPE_CHECKING:
from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44
# Type alias for array-like inputs
type _ArrayLike = torch.Tensor | ndarray
class _Boxes(Protocol):
"""Protocol for boxes with xyxy and id attributes."""
@property
def xyxy(self) -> "NDArray[np.float32] | object": ...
@property
def id(self) -> "NDArray[np.int64] | object | None": ...
class _Masks(Protocol):
"""Protocol for masks with data attribute."""
@property
def data(self) -> "NDArray[np.float32] | object": ...
class _DetectionResults(Protocol):
"""Protocol for detection results from Ultralytics-style objects."""
@property
def boxes(self) -> _Boxes: ...
@property
def masks(self) -> _Masks: ...
@final
class SilhouetteWindow:
"""Bounded sliding window for silhouette sequences.
Manages a fixed-size buffer of silhouettes with track ID tracking
and automatic reset on track changes or frame gaps.
Attributes:
window_size: Maximum number of frames in the buffer.
stride: Classification stride (frames between classifications).
gap_threshold: Maximum allowed frame gap before reset.
"""
window_size: int
stride: int
gap_threshold: int
_buffer: deque[Float[ndarray, "64 44"]]
_frame_indices: deque[int]
_track_id: int | None
_last_classified_frame: int
_frame_count: int
def __init__(
self,
window_size: int = 30,
stride: int = 1,
gap_threshold: int = 15,
) -> None:
"""Initialize the silhouette window.
Args:
window_size: Maximum buffer size (default 30).
stride: Frames between classifications (default 1).
gap_threshold: Max frame gap before reset (default 15).
"""
self.window_size = window_size
self.stride = stride
self.gap_threshold = gap_threshold
# Bounded storage via deque
self._buffer = deque(maxlen=window_size)
self._frame_indices = deque(maxlen=window_size)
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
"""Push a new silhouette into the window.
Automatically resets buffer on track ID change or frame gap
exceeding gap_threshold.
Args:
sil: Silhouette array of shape (64, 44), float32.
frame_idx: Current frame index for gap detection.
track_id: Track ID for the person.
"""
# Check for track ID change
if self._track_id is not None and track_id != self._track_id:
self.reset()
# Check for frame gap
if self._frame_indices:
last_frame = self._frame_indices[-1]
gap = frame_idx - last_frame
if gap > self.gap_threshold:
self.reset()
# Update track ID
self._track_id = track_id
# Validate and append silhouette
sil_array = np.asarray(sil, dtype=np.float32)
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
raise ValueError(
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
)
self._buffer.append(sil_array)
self._frame_indices.append(frame_idx)
self._frame_count += 1
def is_ready(self) -> bool:
"""Check if window has enough frames for classification.
Returns:
True if buffer is full (window_size frames).
"""
return len(self._buffer) >= self.window_size
def should_classify(self) -> bool:
"""Check if classification should run based on stride.
Returns:
True if enough frames have passed since last classification.
"""
if not self.is_ready():
return False
if self._last_classified_frame < 0:
return True
current_frame = self._frame_indices[-1]
frames_since = current_frame - self._last_classified_frame
return frames_since >= self.stride
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
"""Get window contents as a tensor for model input.
Args:
device: Target device for the tensor (default 'cpu').
Returns:
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
Raises:
ValueError: If buffer is not full.
"""
if not self.is_ready():
raise ValueError(
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
)
# Stack buffer into array [window_size, 64, 44]
stacked = np.stack(list(self._buffer), axis=0)
# Add batch and channel dims: [1, 1, window_size, 64, 44]
tensor = torch.from_numpy(stacked.astype(np.float32))
tensor = tensor.unsqueeze(0).unsqueeze(0)
return tensor.to(device)
def reset(self) -> None:
"""Reset the window, clearing all buffers and counters."""
self._buffer.clear()
self._frame_indices.clear()
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def mark_classified(self) -> None:
"""Mark current frame as classified, updating stride tracking."""
if self._frame_indices:
self._last_classified_frame = self._frame_indices[-1]
@property
def current_track_id(self) -> int | None:
"""Current track ID, or None if buffer is empty."""
return self._track_id
@property
def frame_count(self) -> int:
"""Total frames pushed since last reset."""
return self._frame_count
@property
def fill_level(self) -> float:
"""Fill ratio of the buffer (0.0 to 1.0)."""
return len(self._buffer) / self.window_size
@property
def window_start_frame(self) -> int:
if not self._frame_indices:
raise ValueError("Window is empty")
return int(self._frame_indices[0])
@property
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
if not self._buffer:
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
return cast(
Float[ndarray, "n 64 44"],
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
)
def _to_numpy(obj: _ArrayLike) -> ndarray:
"""Safely convert array-like object to numpy array.
Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first.
Falls back to np.asarray for other array-like objects.
Args:
obj: Array-like object (numpy array, torch tensor, or similar).
Returns:
Numpy array representation of the input.
"""
# Handle torch tensors (including CUDA tensors)
detach_fn = getattr(obj, "detach", None)
if detach_fn is not None and callable(detach_fn):
# It's a torch tensor
tensor = detach_fn()
cpu_fn = getattr(tensor, "cpu", None)
if cpu_fn is not None and callable(cpu_fn):
tensor = cpu_fn()
numpy_fn = getattr(tensor, "numpy", None)
if numpy_fn is not None and callable(numpy_fn):
return cast(ndarray, numpy_fn())
# Fall back to np.asarray for other array-like objects
return cast(ndarray, np.asarray(obj))
def select_person(
results: _DetectionResults,
) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
"""Select the person with largest bounding box from detection results.
Args:
results: Detection results object with boxes and masks attributes.
Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
- masks.data: array of masks [N, H, W] in mask coordinates
- boxes.id: optional track IDs [N]
Returns:
Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
or None if no valid detections or track IDs unavailable.
- mask: the person's segmentation mask
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
- track_id: the person's track ID
"""
# Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
if boxes_obj is None:
return None
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
if track_ids_obj is None:
return None
track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj))
if track_ids.size == 0:
return None
# Get bounding boxes
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
if xyxy_obj is None:
return None
bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj))
if bboxes.ndim == 1:
bboxes = bboxes.reshape(1, -1)
if bboxes.shape[0] == 0:
return None
# Get masks
masks_obj: _Masks | object = getattr(results, "masks", None)
if masks_obj is None:
return None
masks_data: ndarray | object = getattr(masks_obj, "data", None)
if masks_data is None:
return None
masks: ndarray = _to_numpy(cast(ndarray, masks_data))
if masks.ndim == 2:
masks = masks[np.newaxis, ...]
if masks.shape[0] != bboxes.shape[0]:
return None
# Find largest bbox by area
best_idx: int = -1
best_area: float = -1.0
for i in range(int(bboxes.shape[0])):
row: "NDArray[np.float32]" = bboxes[i][:4]
x1f: float = float(row[0])
y1f: float = float(row[1])
x2f: float = float(row[2])
y2f: float = float(row[3])
area: float = (x2f - x1f) * (y2f - y1f)
if area > best_area:
best_area = area
best_idx = i
if best_idx < 0:
return None
# Extract mask and bbox
mask: "NDArray[np.float32]" = masks[best_idx]
mask_shape = mask.shape
mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1])
# Get original image dimensions from results (YOLO provides this)
orig_shape = getattr(results, "orig_shape", None)
# Validate orig_shape is a sequence of at least 2 numeric values
if (
orig_shape is not None
and isinstance(orig_shape, (tuple, list))
and len(orig_shape) >= 2
):
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
# Scale bbox from frame space to mask space
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
bbox_mask = (
int(float(bboxes[best_idx][0]) * scale_x),
int(float(bboxes[best_idx][1]) * scale_y),
int(float(bboxes[best_idx][2]) * scale_x),
int(float(bboxes[best_idx][3]) * scale_y),
)
bbox_frame = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
else:
# Fallback: use bbox as-is for both (assume same coordinate space)
bbox_mask = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
bbox_frame = bbox_mask
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox_mask, bbox_frame, track_id
+1 -1
View File
@@ -1,7 +1,7 @@
import os
from time import strftime, localtime
import numpy as np
from utils import get_msg_mgr, mkdir
from opengait.utils import get_msg_mgr, mkdir
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
from .re_rank import re_ranking
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import numpy as np
import torch.nn.functional as F
from utils import is_tensor
from opengait.utils import is_tensor
def cuda_dist(x, y, metric='euc'):
+1 -1
View File
@@ -4,7 +4,7 @@ import argparse
import torch
import torch.nn as nn
from modeling import models
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
parser = argparse.ArgumentParser(description='Main program for opengait.')
parser.add_argument('--local_rank', type=int, default=0,
+4 -4
View File
@@ -28,11 +28,11 @@ from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from utils import Odict, mkdir, ddp_all_gather
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from opengait.utils import Odict, mkdir, ddp_all_gather
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions
from utils import NoOp
from utils import get_msg_mgr
from opengait.utils import NoOp
from opengait.utils import get_msg_mgr
__all__ = ['BaseModel']
+3 -3
View File
@@ -3,9 +3,9 @@
import torch
import torch.nn as nn
from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict
from utils import get_msg_mgr
from opengait.utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from opengait.utils import Odict
from opengait.utils import get_msg_mgr
class LossAggregator(nn.Module):
+2 -2
View File
@@ -1,9 +1,9 @@
from ctypes import ArgumentError
import torch.nn as nn
import torch
from utils import Odict
from opengait.utils import Odict
import functools
from utils import ddp_all_gather
from opengait.utils import ddp_all_gather
def gather_and_scale_wrapper(func):
@@ -132,7 +132,7 @@ class Post_ResNet9(ResNet):
return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones
class Baseline(nn.Module):
def __init__(self, model_cfg):
+1 -1
View File
@@ -4,7 +4,7 @@ from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, FlowFunc
import torch.optim as optim
from einops import rearrange
from utils import get_valid_args
from opengait.utils import get_valid_args
import warnings
import random
from torchvision.utils import flow_to_image
@@ -161,7 +161,7 @@ class Post_ResNet9(ResNet):
return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones
class GaitBaseFusion_denoise(nn.Module):
def __init__(self, model_cfg):
+1 -1
View File
@@ -6,7 +6,7 @@ from ..base_model import BaseModel
from .gaitgl import GaitGL
from ..modules import GaitAlign
from torchvision.transforms import Resize
from utils import get_valid_args, get_attr_from, is_list_or_tuple
from opengait.utils import get_valid_args, get_attr_from, is_list_or_tuple
import os.path as osp
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
from utils import clones
from opengait.utils import clones
class BasicConv1d(nn.Module):
+2 -2
View File
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
from utils import np2var, list2var, get_valid_args, ddp_all_gather
from opengait.utils import np2var, list2var, get_valid_args, ddp_all_gather
from data.transform import get_transform
from einops import rearrange
@@ -143,7 +143,7 @@ class GaitSSB_Pretrain(BaseModel):
import torch.optim as optim
import numpy as np
from utils import get_valid_args, list2var
from opengait.utils import get_valid_args, list2var
class no_grad(torch.no_grad):
def __init__(self, enable=True):
+1 -1
View File
@@ -782,7 +782,7 @@ from ..modules import BasicBlock2D, BasicBlockP3D
import torch.optim as optim
import os.path as osp
from collections import OrderedDict
from utils import get_valid_args, get_attr_from
from opengait.utils import get_valid_args, get_attr_from
class SwinGait(BaseModel):
def __init__(self, cfgs, training):
+1 -1
View File
@@ -2,7 +2,7 @@ import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import clones, is_list_or_tuple
from opengait.utils import clones, is_list_or_tuple
from torchvision.ops import RoIAlign