from __future__ import annotations import atexit import logging import os.path as osp import time from time import localtime, strftime from typing import Any import numpy as np import torch import torchvision.utils as vutils from torch.utils.tensorboard import SummaryWriter try: import wandb except ImportError: wandb = None from .common import NoOp, Odict, is_list, is_tensor, mkdir, ts2np class MessageManager: def __init__(self) -> None: self.info_dict = Odict() self.writer_hparams = ["image", "scalar"] self.time = time.time() self.logger = logging.getLogger("opengait") self.writer: SummaryWriter | None = None self.wandb_run: Any | None = None self.iteration = 0 self.log_iter = 1 self._close_registered = False def init_manager( self, save_path: str, log_to_file: bool, log_iter: int, iteration: int = 0, logger_cfg: dict[str, Any] | None = None, config: dict[str, Any] | None = None, phase: str = "train", ) -> None: self.iteration = iteration self.log_iter = log_iter logger_cfg = logger_cfg or {} if logger_cfg.get("use_tensorboard", True): mkdir(osp.join(save_path, "summary/")) self.writer = SummaryWriter( osp.join(save_path, "summary/"), purge_step=self.iteration, ) else: self.writer = None self.init_logger( save_path, log_to_file, logger_cfg=logger_cfg, config=config, phase=phase, ) def init_logger( self, save_path: str, log_to_file: bool, logger_cfg: dict[str, Any] | None = None, config: dict[str, Any] | None = None, phase: str = "test", ) -> None: self.logger = logging.getLogger("opengait") self.logger.setLevel(logging.INFO) self.logger.propagate = False self.logger.handlers.clear() formatter = logging.Formatter( fmt="[%(asctime)s] [%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) if log_to_file: mkdir(osp.join(save_path, "logs/")) vlog = logging.FileHandler( osp.join( save_path, "logs/", strftime("%Y-%m-%d-%H-%M-%S", localtime()) + ".txt", ) ) vlog.setLevel(logging.INFO) vlog.setFormatter(formatter) self.logger.addHandler(vlog) console = logging.StreamHandler() console.setFormatter(formatter) console.setLevel(logging.DEBUG) self.logger.addHandler(console) self.init_wandb(save_path, logger_cfg or {}, config, phase) def init_wandb( self, save_path: str, logger_cfg: dict[str, Any], config: dict[str, Any] | None, phase: str, ) -> None: if not logger_cfg.get("use_wandb", False): self.wandb_run = None return if wandb is None: raise ImportError( "wandb logging is enabled but the package is not installed. " "Install it with `uv sync --extra wandb`." ) data_cfg = (config or {}).get("data_cfg", {}) model_cfg = (config or {}).get("model_cfg", {}) default_name = "-".join( [ str(data_cfg.get("dataset_name", "dataset")), str(model_cfg.get("model", "model")), phase, ] ) self.wandb_run = wandb.init( project=logger_cfg.get("wandb_project", "OpenGait"), entity=logger_cfg.get("wandb_entity"), name=logger_cfg.get("wandb_name", default_name), group=logger_cfg.get("wandb_group"), job_type=logger_cfg.get("wandb_job_type", phase), tags=logger_cfg.get("wandb_tags", []), mode=logger_cfg.get("wandb_mode", "online"), resume=logger_cfg.get("wandb_resume", "allow"), id=logger_cfg.get("wandb_id"), dir=save_path, config=config, reinit=True, ) if not self._close_registered: atexit.register(self.close) self._close_registered = True def append(self, info) -> None: for k, v in info.items(): v = [v] if not is_list(v) else v v = [ts2np(_) if is_tensor(_) else _ for _ in v] info[k] = v self.info_dict.append(info) def flush(self) -> None: self.info_dict.clear() if self.writer is not None: self.writer.flush() def write_to_tensorboard(self, summary) -> None: if self.writer is None: return for k, v in summary.items(): module_name = k.split("/")[0] if module_name not in self.writer_hparams: self.log_warning( "Not Expected --Summary-- type [{}] appear!!!{}".format( k, self.writer_hparams ) ) continue board_name = k.replace(module_name + "/", "") writer_module = getattr(self.writer, "add_" + module_name) v = v.detach() if is_tensor(v) else v v = ( vutils.make_grid(v, normalize=True, scale_each=True) if "image" in module_name else v ) if module_name == "scalar": try: v = v.mean() except Exception: v = v writer_module(board_name, v, self.iteration) def write_to_wandb(self, summary) -> None: if self.wandb_run is None: return wandb_summary = {} for k, v in summary.items(): module_name = k.split("/")[0] if module_name not in self.writer_hparams: self.log_warning( "Not Expected --Summary-- type [{}] appear!!!{}".format( k, self.writer_hparams ) ) continue if is_tensor(v): v = v.detach().cpu() if module_name == "scalar": if is_tensor(v): wandb_summary[k] = float(v.mean().item()) elif isinstance(v, np.ndarray): wandb_summary[k] = float(np.mean(v)) else: wandb_summary[k] = float(v) continue grid = vutils.make_grid(v, normalize=True, scale_each=True) wandb_summary[k] = wandb.Image(grid.permute(1, 2, 0).numpy()) if wandb_summary: self.wandb_run.log(wandb_summary, step=self.iteration) def log_training_info(self) -> None: now = time.time() string = "Iteration {:0>5}, Cost {:.2f}s".format( self.iteration, now - self.time, end="" ) for i, (k, v) in enumerate(self.info_dict.items()): if "scalar" not in k: continue k = k.replace("scalar/", "").replace("/", "_") end = "\n" if i == len(self.info_dict) - 1 else "" string += ", {0}={1:.4f}".format(k, np.mean(v), end=end) self.log_info(string) self.reset_time() def reset_time(self) -> None: self.time = time.time() def train_step(self, info, summary) -> None: self.iteration += 1 self.append(info) if self.iteration % self.log_iter == 0: self.log_training_info() self.flush() self.write_to_tensorboard(summary) self.write_to_wandb(summary) def log_debug(self, *args, **kwargs) -> None: self.logger.debug(*args, **kwargs) def log_info(self, *args, **kwargs) -> None: self.logger.info(*args, **kwargs) def log_warning(self, *args, **kwargs) -> None: self.logger.warning(*args, **kwargs) def close(self) -> None: if self.writer is not None: self.writer.close() self.writer = None if self.wandb_run is not None: self.wandb_run.finish() self.wandb_run = None msg_mgr = MessageManager() noop = NoOp() def get_msg_mgr(): if torch.distributed.get_rank() > 0: return noop else: return msg_mgr