import os import argparse import torch import torch.nn as nn from modeling import models from opengait.utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr parser = argparse.ArgumentParser(description='Main program for opengait.') parser.add_argument('--local_rank', type=int, default=0, help="passed by torch.distributed.launch module") parser.add_argument('--local-rank', type=int, default=0, help="passed by torch.distributed.launch module, for pytorch >=2.0") parser.add_argument('--cfgs', type=str, default='config/default.yaml', help="path of config file") parser.add_argument('--phase', default='train', choices=['train', 'test'], help="choose train or test phase") parser.add_argument('--log_to_file', action='store_true', help="log to file, default path is: output/////.txt") parser.add_argument('--iter', default=0, help="iter to restore") opt = parser.parse_args() def initialization(cfgs, training): msg_mgr = get_msg_mgr() engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] logger_cfg = cfgs.get('logger_cfg', {}) output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'], cfgs['model_cfg']['model'], engine_cfg['save_name']) if training: msg_mgr.init_manager( output_path, opt.log_to_file, engine_cfg['log_iter'], engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0, logger_cfg=logger_cfg, config=cfgs, phase='train', ) else: msg_mgr.init_logger( output_path, opt.log_to_file, logger_cfg=logger_cfg, config=cfgs, phase='test', ) msg_mgr.log_info(engine_cfg) seed = torch.distributed.get_rank() init_seeds(seed) def run_model(cfgs, training): msg_mgr = get_msg_mgr() model_cfg = cfgs['model_cfg'] msg_mgr.log_info(model_cfg) Model = getattr(models, model_cfg['model']) model = Model(cfgs, training) if training and cfgs['trainer_cfg']['sync_BN']: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if cfgs['trainer_cfg']['fix_BN']: model.fix_BN() model = get_ddp_module(model, cfgs['trainer_cfg']['find_unused_parameters']) msg_mgr.log_info(params_count(model)) msg_mgr.log_info("Model Initialization Finished!") if training: Model.run_train(model) else: result_dict = Model.run_test(model) if result_dict: msg_mgr.write_to_tensorboard(result_dict) msg_mgr.write_to_wandb(result_dict) if __name__ == '__main__': torch.distributed.init_process_group('nccl', init_method='env://') if torch.distributed.get_world_size() != torch.cuda.device_count(): raise ValueError("Expect number of available GPUs({}) equals to the world size({}).".format( torch.cuda.device_count(), torch.distributed.get_world_size())) cfgs = config_loader(opt.cfgs) if opt.iter != 0: cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter) cfgs['trainer_cfg']['restore_hint'] = int(opt.iter) training = (opt.phase == 'train') initialization(cfgs, training) run_model(cfgs, training)