feat: retain best checkpoints and support alternate output roots

This commit is contained in:
2026-03-11 01:14:05 +08:00
parent 63e2ed1097
commit a0150c791f
14 changed files with 852 additions and 9 deletions
+9 -3
View File
@@ -4,7 +4,14 @@ import argparse
import torch
import torch.nn as nn
from modeling import models
from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
from opengait.utils import (
config_loader,
get_ddp_module,
get_msg_mgr,
init_seeds,
params_count,
resolve_output_path,
)
parser = argparse.ArgumentParser(description='Main program for opengait.')
parser.add_argument('--local_rank', type=int, default=0,
@@ -25,8 +32,7 @@ def initialization(cfgs, training):
msg_mgr = get_msg_mgr()
engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
logger_cfg = cfgs.get('logger_cfg', {})
output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'], engine_cfg['save_name'])
output_path = resolve_output_path(cfgs, engine_cfg)
if training:
msg_mgr.init_manager(
output_path,