Add proxy eval and skeleton experiment tooling
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
import os
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import torch.utils.data as tordata
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import random
|
||||
from typing import TypeVar
|
||||
|
||||
import torch.utils.data as tordata
|
||||
from opengait.utils import get_msg_mgr
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DataSet(tordata.Dataset):
|
||||
def __init__(self, data_cfg, training):
|
||||
@@ -66,6 +71,33 @@ class DataSet(tordata.Dataset):
|
||||
for idx in range(len(self)):
|
||||
self.__getitem__(idx)
|
||||
|
||||
@staticmethod
|
||||
def _sample_items(
|
||||
items: list[T],
|
||||
subset_size: int | None,
|
||||
subset_seed: int,
|
||||
subset_name: str,
|
||||
msg_mgr=None,
|
||||
) -> list[T]:
|
||||
if subset_size is None:
|
||||
return items
|
||||
if subset_size <= 0:
|
||||
raise ValueError(f"{subset_name} must be positive, got {subset_size}")
|
||||
if subset_size >= len(items):
|
||||
return items
|
||||
|
||||
sampled_items = random.Random(subset_seed).sample(items, subset_size)
|
||||
sampled_items.sort()
|
||||
if msg_mgr is not None:
|
||||
msg_mgr.log_info(
|
||||
"Using %s subset: %d / %d items (seed=%d)",
|
||||
subset_name,
|
||||
len(sampled_items),
|
||||
len(items),
|
||||
subset_seed,
|
||||
)
|
||||
return sampled_items
|
||||
|
||||
def __dataset_parser(self, data_config, training):
|
||||
dataset_root = data_config['dataset_root']
|
||||
try:
|
||||
@@ -80,9 +112,16 @@ class DataSet(tordata.Dataset):
|
||||
label_list = os.listdir(dataset_root)
|
||||
train_set = [label for label in train_set if label in label_list]
|
||||
test_set = [label for label in test_set if label in label_list]
|
||||
msg_mgr = get_msg_mgr()
|
||||
train_set = self._sample_items(
|
||||
train_set,
|
||||
data_config.get("train_pid_subset_size"),
|
||||
int(data_config.get("train_pid_subset_seed", 0)),
|
||||
"train pid",
|
||||
msg_mgr,
|
||||
)
|
||||
miss_pids = [label for label in label_list if label not in (
|
||||
train_set + test_set)]
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
def log_pid_list(pid_list):
|
||||
if len(pid_list) >= 3:
|
||||
@@ -121,5 +160,14 @@ class DataSet(tordata.Dataset):
|
||||
'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie))
|
||||
return seqs_info_list
|
||||
|
||||
self.seqs_info = get_seqs_info_list(
|
||||
train_set) if training else get_seqs_info_list(test_set)
|
||||
if training:
|
||||
self.seqs_info = get_seqs_info_list(train_set)
|
||||
else:
|
||||
self.seqs_info = get_seqs_info_list(test_set)
|
||||
self.seqs_info = self._sample_items(
|
||||
self.seqs_info,
|
||||
data_config.get("test_seq_subset_size"),
|
||||
int(data_config.get("test_seq_subset_seed", 0)),
|
||||
"test sequence",
|
||||
msg_mgr,
|
||||
)
|
||||
|
||||
@@ -553,22 +553,31 @@ class BaseModel(MetaModel, nn.Module):
|
||||
resume_every_iter = int(model.engine_cfg.get('resume_every_iter', 0))
|
||||
if resume_every_iter > 0 and model.iteration % resume_every_iter == 0:
|
||||
model.save_resume_ckpt(model.iteration)
|
||||
if model.iteration % model.engine_cfg['save_iter'] == 0:
|
||||
save_iter = int(model.engine_cfg['save_iter'])
|
||||
eval_iter = int(model.engine_cfg.get('eval_iter', 0))
|
||||
should_save = save_iter > 0 and model.iteration % save_iter == 0
|
||||
should_eval = False
|
||||
if model.engine_cfg['with_test']:
|
||||
if eval_iter > 0:
|
||||
should_eval = model.iteration % eval_iter == 0
|
||||
else:
|
||||
should_eval = should_save
|
||||
|
||||
if should_save:
|
||||
# save the checkpoint
|
||||
model.save_ckpt(model.iteration)
|
||||
|
||||
# run test if with_test = true
|
||||
if model.engine_cfg['with_test']:
|
||||
model.msg_mgr.log_info("Running test...")
|
||||
model.eval()
|
||||
result_dict = BaseModel.run_test(model)
|
||||
model.train()
|
||||
if model.cfgs['trainer_cfg']['fix_BN']:
|
||||
model.fix_BN()
|
||||
if result_dict:
|
||||
model.msg_mgr.write_to_tensorboard(result_dict)
|
||||
model.msg_mgr.write_to_wandb(result_dict)
|
||||
model.msg_mgr.reset_time()
|
||||
if should_eval:
|
||||
model.msg_mgr.log_info("Running test...")
|
||||
model.eval()
|
||||
result_dict = BaseModel.run_test(model)
|
||||
model.train()
|
||||
if model.cfgs['trainer_cfg']['fix_BN']:
|
||||
model.fix_BN()
|
||||
if result_dict:
|
||||
model.msg_mgr.write_to_tensorboard(result_dict)
|
||||
model.msg_mgr.write_to_wandb(result_dict)
|
||||
model.msg_mgr.reset_time()
|
||||
if model.iteration >= model.engine_cfg['total_iter']:
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user