Files
OpenGait/opengait/utils/msg_manager.py
T

274 lines
8.3 KiB
Python

from __future__ import annotations
import atexit
import logging
import os.path as osp
import time
from time import localtime, strftime
from typing import Any
import numpy as np
import torch
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
try:
import wandb
except ImportError:
wandb = None
from .common import NoOp, Odict, is_list, is_tensor, mkdir, ts2np
class MessageManager:
def __init__(self) -> None:
self.info_dict = Odict()
self.writer_hparams = ["image", "scalar"]
self.time = time.time()
self.logger = logging.getLogger("opengait")
self.writer: SummaryWriter | None = None
self.wandb_run: Any | None = None
self.iteration = 0
self.log_iter = 1
self._close_registered = False
def init_manager(
self,
save_path: str,
log_to_file: bool,
log_iter: int,
iteration: int = 0,
logger_cfg: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
phase: str = "train",
) -> None:
self.iteration = iteration
self.log_iter = log_iter
logger_cfg = logger_cfg or {}
if logger_cfg.get("use_tensorboard", True):
mkdir(osp.join(save_path, "summary/"))
self.writer = SummaryWriter(
osp.join(save_path, "summary/"),
purge_step=self.iteration,
)
else:
self.writer = None
self.init_logger(
save_path,
log_to_file,
logger_cfg=logger_cfg,
config=config,
phase=phase,
)
def init_logger(
self,
save_path: str,
log_to_file: bool,
logger_cfg: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
phase: str = "test",
) -> None:
self.logger = logging.getLogger("opengait")
self.logger.setLevel(logging.INFO)
self.logger.propagate = False
self.logger.handlers.clear()
formatter = logging.Formatter(
fmt="[%(asctime)s] [%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
if log_to_file:
mkdir(osp.join(save_path, "logs/"))
vlog = logging.FileHandler(
osp.join(
save_path,
"logs/",
strftime("%Y-%m-%d-%H-%M-%S", localtime()) + ".txt",
)
)
vlog.setLevel(logging.INFO)
vlog.setFormatter(formatter)
self.logger.addHandler(vlog)
console = logging.StreamHandler()
console.setFormatter(formatter)
console.setLevel(logging.DEBUG)
self.logger.addHandler(console)
self.init_wandb(save_path, logger_cfg or {}, config, phase)
def init_wandb(
self,
save_path: str,
logger_cfg: dict[str, Any],
config: dict[str, Any] | None,
phase: str,
) -> None:
if not logger_cfg.get("use_wandb", False):
self.wandb_run = None
return
if wandb is None:
raise ImportError(
"wandb logging is enabled but the package is not installed. "
"Install it with `uv sync --extra wandb`."
)
data_cfg = (config or {}).get("data_cfg", {})
model_cfg = (config or {}).get("model_cfg", {})
default_name = "-".join(
[
str(data_cfg.get("dataset_name", "dataset")),
str(model_cfg.get("model", "model")),
phase,
]
)
self.wandb_run = wandb.init(
project=logger_cfg.get("wandb_project", "OpenGait"),
entity=logger_cfg.get("wandb_entity"),
name=logger_cfg.get("wandb_name", default_name),
group=logger_cfg.get("wandb_group"),
job_type=logger_cfg.get("wandb_job_type", phase),
tags=logger_cfg.get("wandb_tags", []),
mode=logger_cfg.get("wandb_mode", "online"),
resume=logger_cfg.get("wandb_resume", "allow"),
id=logger_cfg.get("wandb_id"),
dir=save_path,
config=config,
reinit=True,
)
if not self._close_registered:
atexit.register(self.close)
self._close_registered = True
def append(self, info) -> None:
for k, v in info.items():
v = [v] if not is_list(v) else v
v = [ts2np(_) if is_tensor(_) else _ for _ in v]
info[k] = v
self.info_dict.append(info)
def flush(self) -> None:
self.info_dict.clear()
if self.writer is not None:
self.writer.flush()
def write_to_tensorboard(self, summary) -> None:
if self.writer is None:
return
for k, v in summary.items():
module_name = k.split("/")[0]
if module_name not in self.writer_hparams:
self.log_warning(
"Not Expected --Summary-- type [{}] appear!!!{}".format(
k, self.writer_hparams
)
)
continue
board_name = k.replace(module_name + "/", "")
writer_module = getattr(self.writer, "add_" + module_name)
v = v.detach() if is_tensor(v) else v
v = (
vutils.make_grid(v, normalize=True, scale_each=True)
if "image" in module_name
else v
)
if module_name == "scalar":
try:
v = v.mean()
except Exception:
v = v
writer_module(board_name, v, self.iteration)
def write_to_wandb(self, summary) -> None:
if self.wandb_run is None:
return
wandb_summary = {}
for k, v in summary.items():
module_name = k.split("/")[0]
if module_name not in self.writer_hparams:
self.log_warning(
"Not Expected --Summary-- type [{}] appear!!!{}".format(
k, self.writer_hparams
)
)
continue
if is_tensor(v):
v = v.detach().cpu()
if module_name == "scalar":
if is_tensor(v):
wandb_summary[k] = float(v.mean().item())
elif isinstance(v, np.ndarray):
wandb_summary[k] = float(np.mean(v))
else:
wandb_summary[k] = float(v)
continue
grid = vutils.make_grid(v, normalize=True, scale_each=True)
wandb_summary[k] = wandb.Image(grid.permute(1, 2, 0).numpy())
if wandb_summary:
self.wandb_run.log(wandb_summary, step=self.iteration)
def log_training_info(self) -> None:
now = time.time()
string = "Iteration {:0>5}, Cost {:.2f}s".format(
self.iteration, now - self.time, end=""
)
for i, (k, v) in enumerate(self.info_dict.items()):
if "scalar" not in k:
continue
k = k.replace("scalar/", "").replace("/", "_")
end = "\n" if i == len(self.info_dict) - 1 else ""
string += ", {0}={1:.4f}".format(k, np.mean(v), end=end)
self.log_info(string)
self.reset_time()
def reset_time(self) -> None:
self.time = time.time()
def train_step(self, info, summary) -> None:
self.iteration += 1
self.append(info)
if self.iteration % self.log_iter == 0:
self.log_training_info()
self.flush()
self.write_to_tensorboard(summary)
self.write_to_wandb(summary)
def log_debug(self, *args, **kwargs) -> None:
self.logger.debug(*args, **kwargs)
def log_info(self, *args, **kwargs) -> None:
self.logger.info(*args, **kwargs)
def log_warning(self, *args, **kwargs) -> None:
self.logger.warning(*args, **kwargs)
def close(self) -> None:
if self.writer is not None:
self.writer.close()
self.writer = None
if self.wandb_run is not None:
self.wandb_run.finish()
self.wandb_run = None
msg_mgr = MessageManager()
noop = NoOp()
def get_msg_mgr():
if torch.distributed.get_rank() > 0:
return noop
else:
return msg_mgr