diff --git a/README.md b/README.md index 877639e..69ab5a8 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ References: ```bash uv sync --group dev --group detection +uv run patch-mmdet-version-gate uv run pose-tracking-exp run_detection --config detection.toml camera0 camera1 uv run pose-tracking-exp run_detection --source video --output-dir data/detections --config detection.toml cam0=/data/cam0.mp4 cam1=/data/cam1.mp4 ``` @@ -58,6 +59,7 @@ The default backend is `yolo_rtmpose`, and the heavy runtime dependencies live i Checkpoint paths are explicit config fields; the code does not hardcode local checkpoint locations. The only inferred path is the MMPose config path, which is resolved relative to the installed `mmpose` package when `pose_config_path` is omitted. For offline video runs, the default sink is parquet and writes one `*_detected.parquet` file per source. `run_tracking` can consume that directory directly as replay input. +`uv run patch-mmdet-version-gate` is an idempotent local-environment patch for the current `mmdet` compatibility assert against the rebuilt `mmcv` wheel. Re-run it after `uv sync` if the environment is recreated. Example `detection.toml`: diff --git a/pyproject.toml b/pyproject.toml index 09ccd3c..193be60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ detection = [ [project.scripts] pose-tracking-exp = "pose_tracking_exp.cli:main" +patch-mmdet-version-gate = "pose_tracking_exp.detection.patch_mmdet:main" [tool.hatch.build.targets.wheel] packages = ["src/pose_tracking_exp"] diff --git a/src/pose_tracking_exp/detection/__init__.py b/src/pose_tracking_exp/detection/__init__.py index 0927a60..4d17de1 100644 --- a/src/pose_tracking_exp/detection/__init__.py +++ b/src/pose_tracking_exp/detection/__init__.py @@ -6,6 +6,11 @@ from pose_tracking_exp.detection.config import ( resolve_instances, ) from pose_tracking_exp.detection.factory import build_pose_shim +from pose_tracking_exp.detection.patch_mmdet import ( + patch_mmdet_init_file, + patch_mmdet_init_text, + resolve_mmdet_init_path, +) from pose_tracking_exp.detection.runner import ( SimpleMovingAverage, SourceSlot, @@ -39,6 +44,9 @@ __all__ = [ "ParquetPoseSink", "PoseBatchRequest", "PoseDetections", + "patch_mmdet_init_file", + "patch_mmdet_init_text", + "resolve_mmdet_init_path", "SimpleMovingAverage", "SourceFrame", "SourceSlot", diff --git a/src/pose_tracking_exp/detection/patch_mmdet.py b/src/pose_tracking_exp/detection/patch_mmdet.py new file mode 100644 index 0000000..61b98e6 --- /dev/null +++ b/src/pose_tracking_exp/detection/patch_mmdet.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import argparse +import importlib.util +from pathlib import Path +import re +import sys + +PATCH_SENTINEL = "Proceeding anyway because this environment intentionally uses the local build." + +MMCV_ASSERT_RE = re.compile( + r"assert\s*\(mmcv_version >= digit_version\(mmcv_minimum_version\)\s*" + r"and mmcv_version < digit_version\(mmcv_maximum_version\)\), \\\n" + r"\s*f'MMCV==\{mmcv\.__version__\} is used but incompatible\. ' \\\n" + r"\s*f'Please install mmcv>=\{mmcv_minimum_version\}, ' \\\n" + r"\s*f'<\{mmcv_maximum_version\}\.'", + re.MULTILINE, +) + +WARN_IMPORT = "from warnings import warn\n" +WARN_BLOCK = """if not ( + mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version) +): + warn( + f'MMCV=={mmcv.__version__} is outside the tested mmdet range ' + f'[{mmcv_minimum_version}, {mmcv_maximum_version}). ' + 'Proceeding anyway because this environment intentionally uses the local build.', + stacklevel=1, + )""" + + +def patch_mmdet_init_text(source: str) -> tuple[str, bool]: + if PATCH_SENTINEL in source: + return source, False + + if WARN_IMPORT not in source: + import_anchor = "from mmengine.utils import digit_version\n" + if import_anchor not in source: + raise RuntimeError("Could not find the mmengine digit_version import anchor in mmdet.__init__.") + source = source.replace(import_anchor, import_anchor + WARN_IMPORT, 1) + + patched, count = MMCV_ASSERT_RE.subn(WARN_BLOCK, source, count=1) + if count != 1: + raise RuntimeError("Could not find the expected mmcv compatibility assert in mmdet.__init__.") + return patched, True + + +def resolve_mmdet_init_path() -> Path: + spec = importlib.util.find_spec("mmdet") + if spec is None or spec.origin is None: + raise RuntimeError("Could not locate the installed `mmdet` package.") + path = Path(spec.origin) + if path.name != "__init__.py": + raise RuntimeError(f"Expected mmdet to resolve to __init__.py, got: {path}") + return path + + +def patch_mmdet_init_file(path: Path) -> bool: + source = path.read_text(encoding="utf-8") + patched, changed = patch_mmdet_init_text(source) + if changed: + path.write_text(patched, encoding="utf-8") + return changed + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Patch the installed mmdet version gate to allow the local mmcv 2.2.0 build.", + ) + parser.add_argument( + "--path", + type=Path, + default=None, + help="Optional explicit path to mmdet/__init__.py. Defaults to the installed package.", + ) + return parser + + +def main() -> int: + parser = build_parser() + args = parser.parse_args() + path = args.path if args.path is not None else resolve_mmdet_init_path() + changed = patch_mmdet_init_file(path) + if changed: + print(f"Patched mmdet version gate at {path}") + else: + print(f"mmdet version gate already patched at {path}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_patch_mmdet.py b/tests/test_patch_mmdet.py new file mode 100644 index 0000000..c1e65a8 --- /dev/null +++ b/tests/test_patch_mmdet.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from pose_tracking_exp.detection.patch_mmdet import PATCH_SENTINEL, patch_mmdet_init_file, patch_mmdet_init_text + + +UNPATCHED_INIT = """# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +from .version import __version__, version_info + +mmcv_minimum_version = '2.0.0rc4' +mmcv_maximum_version = '2.2.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.7.1' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \\ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \\ + f'Please install mmcv>={mmcv_minimum_version}, ' \\ + f'<{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \\ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \\ + f'Please install mmengine>={mmengine_minimum_version}, ' \\ + f'<{mmengine_maximum_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] +""" + + +def test_patch_mmdet_init_text_rewrites_mmcv_assert() -> None: + patched, changed = patch_mmdet_init_text(UNPATCHED_INIT) + + assert changed is True + assert PATCH_SENTINEL in patched + assert "from warnings import warn" in patched + assert "assert (mmengine_version >= digit_version(mmengine_minimum_version)" in patched + + +def test_patch_mmdet_init_file_is_idempotent(tmp_path: Path) -> None: + init_path = tmp_path / "__init__.py" + init_path.write_text(UNPATCHED_INIT, encoding="utf-8") + + assert patch_mmdet_init_file(init_path) is True + once = init_path.read_text(encoding="utf-8") + assert PATCH_SENTINEL in once + + assert patch_mmdet_init_file(init_path) is False + twice = init_path.read_text(encoding="utf-8") + assert once == twice