from __future__ import annotations from collections.abc import Callable from ctypes import ArgumentError from typing import Any import torch import torch.nn as nn from opengait.utils import Odict import functools from opengait.utils import ddp_all_gather def gather_and_scale_wrapper( func: Callable[..., tuple[torch.Tensor | float, Odict]], ) -> Callable[..., tuple[torch.Tensor | float, Odict]]: """Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards. """ @functools.wraps(func) def inner(*args: Any, **kwds: Any) -> tuple[torch.Tensor | float, Odict]: try: for k, v in kwds.items(): kwds[k] = ddp_all_gather(v) loss, loss_info = func(*args, **kwds) loss *= torch.distributed.get_world_size() return loss, loss_info except: raise ArgumentError return inner class BaseLoss(nn.Module): """ Base class for all losses. Your loss should also subclass this class. """ def __init__(self, loss_term_weight: float = 1.0) -> None: """ Initialize the base class. Args: loss_term_weight: the weight of the loss term. """ super(BaseLoss, self).__init__() self.loss_term_weight = loss_term_weight self.info = Odict() def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]: """ The default forward function. This function should be overridden by the subclass. Args: logits: the logits of the model. labels: the labels of the data. Returns: tuple of loss and info. """ return torch.tensor(0.0, device=logits.device), self.info