rename lib to opengait
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
import copy
|
||||
import os
|
||||
import inspect
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.autograd as autograd
|
||||
import yaml
|
||||
import random
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
|
||||
class NoOp:
|
||||
def __getattr__(self, *args):
|
||||
def no_op(*args, **kwargs): pass
|
||||
return no_op
|
||||
|
||||
|
||||
class Odict(OrderedDict):
|
||||
def append(self, odict):
|
||||
dst_keys = self.keys()
|
||||
for k, v in odict.items():
|
||||
if not is_list(v):
|
||||
v = [v]
|
||||
if k in dst_keys:
|
||||
if is_list(self[k]):
|
||||
self[k] += v
|
||||
else:
|
||||
self[k] = [self[k]] + v
|
||||
else:
|
||||
self[k] = v
|
||||
|
||||
|
||||
def Ntuple(description, keys, values):
|
||||
if not is_list_or_tuple(keys):
|
||||
keys = [keys]
|
||||
values = [values]
|
||||
Tuple = namedtuple(description, keys)
|
||||
return Tuple._make(values)
|
||||
|
||||
|
||||
def get_valid_args(obj, input_args, free_keys=[]):
|
||||
if inspect.isfunction(obj):
|
||||
expected_keys = inspect.getargspec(obj)[0]
|
||||
elif inspect.isclass(obj):
|
||||
expected_keys = inspect.getargspec(obj.__init__)[0]
|
||||
else:
|
||||
raise ValueError('Just support function and class object!')
|
||||
unexpect_keys = list()
|
||||
expected_args = {}
|
||||
for k, v in input_args.items():
|
||||
if k in expected_keys:
|
||||
expected_args[k] = v
|
||||
elif k in free_keys:
|
||||
pass
|
||||
else:
|
||||
unexpect_keys.append(k)
|
||||
if unexpect_keys != []:
|
||||
logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" %
|
||||
(', '.join(unexpect_keys), obj.__name__))
|
||||
return expected_args
|
||||
|
||||
|
||||
def get_attr_from(sources, name):
|
||||
try:
|
||||
return getattr(sources[0], name)
|
||||
except:
|
||||
return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name)
|
||||
|
||||
|
||||
def is_list_or_tuple(x):
|
||||
return isinstance(x, (list, tuple))
|
||||
|
||||
|
||||
def is_bool(x):
|
||||
return isinstance(x, bool)
|
||||
|
||||
|
||||
def is_str(x):
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
def is_list(x):
|
||||
return isinstance(x, list) or isinstance(x, nn.ModuleList)
|
||||
|
||||
|
||||
def is_dict(x):
|
||||
return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict)
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
|
||||
def is_array(x):
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
def ts2np(x):
|
||||
return x.cpu().data.numpy()
|
||||
|
||||
|
||||
def ts2var(x, **kwargs):
|
||||
return autograd.Variable(x, **kwargs).cuda()
|
||||
|
||||
|
||||
def np2var(x, **kwargs):
|
||||
return ts2var(torch.from_numpy(x), **kwargs)
|
||||
|
||||
|
||||
def list2var(x, **kwargs):
|
||||
return np2var(np.array(x), **kwargs)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def MergeCfgsDict(src, dst):
|
||||
for k, v in src.items():
|
||||
if (k not in dst.keys()) or (type(v) != type(dict())):
|
||||
dst[k] = v
|
||||
else:
|
||||
if is_dict(src[k]) and is_dict(dst[k]):
|
||||
MergeCfgsDict(src[k], dst[k])
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
|
||||
def clones(module, N):
|
||||
"Produce N identical layers."
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
||||
|
||||
|
||||
def config_loader(path):
|
||||
with open(path, 'r') as stream:
|
||||
src_cfgs = yaml.safe_load(stream)
|
||||
with open("./config/default.yaml", 'r') as stream:
|
||||
dst_cfgs = yaml.safe_load(stream)
|
||||
MergeCfgsDict(src_cfgs, dst_cfgs)
|
||||
return dst_cfgs
|
||||
|
||||
|
||||
def init_seeds(seed=0, cuda_deterministic=True):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
||||
if cuda_deterministic: # slower, more reproducible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
else: # faster, less reproducible
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def handler(signum, frame):
|
||||
logging.info('Ctrl+c/z pressed')
|
||||
os.system(
|
||||
"kill $(ps aux | grep main.py | grep -v grep | awk '{print $2}') ")
|
||||
logging.info('process group flush!')
|
||||
|
||||
|
||||
def ddp_all_gather(features, dim=0, requires_grad=True):
|
||||
'''
|
||||
inputs: [n, ...]
|
||||
'''
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
feature_list = [torch.ones_like(features) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(feature_list, features.contiguous())
|
||||
|
||||
if requires_grad:
|
||||
feature_list[rank] = features
|
||||
feature = torch.cat(feature_list, dim=dim)
|
||||
return feature
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/16885
|
||||
class DDPPassthrough(DDP):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
def get_ddp_module(module, **kwargs):
|
||||
if len(list(module.parameters())) == 0:
|
||||
# for the case that loss module has not parameters.
|
||||
return module
|
||||
device = torch.cuda.current_device()
|
||||
module = DDPPassthrough(module, device_ids=[device], output_device=device,
|
||||
find_unused_parameters=False, **kwargs)
|
||||
return module
|
||||
|
||||
|
||||
def params_count(net):
|
||||
n_parameters = sum(p.numel() for p in net.parameters())
|
||||
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
|
||||
Reference in New Issue
Block a user