Files
OpenGait/opengait/modeling/losses/base.py
T

68 lines
1.8 KiB
Python

from __future__ import annotations
from collections.abc import Callable
from ctypes import ArgumentError
from typing import Any
import torch
import torch.nn as nn
from opengait.utils import Odict
import functools
from opengait.utils import ddp_all_gather
def gather_and_scale_wrapper(
func: Callable[..., tuple[torch.Tensor | float, Odict]],
) -> Callable[..., tuple[torch.Tensor | float, Odict]]:
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
"""
@functools.wraps(func)
def inner(*args: Any, **kwds: Any) -> tuple[torch.Tensor | float, Odict]:
try:
for k, v in kwds.items():
kwds[k] = ddp_all_gather(v)
loss, loss_info = func(*args, **kwds)
loss *= torch.distributed.get_world_size()
return loss, loss_info
except:
raise ArgumentError
return inner
class BaseLoss(nn.Module):
"""
Base class for all losses.
Your loss should also subclass this class.
"""
def __init__(self, loss_term_weight: float = 1.0) -> None:
"""
Initialize the base class.
Args:
loss_term_weight: the weight of the loss term.
"""
super(BaseLoss, self).__init__()
self.loss_term_weight = loss_term_weight
self.info = Odict()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
The default forward function.
This function should be overridden by the subclass.
Args:
logits: the logits of the model.
labels: the labels of the data.
Returns:
tuple of loss and info.
"""
return torch.tensor(0.0, device=logits.device), self.info