Add proxy eval and skeleton experiment tooling

This commit is contained in:
2026-03-09 23:11:35 +08:00
parent 36aef46a0d
commit 6c8cd2950c
16 changed files with 1107 additions and 69 deletions
+55 -7
View File
@@ -1,10 +1,15 @@
import os
import pickle
import os.path as osp
import torch.utils.data as tordata
import json
import os
import os.path as osp
import pickle
import random
from typing import TypeVar
import torch.utils.data as tordata
from opengait.utils import get_msg_mgr
T = TypeVar("T")
class DataSet(tordata.Dataset):
def __init__(self, data_cfg, training):
@@ -66,6 +71,33 @@ class DataSet(tordata.Dataset):
for idx in range(len(self)):
self.__getitem__(idx)
@staticmethod
def _sample_items(
items: list[T],
subset_size: int | None,
subset_seed: int,
subset_name: str,
msg_mgr=None,
) -> list[T]:
if subset_size is None:
return items
if subset_size <= 0:
raise ValueError(f"{subset_name} must be positive, got {subset_size}")
if subset_size >= len(items):
return items
sampled_items = random.Random(subset_seed).sample(items, subset_size)
sampled_items.sort()
if msg_mgr is not None:
msg_mgr.log_info(
"Using %s subset: %d / %d items (seed=%d)",
subset_name,
len(sampled_items),
len(items),
subset_seed,
)
return sampled_items
def __dataset_parser(self, data_config, training):
dataset_root = data_config['dataset_root']
try:
@@ -80,9 +112,16 @@ class DataSet(tordata.Dataset):
label_list = os.listdir(dataset_root)
train_set = [label for label in train_set if label in label_list]
test_set = [label for label in test_set if label in label_list]
msg_mgr = get_msg_mgr()
train_set = self._sample_items(
train_set,
data_config.get("train_pid_subset_size"),
int(data_config.get("train_pid_subset_seed", 0)),
"train pid",
msg_mgr,
)
miss_pids = [label for label in label_list if label not in (
train_set + test_set)]
msg_mgr = get_msg_mgr()
def log_pid_list(pid_list):
if len(pid_list) >= 3:
@@ -121,5 +160,14 @@ class DataSet(tordata.Dataset):
'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie))
return seqs_info_list
self.seqs_info = get_seqs_info_list(
train_set) if training else get_seqs_info_list(test_set)
if training:
self.seqs_info = get_seqs_info_list(train_set)
else:
self.seqs_info = get_seqs_info_list(test_set)
self.seqs_info = self._sample_items(
self.seqs_info,
data_config.get("test_seq_subset_size"),
int(data_config.get("test_seq_subset_seed", 0)),
"test sequence",
msg_mgr,
)