Add weighted CE proxy and fix loss imports

This commit is contained in:
2026-03-10 00:40:41 +08:00
parent 24381551f4
commit 5a02036318
5 changed files with 184 additions and 14 deletions
+44 -5
View File
@@ -1,29 +1,68 @@
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn.functional as F
from opengait.utils.common import Odict
from .base import BaseLoss
class CrossEntropyLoss(BaseLoss):
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
scale: float
label_smooth: bool
eps: float
log_accuracy: bool
class_weight: torch.Tensor | None = None
def __init__(
self,
scale: float = 2**4,
label_smooth: bool = True,
eps: float = 0.1,
loss_term_weight: float = 1.0,
log_accuracy: bool = False,
class_weight: Sequence[float] | None = None,
) -> None:
super(CrossEntropyLoss, self).__init__(loss_term_weight)
self.scale = scale
self.label_smooth = label_smooth
self.eps = eps
self.log_accuracy = log_accuracy
weight_tensor = (
None
if class_weight is None
else torch.as_tensor(class_weight, dtype=torch.float32)
)
if class_weight is None:
self.register_buffer("class_weight", weight_tensor)
else:
self.register_buffer("class_weight", weight_tensor)
def forward(self, logits, labels):
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
logits: [n, c, p]
labels: [n]
"""
n, c, p = logits.size()
_n, _c, p = logits.size()
logits = logits.float()
labels = labels.unsqueeze(1)
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
if self.label_smooth:
loss = F.cross_entropy(
logits*self.scale, labels.repeat(1, p), label_smoothing=self.eps)
logits * self.scale,
labels.repeat(1, p),
weight=class_weight,
label_smoothing=self.eps,
)
else:
loss = F.cross_entropy(logits*self.scale, labels.repeat(1, p))
loss = F.cross_entropy(
logits * self.scale,
labels.repeat(1, p),
weight=class_weight,
)
self.info.update({'loss': loss.detach().clone()})
if self.log_accuracy:
pred = logits.argmax(dim=1) # [n, p]