Files
OpenGait/opengait/modeling/base_model.py
T

757 lines
29 KiB
Python

"""The base model definition.
This module defines the abstract meta model class and base model class. In the base model,
we define the basic model functions, like get_loader, build_network, and run_train, etc.
The api of the base model is run_train and run_test, they are used in `opengait/main.py`.
Typical usage:
BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import json
import math
import os
import random
import re
from typing import Any
import numpy as np
import torch
import os.path as osp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as tordata
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
from abc import ABCMeta
from abc import abstractmethod
from . import backbones
from .loss_aggregator import LossAggregator
from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from opengait.utils import Odict, mkdir, ddp_all_gather, resolve_output_path
from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions
from opengait.utils import NoOp
from opengait.utils import get_msg_mgr
__all__ = ['BaseModel']
class MetaModel(metaclass=ABCMeta):
"""The necessary functions for the base model.
This class defines the necessary functions for the base model, in the base model, we have implemented them.
"""
@abstractmethod
def get_loader(self, data_cfg):
"""Based on the given data_cfg, we get the data loader."""
raise NotImplementedError
@abstractmethod
def build_network(self, model_cfg):
"""Build your network here."""
raise NotImplementedError
@abstractmethod
def init_parameters(self):
"""Initialize the parameters of your network."""
raise NotImplementedError
@abstractmethod
def get_optimizer(self, optimizer_cfg):
"""Based on the given optimizer_cfg, we get the optimizer."""
raise NotImplementedError
@abstractmethod
def get_scheduler(self, scheduler_cfg):
"""Based on the given scheduler_cfg, we get the scheduler."""
raise NotImplementedError
@abstractmethod
def save_ckpt(self, iteration):
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def resume_ckpt(self, restore_hint):
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def inputs_pretreament(self, inputs):
"""Transform the input data based on transform setting."""
raise NotImplementedError
@abstractmethod
def train_step(self, loss_num) -> bool:
"""Do one training step."""
raise NotImplementedError
@abstractmethod
def inference(self):
"""Do inference (calculate features.)."""
raise NotImplementedError
@abstractmethod
def run_train(model):
"""Run a whole train schedule."""
raise NotImplementedError
@abstractmethod
def run_test(model):
"""Run a whole test schedule."""
raise NotImplementedError
class BaseModel(MetaModel, nn.Module):
"""Base model.
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
Attributes:
msg_mgr: the massage manager.
cfgs: the configs.
iteration: the current iteration of the model.
engine_cfg: the configs of the engine(train or test).
save_path: the path to save the checkpoints.
"""
def __init__(self, cfgs, training):
"""Initialize the base model.
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
Args:
cfgs:
All of the configs.
training:
Whether the model is in training mode.
"""
super(BaseModel, self).__init__()
self.msg_mgr = get_msg_mgr()
self.cfgs = cfgs
self.iteration = 0
self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
if self.engine_cfg is None:
raise Exception("Initialize a model without -Engine-Cfgs-")
if training and self.engine_cfg['enable_float16']:
self.Scaler = GradScaler()
self.save_path = resolve_output_path(cfgs, self.engine_cfg)
self.build_network(cfgs['model_cfg'])
self.init_parameters()
self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform'])
self.msg_mgr.log_info(cfgs['data_cfg'])
if training:
self.train_loader = self.get_loader(
cfgs['data_cfg'], train=True)
if not training or self.engine_cfg['with_test']:
self.test_loader = self.get_loader(
cfgs['data_cfg'], train=False)
self.evaluator_trfs = get_transform(
cfgs['evaluator_cfg']['transform'])
self.device = torch.distributed.get_rank()
torch.cuda.set_device(self.device)
self.to(device=torch.device(
"cuda", self.device))
if training:
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
self.train(training)
restore_hint = self.engine_cfg['restore_hint']
if restore_hint != 0:
self.resume_ckpt(restore_hint)
elif training and self.engine_cfg.get('auto_resume_latest', False):
latest_ckpt = self._get_latest_resume_ckpt_path()
if latest_ckpt is not None:
self.msg_mgr.log_info(
"Auto-resuming from latest checkpoint %s", latest_ckpt
)
self.resume_ckpt(latest_ckpt)
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
def build_network(self, model_cfg):
if 'backbone_cfg' in model_cfg.keys():
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
def init_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
if m.affine:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
def get_loader(self, data_cfg, train=True):
sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler']
dataset = DataSet(data_cfg, train)
Sampler = get_attr_from([Samplers], sampler_cfg['type'])
vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[
'sample_type', 'type'])
sampler = Sampler(dataset, **vaild_args)
loader = tordata.DataLoader(
dataset=dataset,
batch_sampler=sampler,
collate_fn=CollateFn(dataset.label_set, sampler_cfg),
num_workers=data_cfg['num_workers'])
return loader
def get_optimizer(self, optimizer_cfg):
self.msg_mgr.log_info(optimizer_cfg)
optimizer = get_attr_from([optim], optimizer_cfg['solver'])
valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
optimizer = optimizer(
filter(lambda p: p.requires_grad, self.parameters()), **valid_arg)
return optimizer
def get_scheduler(self, scheduler_cfg):
self.msg_mgr.log_info(scheduler_cfg)
Scheduler = get_attr_from(
[optim.lr_scheduler], scheduler_cfg['scheduler'])
valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler'])
scheduler = Scheduler(self.optimizer, **valid_arg)
return scheduler
def _build_checkpoint(self, iteration: int) -> dict[str, Any]:
checkpoint: dict[str, Any] = {
'model': self.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'iteration': iteration,
'random_state': random.getstate(),
'numpy_random_state': np.random.get_state(),
'torch_random_state': torch.get_rng_state(),
}
if torch.cuda.is_available():
checkpoint['cuda_random_state_all'] = torch.cuda.get_rng_state_all()
if self.engine_cfg.get('enable_float16', False) and hasattr(self, 'Scaler'):
checkpoint['scaler'] = self.Scaler.state_dict()
return checkpoint
def _checkpoint_dir(self) -> str:
return osp.join(self.save_path, "checkpoints")
def _resume_dir(self) -> str:
return osp.join(self._checkpoint_dir(), "resume")
def _save_checkpoint_file(
self,
checkpoint: dict[str, Any],
save_path: str,
) -> None:
mkdir(osp.dirname(save_path))
tmp_path = save_path + ".tmp"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, save_path)
def _write_resume_meta(self, iteration: int, resume_path: str) -> None:
meta_path = osp.join(self._checkpoint_dir(), "latest.json")
meta = {
"iteration": iteration,
"path": resume_path,
}
tmp_path = meta_path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(meta, handle, indent=2, sort_keys=True)
os.replace(tmp_path, meta_path)
def _prune_resume_checkpoints(self, keep_count: int) -> None:
if keep_count <= 0:
return
resume_dir = self._resume_dir()
if not osp.isdir(resume_dir):
return
prefix = f"{self.engine_cfg['save_name']}-resume-"
resume_files = sorted(
file_name for file_name in os.listdir(resume_dir)
if file_name.startswith(prefix) and file_name.endswith(".pt")
)
stale_files = resume_files[:-keep_count]
for file_name in stale_files:
os.remove(osp.join(resume_dir, file_name))
def _get_latest_resume_ckpt_path(self) -> str | None:
latest_path = osp.join(self._checkpoint_dir(), "latest.pt")
if osp.isfile(latest_path):
return latest_path
meta_path = osp.join(self._checkpoint_dir(), "latest.json")
if osp.isfile(meta_path):
with open(meta_path, "r", encoding="utf-8") as handle:
latest_meta = json.load(handle)
candidate = latest_meta.get("path")
if isinstance(candidate, str) and osp.isfile(candidate):
return candidate
return None
def _best_ckpt_cfg(self) -> dict[str, Any] | None:
best_ckpt_cfg = self.engine_cfg.get('best_ckpt_cfg')
if not isinstance(best_ckpt_cfg, dict):
return None
keep_n = int(best_ckpt_cfg.get('keep_n', 0))
metric_names = best_ckpt_cfg.get('metric_names', [])
if keep_n <= 0 or not isinstance(metric_names, list) or not metric_names:
return None
return best_ckpt_cfg
def _best_ckpt_root(self) -> str:
return osp.join(self._checkpoint_dir(), "best")
def _best_metric_dir(self, metric_name: str) -> str:
metric_slug = re.sub(r"[^A-Za-z0-9_.-]+", "_", metric_name).strip("._")
return osp.join(self._best_ckpt_root(), metric_slug)
def _best_metric_index_path(self, metric_name: str) -> str:
return osp.join(self._best_metric_dir(metric_name), "index.json")
def _load_best_metric_index(self, metric_name: str) -> list[dict[str, Any]]:
index_path = self._best_metric_index_path(metric_name)
if not osp.isfile(index_path):
return []
with open(index_path, "r", encoding="utf-8") as handle:
raw_entries = json.load(handle)
if not isinstance(raw_entries, list):
return []
entries: list[dict[str, Any]] = []
for entry in raw_entries:
if not isinstance(entry, dict):
continue
path = entry.get("path")
if isinstance(path, str) and osp.isfile(path):
entries.append(entry)
return entries
def _write_best_metric_index(
self,
metric_name: str,
entries: list[dict[str, Any]],
) -> None:
index_path = self._best_metric_index_path(metric_name)
mkdir(osp.dirname(index_path))
tmp_path = index_path + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(entries, handle, indent=2, sort_keys=True)
os.replace(tmp_path, index_path)
def _summary_scalar(self, value: Any) -> float | None:
if isinstance(value, torch.Tensor):
return float(value.detach().float().mean().item())
if isinstance(value, np.ndarray):
return float(np.mean(value))
if isinstance(value, (float, int, np.floating, np.integer)):
return float(value)
return None
def _save_best_ckpts(
self,
iteration: int,
result_dict: dict[str, Any],
) -> None:
if torch.distributed.get_rank() != 0:
return
best_ckpt_cfg = self._best_ckpt_cfg()
if best_ckpt_cfg is None:
return
keep_n = int(best_ckpt_cfg['keep_n'])
metric_names = [metric for metric in best_ckpt_cfg['metric_names'] if metric in result_dict]
if not metric_names:
return
checkpoint: dict[str, Any] | None = None
save_name = self.engine_cfg['save_name']
for metric_name in metric_names:
score = self._summary_scalar(result_dict.get(metric_name))
if score is None or not math.isfinite(score):
continue
entries = [
entry for entry in self._load_best_metric_index(metric_name)
if int(entry.get("iteration", -1)) != iteration
]
ranked_entries = sorted(
entries + [{"iteration": iteration, "score": score, "path": ""}],
key=lambda entry: (float(entry["score"]), int(entry["iteration"])),
reverse=True,
)
kept_entries = ranked_entries[:keep_n]
if not any(int(entry["iteration"]) == iteration for entry in kept_entries):
continue
metric_dir = self._best_metric_dir(metric_name)
mkdir(metric_dir)
metric_slug = osp.basename(metric_dir)
best_path = osp.join(
metric_dir,
f"{save_name}-iter-{iteration:0>5}-score-{score:.4f}-{metric_slug}.pt",
)
if checkpoint is None:
checkpoint = self._build_checkpoint(iteration)
self._save_checkpoint_file(checkpoint, best_path)
refreshed_entries = []
for entry in kept_entries:
if int(entry["iteration"]) == iteration:
refreshed_entries.append(
{
"iteration": iteration,
"score": score,
"path": best_path,
}
)
else:
refreshed_entries.append(entry)
keep_paths = {entry["path"] for entry in refreshed_entries if isinstance(entry.get("path"), str)}
for stale_entry in entries:
stale_path = stale_entry.get("path")
if isinstance(stale_path, str) and stale_path not in keep_paths and osp.isfile(stale_path):
os.remove(stale_path)
self._write_best_metric_index(metric_name, refreshed_entries)
def save_ckpt(self, iteration):
if torch.distributed.get_rank() == 0:
save_name = self.engine_cfg['save_name']
checkpoint = self._build_checkpoint(iteration)
ckpt_path = osp.join(
self._checkpoint_dir(),
'{}-{:0>5}.pt'.format(save_name, iteration),
)
self._save_checkpoint_file(checkpoint, ckpt_path)
def save_resume_ckpt(self, iteration: int) -> None:
if torch.distributed.get_rank() != 0:
return
checkpoint = self._build_checkpoint(iteration)
save_name = self.engine_cfg['save_name']
resume_path = osp.join(
self._resume_dir(),
f"{save_name}-resume-{iteration:0>5}.pt",
)
latest_path = osp.join(self._checkpoint_dir(), "latest.pt")
self._save_checkpoint_file(checkpoint, resume_path)
self._save_checkpoint_file(checkpoint, latest_path)
self._write_resume_meta(iteration, resume_path)
self._prune_resume_checkpoints(
int(self.engine_cfg.get('resume_keep', 3))
)
def _load_ckpt(self, save_name):
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
checkpoint = torch.load(
save_name,
map_location=torch.device("cuda", self.device),
weights_only=False,
)
model_state_dict = checkpoint['model']
current_state_keys = set(self.state_dict().keys())
stale_loss_keys = [
key for key in model_state_dict.keys()
if key.startswith("loss_aggregator.") and key not in current_state_keys
]
for key in stale_loss_keys:
model_state_dict.pop(key)
if stale_loss_keys:
self.msg_mgr.log_warning(
"Ignoring stale loss state from %s: %s"
% (save_name, stale_loss_keys)
)
if not load_ckpt_strict:
self.msg_mgr.log_info("-------- Restored Params List --------")
self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection(
set(self.state_dict().keys()))))
self.load_state_dict(model_state_dict, strict=load_ckpt_strict)
if self.training:
if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
else:
self.msg_mgr.log_warning(
"Restore NO Optimizer from %s !!!" % save_name)
if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint:
self.scheduler.load_state_dict(
checkpoint['scheduler'])
else:
self.msg_mgr.log_warning(
"Restore NO Scheduler from %s !!!" % save_name)
if (
self.engine_cfg.get('enable_float16', False)
and hasattr(self, 'Scaler')
and 'scaler' in checkpoint
):
self.Scaler.load_state_dict(checkpoint['scaler'])
if 'random_state' in checkpoint:
random.setstate(checkpoint['random_state'])
if 'numpy_random_state' in checkpoint:
np.random.set_state(checkpoint['numpy_random_state'])
if 'torch_random_state' in checkpoint:
torch_random_state = checkpoint['torch_random_state']
if not isinstance(torch_random_state, torch.Tensor):
torch_random_state = torch.as_tensor(
torch_random_state,
dtype=torch.uint8,
)
torch.set_rng_state(torch_random_state.cpu())
if 'cuda_random_state_all' in checkpoint and torch.cuda.is_available():
cuda_random_state_all = checkpoint['cuda_random_state_all']
normalized_cuda_states = []
for state in cuda_random_state_all:
if not isinstance(state, torch.Tensor):
state = torch.as_tensor(state, dtype=torch.uint8)
normalized_cuda_states.append(state.cpu())
torch.cuda.set_rng_state_all(normalized_cuda_states)
self.iteration = int(checkpoint.get('iteration', self.iteration))
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
def resume_ckpt(self, restore_hint):
if isinstance(restore_hint, int):
save_name = self.engine_cfg['save_name']
save_name = osp.join(
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
elif isinstance(restore_hint, str):
if restore_hint == 'latest':
save_name = self._get_latest_resume_ckpt_path()
if save_name is None:
raise FileNotFoundError(
f"No latest checkpoint found under {self._checkpoint_dir()}"
)
else:
save_name = restore_hint
else:
raise ValueError(
"Error type for -Restore_Hint-, supported: int or string.")
self._load_ckpt(save_name)
def fix_BN(self):
for module in self.modules():
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1:
module.eval()
def inputs_pretreament(self, inputs):
"""Conduct transforms on input data.
Args:
inputs: the input data.
Returns:
tuple: training data including inputs, labels, and some meta data.
"""
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs
if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
requires_grad = bool(self.training)
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
for trf, seq in zip(seq_trfs, seqs_batch)]
typs = typs_batch
vies = vies_batch
labs = list2var(labs_batch).long()
if seqL_batch is not None:
seqL_batch = np2var(seqL_batch).int()
seqL = seqL_batch
if seqL is not None:
seqL_sum = int(seqL.sum().data.cpu().numpy())
ipts = [_[:, :seqL_sum] for _ in seqs]
else:
ipts = seqs
del seqs
return ipts, labs, typs, vies, seqL
def train_step(self, loss_sum) -> bool:
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
Args:
loss_sum:The loss of the current batch.
Returns:
bool: True if the training is finished, False otherwise.
"""
self.optimizer.zero_grad()
if loss_sum <= 1e-9:
self.msg_mgr.log_warning(
"Find the loss sum less than 1e-9 but the training process will continue!")
if self.engine_cfg['enable_float16']:
self.Scaler.scale(loss_sum).backward()
self.Scaler.step(self.optimizer)
scale = self.Scaler.get_scale()
self.Scaler.update()
# Warning caused by optimizer skip when NaN
# https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
if scale != self.Scaler.get_scale():
self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
scale, self.Scaler.get_scale()))
return False
else:
loss_sum.backward()
self.optimizer.step()
self.iteration += 1
self.scheduler.step()
return True
def inference(self, rank):
"""Inference all the test data.
Args:
rank: the rank of the current process.Transform
Returns:
Odict: contains the inference results.
"""
total_size = len(self.test_loader)
if rank == 0:
pbar = tqdm(total=total_size, desc='Transforming')
else:
pbar = NoOp()
batch_size = self.test_loader.batch_sampler.batch_size
rest_size = total_size
info_dict = Odict()
for inputs in self.test_loader:
ipts = self.inputs_pretreament(inputs)
with autocast(enabled=self.engine_cfg['enable_float16']):
retval = self.forward(ipts)
inference_feat = retval['inference_feat']
for k, v in inference_feat.items():
inference_feat[k] = ddp_all_gather(v, requires_grad=False)
del retval
for k, v in inference_feat.items():
inference_feat[k] = ts2np(v)
info_dict.append(inference_feat)
rest_size -= batch_size
if rest_size >= 0:
update_size = batch_size
else:
update_size = total_size % batch_size
pbar.update(update_size)
pbar.close()
for k, v in info_dict.items():
v = np.concatenate(v)[:total_size]
info_dict[k] = v
return info_dict
@ staticmethod
def run_train(model):
"""Accept the instance object(model) here, and then run the train loop."""
for inputs in model.train_loader:
ipts = model.inputs_pretreament(inputs)
with autocast(enabled=model.engine_cfg['enable_float16']):
retval = model(ipts)
training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
del retval
loss_sum, loss_info = model.loss_aggregator(training_feat)
ok = model.train_step(loss_sum)
if not ok:
continue
visual_summary.update(loss_info)
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
model.msg_mgr.train_step(loss_info, visual_summary)
resume_every_iter = int(model.engine_cfg.get('resume_every_iter', 0))
if resume_every_iter > 0 and model.iteration % resume_every_iter == 0:
model.save_resume_ckpt(model.iteration)
save_iter = int(model.engine_cfg['save_iter'])
eval_iter = int(model.engine_cfg.get('eval_iter', 0))
should_save = save_iter > 0 and model.iteration % save_iter == 0
should_eval = False
if model.engine_cfg['with_test']:
if eval_iter > 0:
should_eval = model.iteration % eval_iter == 0
else:
should_eval = should_save
if should_save:
# save the checkpoint
model.save_ckpt(model.iteration)
if should_eval:
model.msg_mgr.log_info("Running test...")
model.eval()
result_dict = BaseModel.run_test(model)
model.train()
if model.cfgs['trainer_cfg']['fix_BN']:
model.fix_BN()
if result_dict:
model.msg_mgr.write_to_tensorboard(result_dict)
model.msg_mgr.write_to_wandb(result_dict)
model._save_best_ckpts(model.iteration, result_dict)
model.msg_mgr.reset_time()
if model.iteration >= model.engine_cfg['total_iter']:
break
@ staticmethod
def run_test(model):
"""Accept the instance object(model) here, and then run the test loop."""
evaluator_cfg = model.cfgs['evaluator_cfg']
if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
rank = torch.distributed.get_rank()
with torch.no_grad():
info_dict = model.inference(rank)
if rank == 0:
loader = model.test_loader
label_list = loader.dataset.label_list
types_list = loader.dataset.types_list
views_list = loader.dataset.views_list
info_dict.update({
'labels': label_list, 'types': types_list, 'views': views_list})
if 'eval_func' in evaluator_cfg.keys():
eval_func = evaluator_cfg["eval_func"]
else:
eval_func = 'identification'
eval_func = getattr(eval_functions, eval_func)
valid_args = get_valid_args(
eval_func, evaluator_cfg, ['metric'])
try:
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
except:
dataset_name = model.cfgs['data_cfg']['dataset_name']
return eval_func(info_dict, dataset_name, **valid_args)