feat: retain best checkpoints and support alternate output roots
This commit is contained in:
+9
-3
@@ -4,7 +4,14 @@ import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from modeling import models
|
||||
from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
|
||||
from opengait.utils import (
|
||||
config_loader,
|
||||
get_ddp_module,
|
||||
get_msg_mgr,
|
||||
init_seeds,
|
||||
params_count,
|
||||
resolve_output_path,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Main program for opengait.')
|
||||
parser.add_argument('--local_rank', type=int, default=0,
|
||||
@@ -25,8 +32,7 @@ def initialization(cfgs, training):
|
||||
msg_mgr = get_msg_mgr()
|
||||
engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
|
||||
logger_cfg = cfgs.get('logger_cfg', {})
|
||||
output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'],
|
||||
cfgs['model_cfg']['model'], engine_cfg['save_name'])
|
||||
output_path = resolve_output_path(cfgs, engine_cfg)
|
||||
if training:
|
||||
msg_mgr.init_manager(
|
||||
output_path,
|
||||
|
||||
Reference in New Issue
Block a user