feat: retain best checkpoints and support alternate output roots
This commit is contained in:
@@ -10,8 +10,10 @@ BaseModel.run_train(model)
|
||||
BaseModel.run_test(model)
|
||||
"""
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -33,7 +35,7 @@ from data.transform import get_transform
|
||||
from data.collate_fn import CollateFn
|
||||
from data.dataset import DataSet
|
||||
import data.sampler as Samplers
|
||||
from opengait.utils import Odict, mkdir, ddp_all_gather
|
||||
from opengait.utils import Odict, mkdir, ddp_all_gather, resolve_output_path
|
||||
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
||||
from evaluation import evaluator as eval_functions
|
||||
from opengait.utils import NoOp
|
||||
@@ -144,8 +146,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
if training and self.engine_cfg['enable_float16']:
|
||||
self.Scaler = GradScaler()
|
||||
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
|
||||
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
|
||||
self.save_path = resolve_output_path(cfgs, self.engine_cfg)
|
||||
|
||||
self.build_network(cfgs['model_cfg'])
|
||||
self.init_parameters()
|
||||
@@ -317,6 +318,134 @@ class BaseModel(MetaModel, nn.Module):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _best_ckpt_cfg(self) -> dict[str, Any] | None:
|
||||
best_ckpt_cfg = self.engine_cfg.get('best_ckpt_cfg')
|
||||
if not isinstance(best_ckpt_cfg, dict):
|
||||
return None
|
||||
keep_n = int(best_ckpt_cfg.get('keep_n', 0))
|
||||
metric_names = best_ckpt_cfg.get('metric_names', [])
|
||||
if keep_n <= 0 or not isinstance(metric_names, list) or not metric_names:
|
||||
return None
|
||||
return best_ckpt_cfg
|
||||
|
||||
def _best_ckpt_root(self) -> str:
|
||||
return osp.join(self._checkpoint_dir(), "best")
|
||||
|
||||
def _best_metric_dir(self, metric_name: str) -> str:
|
||||
metric_slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", metric_name).strip("._")
|
||||
return osp.join(self._best_ckpt_root(), metric_slug)
|
||||
|
||||
def _best_metric_index_path(self, metric_name: str) -> str:
|
||||
return osp.join(self._best_metric_dir(metric_name), "index.json")
|
||||
|
||||
def _load_best_metric_index(self, metric_name: str) -> list[dict[str, Any]]:
|
||||
index_path = self._best_metric_index_path(metric_name)
|
||||
if not osp.isfile(index_path):
|
||||
return []
|
||||
with open(index_path, "r", encoding="utf-8") as handle:
|
||||
raw_entries = json.load(handle)
|
||||
if not isinstance(raw_entries, list):
|
||||
return []
|
||||
entries: list[dict[str, Any]] = []
|
||||
for entry in raw_entries:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
path = entry.get("path")
|
||||
if isinstance(path, str) and osp.isfile(path):
|
||||
entries.append(entry)
|
||||
return entries
|
||||
|
||||
def _write_best_metric_index(
|
||||
self,
|
||||
metric_name: str,
|
||||
entries: list[dict[str, Any]],
|
||||
) -> None:
|
||||
index_path = self._best_metric_index_path(metric_name)
|
||||
mkdir(osp.dirname(index_path))
|
||||
tmp_path = index_path + ".tmp"
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(entries, handle, indent=2, sort_keys=True)
|
||||
os.replace(tmp_path, index_path)
|
||||
|
||||
def _summary_scalar(self, value: Any) -> float | None:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return float(value.detach().float().mean().item())
|
||||
if isinstance(value, np.ndarray):
|
||||
return float(np.mean(value))
|
||||
if isinstance(value, (float, int, np.floating, np.integer)):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
def _save_best_ckpts(
|
||||
self,
|
||||
iteration: int,
|
||||
result_dict: dict[str, Any],
|
||||
) -> None:
|
||||
if torch.distributed.get_rank() != 0:
|
||||
return
|
||||
best_ckpt_cfg = self._best_ckpt_cfg()
|
||||
if best_ckpt_cfg is None:
|
||||
return
|
||||
|
||||
keep_n = int(best_ckpt_cfg['keep_n'])
|
||||
metric_names = [metric for metric in best_ckpt_cfg['metric_names'] if metric in result_dict]
|
||||
if not metric_names:
|
||||
return
|
||||
|
||||
checkpoint: dict[str, Any] | None = None
|
||||
save_name = self.engine_cfg['save_name']
|
||||
|
||||
for metric_name in metric_names:
|
||||
score = self._summary_scalar(result_dict.get(metric_name))
|
||||
if score is None or not math.isfinite(score):
|
||||
continue
|
||||
|
||||
entries = [
|
||||
entry for entry in self._load_best_metric_index(metric_name)
|
||||
if int(entry.get("iteration", -1)) != iteration
|
||||
]
|
||||
ranked_entries = sorted(
|
||||
entries + [{"iteration": iteration, "score": score, "path": ""}],
|
||||
key=lambda entry: (float(entry["score"]), int(entry["iteration"])),
|
||||
reverse=True,
|
||||
)
|
||||
kept_entries = ranked_entries[:keep_n]
|
||||
if not any(int(entry["iteration"]) == iteration for entry in kept_entries):
|
||||
continue
|
||||
|
||||
metric_dir = self._best_metric_dir(metric_name)
|
||||
mkdir(metric_dir)
|
||||
metric_slug = osp.basename(metric_dir)
|
||||
best_path = osp.join(
|
||||
metric_dir,
|
||||
f"{save_name}-iter-{iteration:0>5}-score-{score:.4f}-{metric_slug}.pt",
|
||||
)
|
||||
|
||||
if checkpoint is None:
|
||||
checkpoint = self._build_checkpoint(iteration)
|
||||
self._save_checkpoint_file(checkpoint, best_path)
|
||||
|
||||
refreshed_entries = []
|
||||
for entry in kept_entries:
|
||||
if int(entry["iteration"]) == iteration:
|
||||
refreshed_entries.append(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"score": score,
|
||||
"path": best_path,
|
||||
}
|
||||
)
|
||||
else:
|
||||
refreshed_entries.append(entry)
|
||||
|
||||
keep_paths = {entry["path"] for entry in refreshed_entries if isinstance(entry.get("path"), str)}
|
||||
for stale_entry in entries:
|
||||
stale_path = stale_entry.get("path")
|
||||
if isinstance(stale_path, str) and stale_path not in keep_paths and osp.isfile(stale_path):
|
||||
os.remove(stale_path)
|
||||
|
||||
self._write_best_metric_index(metric_name, refreshed_entries)
|
||||
|
||||
def save_ckpt(self, iteration):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
save_name = self.engine_cfg['save_name']
|
||||
@@ -589,6 +718,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if result_dict:
|
||||
model.msg_mgr.write_to_tensorboard(result_dict)
|
||||
model.msg_mgr.write_to_wandb(result_dict)
|
||||
model._save_best_ckpts(model.iteration, result_dict)
|
||||
model.msg_mgr.reset_time()
|
||||
if model.iteration >= model.engine_cfg['total_iter']:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user