Add resumable ScoNet skeleton training diagnostics

This commit is contained in:
2026-03-09 15:57:13 +08:00
parent 4e0b0a18dc
commit 36aef46a0d
15 changed files with 1226 additions and 44 deletions
+150 -14
View File
@@ -9,8 +9,13 @@ Typical usage:
BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import torch
import json
import os
import random
from typing import Any
import numpy as np
import torch
import os.path as osp
import torch.nn as nn
import torch.optim as optim
@@ -169,6 +174,13 @@ class BaseModel(MetaModel, nn.Module):
restore_hint = self.engine_cfg['restore_hint']
if restore_hint != 0:
self.resume_ckpt(restore_hint)
elif training and self.engine_cfg.get('auto_resume_latest', False):
latest_ckpt = self._get_latest_resume_ckpt_path()
if latest_ckpt is not None:
self.msg_mgr.log_info(
"Auto-resuming from latest checkpoint %s", latest_ckpt
)
self.resume_ckpt(latest_ckpt)
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
@@ -234,23 +246,112 @@ class BaseModel(MetaModel, nn.Module):
scheduler = Scheduler(self.optimizer, **valid_arg)
return scheduler
def _build_checkpoint(self, iteration: int) -> dict[str, Any]:
checkpoint: dict[str, Any] = {
'model': self.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'iteration': iteration,
'random_state': random.getstate(),
'numpy_random_state': np.random.get_state(),
'torch_random_state': torch.get_rng_state(),
}
if torch.cuda.is_available():
checkpoint['cuda_random_state_all'] = torch.cuda.get_rng_state_all()
if self.engine_cfg.get('enable_float16', False) and hasattr(self, 'Scaler'):
checkpoint['scaler'] = self.Scaler.state_dict()
return checkpoint
def _checkpoint_dir(self) -> str:
return osp.join(self.save_path, "checkpoints")
def _resume_dir(self) -> str:
return osp.join(self._checkpoint_dir(), "resume")
def _save_checkpoint_file(
self,
checkpoint: dict[str, Any],
save_path: str,
) -> None:
mkdir(osp.dirname(save_path))
tmp_path = save_path + ".tmp"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, save_path)
def _write_resume_meta(self, iteration: int, resume_path: str) -> None:
meta_path = osp.join(self._checkpoint_dir(), "latest.json")
meta = {
"iteration": iteration,
"path": resume_path,
}
tmp_path = meta_path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(meta, handle, indent=2, sort_keys=True)
os.replace(tmp_path, meta_path)
def _prune_resume_checkpoints(self, keep_count: int) -> None:
if keep_count <= 0:
return
resume_dir = self._resume_dir()
if not osp.isdir(resume_dir):
return
prefix = f"{self.engine_cfg['save_name']}-resume-"
resume_files = sorted(
file_name for file_name in os.listdir(resume_dir)
if file_name.startswith(prefix) and file_name.endswith(".pt")
)
stale_files = resume_files[:-keep_count]
for file_name in stale_files:
os.remove(osp.join(resume_dir, file_name))
def _get_latest_resume_ckpt_path(self) -> str | None:
latest_path = osp.join(self._checkpoint_dir(), "latest.pt")
if osp.isfile(latest_path):
return latest_path
meta_path = osp.join(self._checkpoint_dir(), "latest.json")
if osp.isfile(meta_path):
with open(meta_path, "r", encoding="utf-8") as handle:
latest_meta = json.load(handle)
candidate = latest_meta.get("path")
if isinstance(candidate, str) and osp.isfile(candidate):
return candidate
return None
def save_ckpt(self, iteration):
if torch.distributed.get_rank() == 0:
mkdir(osp.join(self.save_path, "checkpoints/"))
save_name = self.engine_cfg['save_name']
checkpoint = {
'model': self.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'iteration': iteration}
torch.save(checkpoint,
osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
checkpoint = self._build_checkpoint(iteration)
ckpt_path = osp.join(
self._checkpoint_dir(),
'{}-{:0>5}.pt'.format(save_name, iteration),
)
self._save_checkpoint_file(checkpoint, ckpt_path)
def save_resume_ckpt(self, iteration: int) -> None:
if torch.distributed.get_rank() != 0:
return
checkpoint = self._build_checkpoint(iteration)
save_name = self.engine_cfg['save_name']
resume_path = osp.join(
self._resume_dir(),
f"{save_name}-resume-{iteration:0>5}.pt",
)
latest_path = osp.join(self._checkpoint_dir(), "latest.pt")
self._save_checkpoint_file(checkpoint, resume_path)
self._save_checkpoint_file(checkpoint, latest_path)
self._write_resume_meta(iteration, resume_path)
self._prune_resume_checkpoints(
int(self.engine_cfg.get('resume_keep', 3))
)
def _load_ckpt(self, save_name):
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
checkpoint = torch.load(save_name, map_location=torch.device(
"cuda", self.device))
checkpoint = torch.load(
save_name,
map_location=torch.device("cuda", self.device),
weights_only=False,
)
model_state_dict = checkpoint['model']
if not load_ckpt_strict:
@@ -271,6 +372,33 @@ class BaseModel(MetaModel, nn.Module):
else:
self.msg_mgr.log_warning(
"Restore NO Scheduler from %s !!!" % save_name)
if (
self.engine_cfg.get('enable_float16', False)
and hasattr(self, 'Scaler')
and 'scaler' in checkpoint
):
self.Scaler.load_state_dict(checkpoint['scaler'])
if 'random_state' in checkpoint:
random.setstate(checkpoint['random_state'])
if 'numpy_random_state' in checkpoint:
np.random.set_state(checkpoint['numpy_random_state'])
if 'torch_random_state' in checkpoint:
torch_random_state = checkpoint['torch_random_state']
if not isinstance(torch_random_state, torch.Tensor):
torch_random_state = torch.as_tensor(
torch_random_state,
dtype=torch.uint8,
)
torch.set_rng_state(torch_random_state.cpu())
if 'cuda_random_state_all' in checkpoint and torch.cuda.is_available():
cuda_random_state_all = checkpoint['cuda_random_state_all']
normalized_cuda_states = []
for state in cuda_random_state_all:
if not isinstance(state, torch.Tensor):
state = torch.as_tensor(state, dtype=torch.uint8)
normalized_cuda_states.append(state.cpu())
torch.cuda.set_rng_state_all(normalized_cuda_states)
self.iteration = int(checkpoint.get('iteration', self.iteration))
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
def resume_ckpt(self, restore_hint):
@@ -278,10 +406,15 @@ class BaseModel(MetaModel, nn.Module):
save_name = self.engine_cfg['save_name']
save_name = osp.join(
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
self.iteration = restore_hint
elif isinstance(restore_hint, str):
save_name = restore_hint
self.iteration = 0
if restore_hint == 'latest':
save_name = self._get_latest_resume_ckpt_path()
if save_name is None:
raise FileNotFoundError(
f"No latest checkpoint found under {self._checkpoint_dir()}"
)
else:
save_name = restore_hint
else:
raise ValueError(
"Error type for -Restore_Hint-, supported: int or string.")
@@ -417,6 +550,9 @@ class BaseModel(MetaModel, nn.Module):
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
model.msg_mgr.train_step(loss_info, visual_summary)
resume_every_iter = int(model.engine_cfg.get('resume_every_iter', 0))
if resume_every_iter > 0 and model.iteration % resume_every_iter == 0:
model.save_resume_ckpt(model.iteration)
if model.iteration % model.engine_cfg['save_iter'] == 0:
# save the checkpoint
model.save_ckpt(model.iteration)