from __future__ import annotations from collections.abc import Sequence import torch import torch.nn.functional as F from opengait.utils.common import Odict from .base import BaseLoss class CrossEntropyLoss(BaseLoss): scale: float label_smooth: bool eps: float log_accuracy: bool def __init__( self, scale: float = 2**4, label_smooth: bool = True, eps: float = 0.1, loss_term_weight: float = 1.0, log_accuracy: bool = False, class_weight: Sequence[float] | None = None, ) -> None: super(CrossEntropyLoss, self).__init__(loss_term_weight) self.scale = scale self.label_smooth = label_smooth self.eps = eps self.log_accuracy = log_accuracy weight_tensor = ( None if class_weight is None else torch.as_tensor(class_weight, dtype=torch.float32) ) self.register_buffer("_class_weight", weight_tensor) @property def class_weight(self) -> torch.Tensor | None: buffer = self._buffers.get("_class_weight") return buffer if isinstance(buffer, torch.Tensor) else None def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]: """ logits: [n, c, p] labels: [n] """ _n, _c, p = logits.size() logits = logits.float() labels = labels.unsqueeze(1) class_weight = self.class_weight if self.label_smooth: loss = F.cross_entropy( logits * self.scale, labels.repeat(1, p), weight=class_weight, label_smoothing=self.eps, ) else: loss = F.cross_entropy( logits * self.scale, labels.repeat(1, p), weight=class_weight, ) self.info.update({'loss': loss.detach().clone()}) if self.log_accuracy: pred = logits.argmax(dim=1) # [n, p] accu = (pred == labels).float().mean() self.info.update({'accuracy': accu}) return loss, self.info