Files
OpenGait/opengait/modeling/losses/ce.py
T

73 lines
2.1 KiB
Python

from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn.functional as F
from opengait.utils.common import Odict
from .base import BaseLoss
class CrossEntropyLoss(BaseLoss):
scale: float
label_smooth: bool
eps: float
log_accuracy: bool
def __init__(
self,
scale: float = 2**4,
label_smooth: bool = True,
eps: float = 0.1,
loss_term_weight: float = 1.0,
log_accuracy: bool = False,
class_weight: Sequence[float] | None = None,
) -> None:
super(CrossEntropyLoss, self).__init__(loss_term_weight)
self.scale = scale
self.label_smooth = label_smooth
self.eps = eps
self.log_accuracy = log_accuracy
weight_tensor = (
None
if class_weight is None
else torch.as_tensor(class_weight, dtype=torch.float32)
)
self.register_buffer("_class_weight", weight_tensor)
@property
def class_weight(self) -> torch.Tensor | None:
buffer = self._buffers.get("_class_weight")
return buffer if isinstance(buffer, torch.Tensor) else None
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
logits: [n, c, p]
labels: [n]
"""
_n, _c, p = logits.size()
logits = logits.float()
labels = labels.unsqueeze(1)
class_weight = self.class_weight
if self.label_smooth:
loss = F.cross_entropy(
logits * self.scale,
labels.repeat(1, p),
weight=class_weight,
label_smoothing=self.eps,
)
else:
loss = F.cross_entropy(
logits * self.scale,
labels.repeat(1, p),
weight=class_weight,
)
self.info.update({'loss': loss.detach().clone()})
if self.log_accuracy:
pred = logits.argmax(dim=1) # [n, p]
accu = (pred == labels).float().mean()
self.info.update({'accuracy': accu})
return loss, self.info