Add proxy eval and skeleton experiment tooling
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
import os
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import torch.utils.data as tordata
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import random
|
||||
from typing import TypeVar
|
||||
|
||||
import torch.utils.data as tordata
|
||||
from opengait.utils import get_msg_mgr
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DataSet(tordata.Dataset):
|
||||
def __init__(self, data_cfg, training):
|
||||
@@ -66,6 +71,33 @@ class DataSet(tordata.Dataset):
|
||||
for idx in range(len(self)):
|
||||
self.__getitem__(idx)
|
||||
|
||||
@staticmethod
|
||||
def _sample_items(
|
||||
items: list[T],
|
||||
subset_size: int | None,
|
||||
subset_seed: int,
|
||||
subset_name: str,
|
||||
msg_mgr=None,
|
||||
) -> list[T]:
|
||||
if subset_size is None:
|
||||
return items
|
||||
if subset_size <= 0:
|
||||
raise ValueError(f"{subset_name} must be positive, got {subset_size}")
|
||||
if subset_size >= len(items):
|
||||
return items
|
||||
|
||||
sampled_items = random.Random(subset_seed).sample(items, subset_size)
|
||||
sampled_items.sort()
|
||||
if msg_mgr is not None:
|
||||
msg_mgr.log_info(
|
||||
"Using %s subset: %d / %d items (seed=%d)",
|
||||
subset_name,
|
||||
len(sampled_items),
|
||||
len(items),
|
||||
subset_seed,
|
||||
)
|
||||
return sampled_items
|
||||
|
||||
def __dataset_parser(self, data_config, training):
|
||||
dataset_root = data_config['dataset_root']
|
||||
try:
|
||||
@@ -80,9 +112,16 @@ class DataSet(tordata.Dataset):
|
||||
label_list = os.listdir(dataset_root)
|
||||
train_set = [label for label in train_set if label in label_list]
|
||||
test_set = [label for label in test_set if label in label_list]
|
||||
msg_mgr = get_msg_mgr()
|
||||
train_set = self._sample_items(
|
||||
train_set,
|
||||
data_config.get("train_pid_subset_size"),
|
||||
int(data_config.get("train_pid_subset_seed", 0)),
|
||||
"train pid",
|
||||
msg_mgr,
|
||||
)
|
||||
miss_pids = [label for label in label_list if label not in (
|
||||
train_set + test_set)]
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
def log_pid_list(pid_list):
|
||||
if len(pid_list) >= 3:
|
||||
@@ -121,5 +160,14 @@ class DataSet(tordata.Dataset):
|
||||
'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie))
|
||||
return seqs_info_list
|
||||
|
||||
self.seqs_info = get_seqs_info_list(
|
||||
train_set) if training else get_seqs_info_list(test_set)
|
||||
if training:
|
||||
self.seqs_info = get_seqs_info_list(train_set)
|
||||
else:
|
||||
self.seqs_info = get_seqs_info_list(test_set)
|
||||
self.seqs_info = self._sample_items(
|
||||
self.seqs_info,
|
||||
data_config.get("test_seq_subset_size"),
|
||||
int(data_config.get("test_seq_subset_seed", 0)),
|
||||
"test sequence",
|
||||
msg_mgr,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user