feat: retain best checkpoints and support alternate output roots

This commit is contained in:
2026-03-11 01:14:05 +08:00
parent 63e2ed1097
commit a0150c791f
14 changed files with 852 additions and 9 deletions
+9 -3
View File
@@ -4,7 +4,14 @@ import argparse
import torch
import torch.nn as nn
from modeling import models
from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
from opengait.utils import (
config_loader,
get_ddp_module,
get_msg_mgr,
init_seeds,
params_count,
resolve_output_path,
)
parser = argparse.ArgumentParser(description='Main program for opengait.')
parser.add_argument('--local_rank', type=int, default=0,
@@ -25,8 +32,7 @@ def initialization(cfgs, training):
msg_mgr = get_msg_mgr()
engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
logger_cfg = cfgs.get('logger_cfg', {})
output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'], engine_cfg['save_name'])
output_path = resolve_output_path(cfgs, engine_cfg)
if training:
msg_mgr.init_manager(
output_path,
+133 -3
View File
@@ -10,8 +10,10 @@ BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import json
import math
import os
import random
import re
from typing import Any
import numpy as np
@@ -33,7 +35,7 @@ from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from opengait.utils import Odict, mkdir, ddp_all_gather
from opengait.utils import Odict, mkdir, ddp_all_gather, resolve_output_path
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions
from opengait.utils import NoOp
@@ -144,8 +146,7 @@ class BaseModel(MetaModel, nn.Module):
if training and self.engine_cfg['enable_float16']:
self.Scaler = GradScaler()
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
self.save_path = resolve_output_path(cfgs, self.engine_cfg)
self.build_network(cfgs['model_cfg'])
self.init_parameters()
@@ -317,6 +318,134 @@ class BaseModel(MetaModel, nn.Module):
return candidate
return None
def _best_ckpt_cfg(self) -> dict[str, Any] | None:
best_ckpt_cfg = self.engine_cfg.get('best_ckpt_cfg')
if not isinstance(best_ckpt_cfg, dict):
return None
keep_n = int(best_ckpt_cfg.get('keep_n', 0))
metric_names = best_ckpt_cfg.get('metric_names', [])
if keep_n <= 0 or not isinstance(metric_names, list) or not metric_names:
return None
return best_ckpt_cfg
def _best_ckpt_root(self) -> str:
return osp.join(self._checkpoint_dir(), "best")
def _best_metric_dir(self, metric_name: str) -> str:
metric_slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", metric_name).strip("._")
return osp.join(self._best_ckpt_root(), metric_slug)
def _best_metric_index_path(self, metric_name: str) -> str:
return osp.join(self._best_metric_dir(metric_name), "index.json")
def _load_best_metric_index(self, metric_name: str) -> list[dict[str, Any]]:
index_path = self._best_metric_index_path(metric_name)
if not osp.isfile(index_path):
return []
with open(index_path, "r", encoding="utf-8") as handle:
raw_entries = json.load(handle)
if not isinstance(raw_entries, list):
return []
entries: list[dict[str, Any]] = []
for entry in raw_entries:
if not isinstance(entry, dict):
continue
path = entry.get("path")
if isinstance(path, str) and osp.isfile(path):
entries.append(entry)
return entries
def _write_best_metric_index(
self,
metric_name: str,
entries: list[dict[str, Any]],
) -> None:
index_path = self._best_metric_index_path(metric_name)
mkdir(osp.dirname(index_path))
tmp_path = index_path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(entries, handle, indent=2, sort_keys=True)
os.replace(tmp_path, index_path)
def _summary_scalar(self, value: Any) -> float | None:
if isinstance(value, torch.Tensor):
return float(value.detach().float().mean().item())
if isinstance(value, np.ndarray):
return float(np.mean(value))
if isinstance(value, (float, int, np.floating, np.integer)):
return float(value)
return None
def _save_best_ckpts(
self,
iteration: int,
result_dict: dict[str, Any],
) -> None:
if torch.distributed.get_rank() != 0:
return
best_ckpt_cfg = self._best_ckpt_cfg()
if best_ckpt_cfg is None:
return
keep_n = int(best_ckpt_cfg['keep_n'])
metric_names = [metric for metric in best_ckpt_cfg['metric_names'] if metric in result_dict]
if not metric_names:
return
checkpoint: dict[str, Any] | None = None
save_name = self.engine_cfg['save_name']
for metric_name in metric_names:
score = self._summary_scalar(result_dict.get(metric_name))
if score is None or not math.isfinite(score):
continue
entries = [
entry for entry in self._load_best_metric_index(metric_name)
if int(entry.get("iteration", -1)) != iteration
]
ranked_entries = sorted(
entries + [{"iteration": iteration, "score": score, "path": ""}],
key=lambda entry: (float(entry["score"]), int(entry["iteration"])),
reverse=True,
)
kept_entries = ranked_entries[:keep_n]
if not any(int(entry["iteration"]) == iteration for entry in kept_entries):
continue
metric_dir = self._best_metric_dir(metric_name)
mkdir(metric_dir)
metric_slug = osp.basename(metric_dir)
best_path = osp.join(
metric_dir,
f"{save_name}-iter-{iteration:0>5}-score-{score:.4f}-{metric_slug}.pt",
)
if checkpoint is None:
checkpoint = self._build_checkpoint(iteration)
self._save_checkpoint_file(checkpoint, best_path)
refreshed_entries = []
for entry in kept_entries:
if int(entry["iteration"]) == iteration:
refreshed_entries.append(
{
"iteration": iteration,
"score": score,
"path": best_path,
}
)
else:
refreshed_entries.append(entry)
keep_paths = {entry["path"] for entry in refreshed_entries if isinstance(entry.get("path"), str)}
for stale_entry in entries:
stale_path = stale_entry.get("path")
if isinstance(stale_path, str) and stale_path not in keep_paths and osp.isfile(stale_path):
os.remove(stale_path)
self._write_best_metric_index(metric_name, refreshed_entries)
def save_ckpt(self, iteration):
if torch.distributed.get_rank() == 0:
save_name = self.engine_cfg['save_name']
@@ -589,6 +718,7 @@ class BaseModel(MetaModel, nn.Module):
if result_dict:
model.msg_mgr.write_to_tensorboard(result_dict)
model.msg_mgr.write_to_wandb(result_dict)
model._save_best_ckpts(model.iteration, result_dict)
model.msg_mgr.reset_time()
if model.iteration >= model.engine_cfg['total_iter']:
break
+2 -1
View File
@@ -7,4 +7,5 @@ from .common import mkdir, clones
from .common import MergeCfgsDict
from .common import get_attr_from
from .common import NoOp
from .msg_manager import get_msg_mgr
from .common import resolve_output_path
from .msg_manager import get_msg_mgr
+17
View File
@@ -2,6 +2,7 @@ import copy
import os
import inspect
import logging
from pathlib import Path
import torch
import numpy as np
import torch.nn as nn
@@ -203,3 +204,19 @@ def get_ddp_module(module, find_unused_parameters=False, **kwargs):
def params_count(net):
n_parameters = sum(p.numel() for p in net.parameters())
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
def resolve_output_path(cfgs, engine_cfg):
output_root = (
engine_cfg.get('output_root')
or cfgs.get('output_root')
or os.environ.get('OPENGAIT_OUTPUT_ROOT')
or 'output'
)
output_root = str(Path(output_root).expanduser())
return os.path.join(
output_root,
cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'],
engine_cfg['save_name'],
)