67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
|
|
import os
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
from modeling import models
|
|
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
|
|
|
|
parser = argparse.ArgumentParser(description='Main program for opengait.')
|
|
parser.add_argument('--local_rank', type=int, default=0,
|
|
help="passed by torch.distributed.launch module")
|
|
parser.add_argument('--cfgs', type=str,
|
|
default='config/default.yaml', help="path of config file")
|
|
parser.add_argument('--phase', default='train',
|
|
choices=['train', 'test'], help="choose train or test phase")
|
|
parser.add_argument('--log_to_file', action='store_true',
|
|
help="log to file, default path is: output/<dataset>/<model>/<save_name>/<logs>/<Datetime>.txt")
|
|
parser.add_argument('--iter', default=0, help="iter to restore")
|
|
opt = parser.parse_args()
|
|
|
|
|
|
def initialization(cfgs, training):
|
|
msg_mgr = get_msg_mgr()
|
|
engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
|
|
output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'],
|
|
cfgs['model_cfg']['model'], engine_cfg['save_name'])
|
|
msg_mgr.init_manager(output_path, opt.log_to_file, engine_cfg['log_iter'] if training else 0,
|
|
engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0)
|
|
|
|
msg_mgr.log_info(engine_cfg)
|
|
|
|
seed = torch.distributed.get_rank()
|
|
init_seeds(seed)
|
|
|
|
|
|
def run_model(cfgs, training):
|
|
msg_mgr = get_msg_mgr()
|
|
model_cfg = cfgs['model_cfg']
|
|
msg_mgr.log_info(model_cfg)
|
|
Model = getattr(models, model_cfg['model'])
|
|
model = Model(cfgs, training)
|
|
if training and cfgs['trainer_cfg']['sync_BN']:
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
model = get_ddp_module(model)
|
|
msg_mgr.log_info(params_count(model))
|
|
msg_mgr.log_info("Model Initialization Finished!")
|
|
|
|
if training:
|
|
Model.run_train(model)
|
|
else:
|
|
Model.run_test(model)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
torch.distributed.init_process_group('nccl', init_method='env://')
|
|
if torch.distributed.get_world_size() != torch.cuda.device_count():
|
|
raise AssertionError("Expect number of availuable GPUs({}) equals to the world size({}).".format(
|
|
torch.distributed.get_world_size(), torch.cuda.device_count()))
|
|
cfgs = config_loader(opt.cfgs)
|
|
if opt.iter != 0:
|
|
cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter)
|
|
cfgs['trainer_cfg']['restore_hint'] = int(opt.iter)
|
|
|
|
training = (opt.phase == 'train')
|
|
initialization(cfgs, training)
|
|
run_model(cfgs, training)
|