Add resumable ScoNet skeleton training diagnostics
This commit is contained in:
+150
-14
@@ -9,8 +9,13 @@ Typical usage:
|
||||
BaseModel.run_train(model)
|
||||
BaseModel.run_test(model)
|
||||
"""
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import os.path as osp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
@@ -169,6 +174,13 @@ class BaseModel(MetaModel, nn.Module):
|
||||
restore_hint = self.engine_cfg['restore_hint']
|
||||
if restore_hint != 0:
|
||||
self.resume_ckpt(restore_hint)
|
||||
elif training and self.engine_cfg.get('auto_resume_latest', False):
|
||||
latest_ckpt = self._get_latest_resume_ckpt_path()
|
||||
if latest_ckpt is not None:
|
||||
self.msg_mgr.log_info(
|
||||
"Auto-resuming from latest checkpoint %s", latest_ckpt
|
||||
)
|
||||
self.resume_ckpt(latest_ckpt)
|
||||
|
||||
def get_backbone(self, backbone_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
@@ -234,23 +246,112 @@ class BaseModel(MetaModel, nn.Module):
|
||||
scheduler = Scheduler(self.optimizer, **valid_arg)
|
||||
return scheduler
|
||||
|
||||
def _build_checkpoint(self, iteration: int) -> dict[str, Any]:
|
||||
checkpoint: dict[str, Any] = {
|
||||
'model': self.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict(),
|
||||
'iteration': iteration,
|
||||
'random_state': random.getstate(),
|
||||
'numpy_random_state': np.random.get_state(),
|
||||
'torch_random_state': torch.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
checkpoint['cuda_random_state_all'] = torch.cuda.get_rng_state_all()
|
||||
if self.engine_cfg.get('enable_float16', False) and hasattr(self, 'Scaler'):
|
||||
checkpoint['scaler'] = self.Scaler.state_dict()
|
||||
return checkpoint
|
||||
|
||||
def _checkpoint_dir(self) -> str:
|
||||
return osp.join(self.save_path, "checkpoints")
|
||||
|
||||
def _resume_dir(self) -> str:
|
||||
return osp.join(self._checkpoint_dir(), "resume")
|
||||
|
||||
def _save_checkpoint_file(
|
||||
self,
|
||||
checkpoint: dict[str, Any],
|
||||
save_path: str,
|
||||
) -> None:
|
||||
mkdir(osp.dirname(save_path))
|
||||
tmp_path = save_path + ".tmp"
|
||||
torch.save(checkpoint, tmp_path)
|
||||
os.replace(tmp_path, save_path)
|
||||
|
||||
def _write_resume_meta(self, iteration: int, resume_path: str) -> None:
|
||||
meta_path = osp.join(self._checkpoint_dir(), "latest.json")
|
||||
meta = {
|
||||
"iteration": iteration,
|
||||
"path": resume_path,
|
||||
}
|
||||
tmp_path = meta_path + ".tmp"
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(meta, handle, indent=2, sort_keys=True)
|
||||
os.replace(tmp_path, meta_path)
|
||||
|
||||
def _prune_resume_checkpoints(self, keep_count: int) -> None:
|
||||
if keep_count <= 0:
|
||||
return
|
||||
resume_dir = self._resume_dir()
|
||||
if not osp.isdir(resume_dir):
|
||||
return
|
||||
prefix = f"{self.engine_cfg['save_name']}-resume-"
|
||||
resume_files = sorted(
|
||||
file_name for file_name in os.listdir(resume_dir)
|
||||
if file_name.startswith(prefix) and file_name.endswith(".pt")
|
||||
)
|
||||
stale_files = resume_files[:-keep_count]
|
||||
for file_name in stale_files:
|
||||
os.remove(osp.join(resume_dir, file_name))
|
||||
|
||||
def _get_latest_resume_ckpt_path(self) -> str | None:
|
||||
latest_path = osp.join(self._checkpoint_dir(), "latest.pt")
|
||||
if osp.isfile(latest_path):
|
||||
return latest_path
|
||||
meta_path = osp.join(self._checkpoint_dir(), "latest.json")
|
||||
if osp.isfile(meta_path):
|
||||
with open(meta_path, "r", encoding="utf-8") as handle:
|
||||
latest_meta = json.load(handle)
|
||||
candidate = latest_meta.get("path")
|
||||
if isinstance(candidate, str) and osp.isfile(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def save_ckpt(self, iteration):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
mkdir(osp.join(self.save_path, "checkpoints/"))
|
||||
save_name = self.engine_cfg['save_name']
|
||||
checkpoint = {
|
||||
'model': self.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict(),
|
||||
'iteration': iteration}
|
||||
torch.save(checkpoint,
|
||||
osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
|
||||
checkpoint = self._build_checkpoint(iteration)
|
||||
ckpt_path = osp.join(
|
||||
self._checkpoint_dir(),
|
||||
'{}-{:0>5}.pt'.format(save_name, iteration),
|
||||
)
|
||||
self._save_checkpoint_file(checkpoint, ckpt_path)
|
||||
|
||||
def save_resume_ckpt(self, iteration: int) -> None:
|
||||
if torch.distributed.get_rank() != 0:
|
||||
return
|
||||
checkpoint = self._build_checkpoint(iteration)
|
||||
save_name = self.engine_cfg['save_name']
|
||||
resume_path = osp.join(
|
||||
self._resume_dir(),
|
||||
f"{save_name}-resume-{iteration:0>5}.pt",
|
||||
)
|
||||
latest_path = osp.join(self._checkpoint_dir(), "latest.pt")
|
||||
self._save_checkpoint_file(checkpoint, resume_path)
|
||||
self._save_checkpoint_file(checkpoint, latest_path)
|
||||
self._write_resume_meta(iteration, resume_path)
|
||||
self._prune_resume_checkpoints(
|
||||
int(self.engine_cfg.get('resume_keep', 3))
|
||||
)
|
||||
|
||||
def _load_ckpt(self, save_name):
|
||||
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
|
||||
|
||||
checkpoint = torch.load(save_name, map_location=torch.device(
|
||||
"cuda", self.device))
|
||||
checkpoint = torch.load(
|
||||
save_name,
|
||||
map_location=torch.device("cuda", self.device),
|
||||
weights_only=False,
|
||||
)
|
||||
model_state_dict = checkpoint['model']
|
||||
|
||||
if not load_ckpt_strict:
|
||||
@@ -271,6 +372,33 @@ class BaseModel(MetaModel, nn.Module):
|
||||
else:
|
||||
self.msg_mgr.log_warning(
|
||||
"Restore NO Scheduler from %s !!!" % save_name)
|
||||
if (
|
||||
self.engine_cfg.get('enable_float16', False)
|
||||
and hasattr(self, 'Scaler')
|
||||
and 'scaler' in checkpoint
|
||||
):
|
||||
self.Scaler.load_state_dict(checkpoint['scaler'])
|
||||
if 'random_state' in checkpoint:
|
||||
random.setstate(checkpoint['random_state'])
|
||||
if 'numpy_random_state' in checkpoint:
|
||||
np.random.set_state(checkpoint['numpy_random_state'])
|
||||
if 'torch_random_state' in checkpoint:
|
||||
torch_random_state = checkpoint['torch_random_state']
|
||||
if not isinstance(torch_random_state, torch.Tensor):
|
||||
torch_random_state = torch.as_tensor(
|
||||
torch_random_state,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
torch.set_rng_state(torch_random_state.cpu())
|
||||
if 'cuda_random_state_all' in checkpoint and torch.cuda.is_available():
|
||||
cuda_random_state_all = checkpoint['cuda_random_state_all']
|
||||
normalized_cuda_states = []
|
||||
for state in cuda_random_state_all:
|
||||
if not isinstance(state, torch.Tensor):
|
||||
state = torch.as_tensor(state, dtype=torch.uint8)
|
||||
normalized_cuda_states.append(state.cpu())
|
||||
torch.cuda.set_rng_state_all(normalized_cuda_states)
|
||||
self.iteration = int(checkpoint.get('iteration', self.iteration))
|
||||
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
||||
|
||||
def resume_ckpt(self, restore_hint):
|
||||
@@ -278,10 +406,15 @@ class BaseModel(MetaModel, nn.Module):
|
||||
save_name = self.engine_cfg['save_name']
|
||||
save_name = osp.join(
|
||||
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
|
||||
self.iteration = restore_hint
|
||||
elif isinstance(restore_hint, str):
|
||||
save_name = restore_hint
|
||||
self.iteration = 0
|
||||
if restore_hint == 'latest':
|
||||
save_name = self._get_latest_resume_ckpt_path()
|
||||
if save_name is None:
|
||||
raise FileNotFoundError(
|
||||
f"No latest checkpoint found under {self._checkpoint_dir()}"
|
||||
)
|
||||
else:
|
||||
save_name = restore_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
"Error type for -Restore_Hint-, supported: int or string.")
|
||||
@@ -417,6 +550,9 @@ class BaseModel(MetaModel, nn.Module):
|
||||
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
|
||||
|
||||
model.msg_mgr.train_step(loss_info, visual_summary)
|
||||
resume_every_iter = int(model.engine_cfg.get('resume_every_iter', 0))
|
||||
if resume_every_iter > 0 and model.iteration % resume_every_iter == 0:
|
||||
model.save_resume_ckpt(model.iteration)
|
||||
if model.iteration % model.engine_cfg['save_iter'] == 0:
|
||||
# save the checkpoint
|
||||
model.save_ckpt(model.iteration)
|
||||
|
||||
Reference in New Issue
Block a user