feat: retain best checkpoints and support alternate output roots

This commit is contained in:
2026-03-11 01:14:05 +08:00
parent 63e2ed1097
commit a0150c791f
14 changed files with 852 additions and 9 deletions
+133 -3
View File
@@ -10,8 +10,10 @@ BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import json
import math
import os
import random
import re
from typing import Any
import numpy as np
@@ -33,7 +35,7 @@ from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from opengait.utils import Odict, mkdir, ddp_all_gather
from opengait.utils import Odict, mkdir, ddp_all_gather, resolve_output_path
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions
from opengait.utils import NoOp
@@ -144,8 +146,7 @@ class BaseModel(MetaModel, nn.Module):
if training and self.engine_cfg['enable_float16']:
self.Scaler = GradScaler()
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
self.save_path = resolve_output_path(cfgs, self.engine_cfg)
self.build_network(cfgs['model_cfg'])
self.init_parameters()
@@ -317,6 +318,134 @@ class BaseModel(MetaModel, nn.Module):
return candidate
return None
def _best_ckpt_cfg(self) -> dict[str, Any] | None:
best_ckpt_cfg = self.engine_cfg.get('best_ckpt_cfg')
if not isinstance(best_ckpt_cfg, dict):
return None
keep_n = int(best_ckpt_cfg.get('keep_n', 0))
metric_names = best_ckpt_cfg.get('metric_names', [])
if keep_n <= 0 or not isinstance(metric_names, list) or not metric_names:
return None
return best_ckpt_cfg
def _best_ckpt_root(self) -> str:
return osp.join(self._checkpoint_dir(), "best")
def _best_metric_dir(self, metric_name: str) -> str:
metric_slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", metric_name).strip("._")
return osp.join(self._best_ckpt_root(), metric_slug)
def _best_metric_index_path(self, metric_name: str) -> str:
return osp.join(self._best_metric_dir(metric_name), "index.json")
def _load_best_metric_index(self, metric_name: str) -> list[dict[str, Any]]:
index_path = self._best_metric_index_path(metric_name)
if not osp.isfile(index_path):
return []
with open(index_path, "r", encoding="utf-8") as handle:
raw_entries = json.load(handle)
if not isinstance(raw_entries, list):
return []
entries: list[dict[str, Any]] = []
for entry in raw_entries:
if not isinstance(entry, dict):
continue
path = entry.get("path")
if isinstance(path, str) and osp.isfile(path):
entries.append(entry)
return entries
def _write_best_metric_index(
self,
metric_name: str,
entries: list[dict[str, Any]],
) -> None:
index_path = self._best_metric_index_path(metric_name)
mkdir(osp.dirname(index_path))
tmp_path = index_path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(entries, handle, indent=2, sort_keys=True)
os.replace(tmp_path, index_path)
def _summary_scalar(self, value: Any) -> float | None:
if isinstance(value, torch.Tensor):
return float(value.detach().float().mean().item())
if isinstance(value, np.ndarray):
return float(np.mean(value))
if isinstance(value, (float, int, np.floating, np.integer)):
return float(value)
return None
def _save_best_ckpts(
self,
iteration: int,
result_dict: dict[str, Any],
) -> None:
if torch.distributed.get_rank() != 0:
return
best_ckpt_cfg = self._best_ckpt_cfg()
if best_ckpt_cfg is None:
return
keep_n = int(best_ckpt_cfg['keep_n'])
metric_names = [metric for metric in best_ckpt_cfg['metric_names'] if metric in result_dict]
if not metric_names:
return
checkpoint: dict[str, Any] | None = None
save_name = self.engine_cfg['save_name']
for metric_name in metric_names:
score = self._summary_scalar(result_dict.get(metric_name))
if score is None or not math.isfinite(score):
continue
entries = [
entry for entry in self._load_best_metric_index(metric_name)
if int(entry.get("iteration", -1)) != iteration
]
ranked_entries = sorted(
entries + [{"iteration": iteration, "score": score, "path": ""}],
key=lambda entry: (float(entry["score"]), int(entry["iteration"])),
reverse=True,
)
kept_entries = ranked_entries[:keep_n]
if not any(int(entry["iteration"]) == iteration for entry in kept_entries):
continue
metric_dir = self._best_metric_dir(metric_name)
mkdir(metric_dir)
metric_slug = osp.basename(metric_dir)
best_path = osp.join(
metric_dir,
f"{save_name}-iter-{iteration:0>5}-score-{score:.4f}-{metric_slug}.pt",
)
if checkpoint is None:
checkpoint = self._build_checkpoint(iteration)
self._save_checkpoint_file(checkpoint, best_path)
refreshed_entries = []
for entry in kept_entries:
if int(entry["iteration"]) == iteration:
refreshed_entries.append(
{
"iteration": iteration,
"score": score,
"path": best_path,
}
)
else:
refreshed_entries.append(entry)
keep_paths = {entry["path"] for entry in refreshed_entries if isinstance(entry.get("path"), str)}
for stale_entry in entries:
stale_path = stale_entry.get("path")
if isinstance(stale_path, str) and stale_path not in keep_paths and osp.isfile(stale_path):
os.remove(stale_path)
self._write_best_metric_index(metric_name, refreshed_entries)
def save_ckpt(self, iteration):
if torch.distributed.get_rank() == 0:
save_name = self.engine_cfg['save_name']
@@ -589,6 +718,7 @@ class BaseModel(MetaModel, nn.Module):
if result_dict:
model.msg_mgr.write_to_tensorboard(result_dict)
model.msg_mgr.write_to_wandb(result_dict)
model._save_best_ckpts(model.iteration, result_dict)
model.msg_mgr.reset_time()
if model.iteration >= model.engine_cfg['total_iter']:
break