73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from opengait.utils.common import Odict
|
|
|
|
from .base import BaseLoss
|
|
|
|
|
|
class CrossEntropyLoss(BaseLoss):
|
|
scale: float
|
|
label_smooth: bool
|
|
eps: float
|
|
log_accuracy: bool
|
|
|
|
def __init__(
|
|
self,
|
|
scale: float = 2**4,
|
|
label_smooth: bool = True,
|
|
eps: float = 0.1,
|
|
loss_term_weight: float = 1.0,
|
|
log_accuracy: bool = False,
|
|
class_weight: Sequence[float] | None = None,
|
|
) -> None:
|
|
super(CrossEntropyLoss, self).__init__(loss_term_weight)
|
|
self.scale = scale
|
|
self.label_smooth = label_smooth
|
|
self.eps = eps
|
|
self.log_accuracy = log_accuracy
|
|
weight_tensor = (
|
|
None
|
|
if class_weight is None
|
|
else torch.as_tensor(class_weight, dtype=torch.float32)
|
|
)
|
|
self.register_buffer("_class_weight", weight_tensor)
|
|
|
|
@property
|
|
def class_weight(self) -> torch.Tensor | None:
|
|
buffer = self._buffers.get("_class_weight")
|
|
return buffer if isinstance(buffer, torch.Tensor) else None
|
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
|
"""
|
|
logits: [n, c, p]
|
|
labels: [n]
|
|
"""
|
|
_n, _c, p = logits.size()
|
|
logits = logits.float()
|
|
labels = labels.unsqueeze(1)
|
|
class_weight = self.class_weight
|
|
if self.label_smooth:
|
|
loss = F.cross_entropy(
|
|
logits * self.scale,
|
|
labels.repeat(1, p),
|
|
weight=class_weight,
|
|
label_smoothing=self.eps,
|
|
)
|
|
else:
|
|
loss = F.cross_entropy(
|
|
logits * self.scale,
|
|
labels.repeat(1, p),
|
|
weight=class_weight,
|
|
)
|
|
self.info.update({'loss': loss.detach().clone()})
|
|
if self.log_accuracy:
|
|
pred = logits.argmax(dim=1) # [n, p]
|
|
accu = (pred == labels).float().mean()
|
|
self.info.update({'accuracy': accu})
|
|
return loss, self.info
|