feat: retain best checkpoints and support alternate output roots
This commit is contained in:
@@ -7,4 +7,5 @@ from .common import mkdir, clones
|
||||
from .common import MergeCfgsDict
|
||||
from .common import get_attr_from
|
||||
from .common import NoOp
|
||||
from .msg_manager import get_msg_mgr
|
||||
from .common import resolve_output_path
|
||||
from .msg_manager import get_msg_mgr
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import os
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
@@ -203,3 +204,19 @@ def get_ddp_module(module, find_unused_parameters=False, **kwargs):
|
||||
def params_count(net):
|
||||
n_parameters = sum(p.numel() for p in net.parameters())
|
||||
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
|
||||
|
||||
|
||||
def resolve_output_path(cfgs, engine_cfg):
|
||||
output_root = (
|
||||
engine_cfg.get('output_root')
|
||||
or cfgs.get('output_root')
|
||||
or os.environ.get('OPENGAIT_OUTPUT_ROOT')
|
||||
or 'output'
|
||||
)
|
||||
output_root = str(Path(output_root).expanduser())
|
||||
return os.path.join(
|
||||
output_root,
|
||||
cfgs['data_cfg']['dataset_name'],
|
||||
cfgs['model_cfg']['model'],
|
||||
engine_cfg['save_name'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user