Add DRF Scoliosis1K pipeline and optional wandb logging

This commit is contained in:
2026-03-07 17:49:19 +08:00
parent 654409ff50
commit 51eee70a4b
12 changed files with 1257 additions and 151 deletions
+192 -40
View File
@@ -1,41 +1,93 @@
from __future__ import annotations
import atexit
import logging
import os.path as osp
import time
import torch
from time import localtime, strftime
from typing import Any
import numpy as np
import torch
import torchvision.utils as vutils
import os.path as osp
from time import strftime, localtime
from torch.utils.tensorboard import SummaryWriter
from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp
import logging
try:
import wandb
except ImportError:
wandb = None
from .common import NoOp, Odict, is_list, is_tensor, mkdir, ts2np
class MessageManager:
def __init__(self):
def __init__(self) -> None:
self.info_dict = Odict()
self.writer_hparams = ['image', 'scalar']
self.writer_hparams = ["image", "scalar"]
self.time = time.time()
self.logger = logging.getLogger("opengait")
self.writer: SummaryWriter | None = None
self.wandb_run: Any | None = None
self.iteration = 0
self.log_iter = 1
self._close_registered = False
def init_manager(self, save_path, log_to_file, log_iter, iteration=0):
def init_manager(
self,
save_path: str,
log_to_file: bool,
log_iter: int,
iteration: int = 0,
logger_cfg: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
phase: str = "train",
) -> None:
self.iteration = iteration
self.log_iter = log_iter
mkdir(osp.join(save_path, "summary/"))
self.writer = SummaryWriter(
osp.join(save_path, "summary/"), purge_step=self.iteration)
self.init_logger(save_path, log_to_file)
def init_logger(self, save_path, log_to_file):
# init logger
self.logger = logging.getLogger('opengait')
logger_cfg = logger_cfg or {}
if logger_cfg.get("use_tensorboard", True):
mkdir(osp.join(save_path, "summary/"))
self.writer = SummaryWriter(
osp.join(save_path, "summary/"),
purge_step=self.iteration,
)
else:
self.writer = None
self.init_logger(
save_path,
log_to_file,
logger_cfg=logger_cfg,
config=config,
phase=phase,
)
def init_logger(
self,
save_path: str,
log_to_file: bool,
logger_cfg: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
phase: str = "test",
) -> None:
self.logger = logging.getLogger("opengait")
self.logger.setLevel(logging.INFO)
self.logger.propagate = False
self.logger.handlers.clear()
formatter = logging.Formatter(
fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
fmt="[%(asctime)s] [%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
if log_to_file:
mkdir(osp.join(save_path, "logs/"))
vlog = logging.FileHandler(
osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime())+'.txt'))
osp.join(
save_path,
"logs/",
strftime("%Y-%m-%d-%H-%M-%S", localtime()) + ".txt",
)
)
vlog.setLevel(logging.INFO)
vlog.setFormatter(formatter)
self.logger.addHandler(vlog)
@@ -45,70 +97,170 @@ class MessageManager:
console.setLevel(logging.DEBUG)
self.logger.addHandler(console)
def append(self, info):
self.init_wandb(save_path, logger_cfg or {}, config, phase)
def init_wandb(
self,
save_path: str,
logger_cfg: dict[str, Any],
config: dict[str, Any] | None,
phase: str,
) -> None:
if not logger_cfg.get("use_wandb", False):
self.wandb_run = None
return
if wandb is None:
raise ImportError(
"wandb logging is enabled but the package is not installed. "
"Install it with `uv sync --extra wandb`."
)
data_cfg = (config or {}).get("data_cfg", {})
model_cfg = (config or {}).get("model_cfg", {})
default_name = "-".join(
[
str(data_cfg.get("dataset_name", "dataset")),
str(model_cfg.get("model", "model")),
phase,
]
)
self.wandb_run = wandb.init(
project=logger_cfg.get("wandb_project", "OpenGait"),
entity=logger_cfg.get("wandb_entity"),
name=logger_cfg.get("wandb_name", default_name),
group=logger_cfg.get("wandb_group"),
job_type=logger_cfg.get("wandb_job_type", phase),
tags=logger_cfg.get("wandb_tags", []),
mode=logger_cfg.get("wandb_mode", "online"),
resume=logger_cfg.get("wandb_resume", "allow"),
id=logger_cfg.get("wandb_id"),
dir=save_path,
config=config,
reinit=True,
)
if not self._close_registered:
atexit.register(self.close)
self._close_registered = True
def append(self, info) -> None:
for k, v in info.items():
v = [v] if not is_list(v) else v
v = [ts2np(_) if is_tensor(_) else _ for _ in v]
info[k] = v
self.info_dict.append(info)
def flush(self):
def flush(self) -> None:
self.info_dict.clear()
self.writer.flush()
if self.writer is not None:
self.writer.flush()
def write_to_tensorboard(self, summary):
def write_to_tensorboard(self, summary) -> None:
if self.writer is None:
return
for k, v in summary.items():
module_name = k.split('/')[0]
module_name = k.split("/")[0]
if module_name not in self.writer_hparams:
self.log_warning(
'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams))
"Not Expected --Summary-- type [{}] appear!!!{}".format(
k, self.writer_hparams
)
)
continue
board_name = k.replace(module_name + "/", '')
writer_module = getattr(self.writer, 'add_' + module_name)
board_name = k.replace(module_name + "/", "")
writer_module = getattr(self.writer, "add_" + module_name)
v = v.detach() if is_tensor(v) else v
v = vutils.make_grid(
v, normalize=True, scale_each=True) if 'image' in module_name else v
if module_name == 'scalar':
v = (
vutils.make_grid(v, normalize=True, scale_each=True)
if "image" in module_name
else v
)
if module_name == "scalar":
try:
v = v.mean()
except:
except Exception:
v = v
writer_module(board_name, v, self.iteration)
def log_training_info(self):
def write_to_wandb(self, summary) -> None:
if self.wandb_run is None:
return
wandb_summary = {}
for k, v in summary.items():
module_name = k.split("/")[0]
if module_name not in self.writer_hparams:
self.log_warning(
"Not Expected --Summary-- type [{}] appear!!!{}".format(
k, self.writer_hparams
)
)
continue
if is_tensor(v):
v = v.detach().cpu()
if module_name == "scalar":
if is_tensor(v):
wandb_summary[k] = float(v.mean().item())
elif isinstance(v, np.ndarray):
wandb_summary[k] = float(np.mean(v))
else:
wandb_summary[k] = float(v)
continue
grid = vutils.make_grid(v, normalize=True, scale_each=True)
wandb_summary[k] = wandb.Image(grid.permute(1, 2, 0).numpy())
if wandb_summary:
self.wandb_run.log(wandb_summary, step=self.iteration)
def log_training_info(self) -> None:
now = time.time()
string = "Iteration {:0>5}, Cost {:.2f}s".format(
self.iteration, now-self.time, end="")
self.iteration, now - self.time, end=""
)
for i, (k, v) in enumerate(self.info_dict.items()):
if 'scalar' not in k:
if "scalar" not in k:
continue
k = k.replace('scalar/', '').replace('/', '_')
end = "\n" if i == len(self.info_dict)-1 else ""
k = k.replace("scalar/", "").replace("/", "_")
end = "\n" if i == len(self.info_dict) - 1 else ""
string += ", {0}={1:.4f}".format(k, np.mean(v), end=end)
self.log_info(string)
self.reset_time()
def reset_time(self):
def reset_time(self) -> None:
self.time = time.time()
def train_step(self, info, summary):
def train_step(self, info, summary) -> None:
self.iteration += 1
self.append(info)
if self.iteration % self.log_iter == 0:
self.log_training_info()
self.flush()
self.write_to_tensorboard(summary)
self.write_to_wandb(summary)
def log_debug(self, *args, **kwargs):
def log_debug(self, *args, **kwargs) -> None:
self.logger.debug(*args, **kwargs)
def log_info(self, *args, **kwargs):
def log_info(self, *args, **kwargs) -> None:
self.logger.info(*args, **kwargs)
def log_warning(self, *args, **kwargs):
def log_warning(self, *args, **kwargs) -> None:
self.logger.warning(*args, **kwargs)
def close(self) -> None:
if self.writer is not None:
self.writer.close()
self.writer = None
if self.wandb_run is not None:
self.wandb_run.finish()
self.wandb_run = None
msg_mgr = MessageManager()
noop = NoOp()