Add weighted CE proxy and fix loss imports

This commit is contained in:
2026-03-10 00:40:41 +08:00
parent 24381551f4
commit 5a02036318
5 changed files with 184 additions and 14 deletions
+14 -6
View File
@@ -1,17 +1,25 @@
from __future__ import annotations
from collections.abc import Callable
from ctypes import ArgumentError
import torch.nn as nn
from typing import Any
import torch
import torch.nn as nn
from opengait.utils import Odict
import functools
from opengait.utils import ddp_all_gather
def gather_and_scale_wrapper(func):
def gather_and_scale_wrapper(
func: Callable[..., tuple[torch.Tensor | float, Odict]],
) -> Callable[..., tuple[torch.Tensor | float, Odict]]:
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
"""
@functools.wraps(func)
def inner(*args, **kwds):
def inner(*args: Any, **kwds: Any) -> tuple[torch.Tensor | float, Odict]:
try:
for k, v in kwds.items():
@@ -32,7 +40,7 @@ class BaseLoss(nn.Module):
Your loss should also subclass this class.
"""
def __init__(self, loss_term_weight=1.0):
def __init__(self, loss_term_weight: float = 1.0) -> None:
"""
Initialize the base class.
@@ -43,7 +51,7 @@ class BaseLoss(nn.Module):
self.loss_term_weight = loss_term_weight
self.info = Odict()
def forward(self, logits, labels):
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
The default forward function.
@@ -56,4 +64,4 @@ class BaseLoss(nn.Module):
Returns:
tuple of loss and info.
"""
return .0, self.info
return torch.tensor(0.0, device=logits.device), self.info
+10 -3
View File
@@ -1,14 +1,21 @@
from __future__ import annotations
import torch
from opengait.evaluation.metric import mean_iou
from opengait.utils.common import Odict
from .base import BaseLoss
from evaluation import mean_iou
class BinaryCrossEntropyLoss(BaseLoss):
def __init__(self, loss_term_weight=1.0, eps=1.0e-9):
eps: float
def __init__(self, loss_term_weight: float = 1.0, eps: float = 1.0e-9) -> None:
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
self.eps = eps
def forward(self, logits, labels):
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
logits: [n, 1, h, w]
labels: [n, 1, h, w]
+44 -5
View File
@@ -1,29 +1,68 @@
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn.functional as F
from opengait.utils.common import Odict
from .base import BaseLoss
class CrossEntropyLoss(BaseLoss):
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
scale: float
label_smooth: bool
eps: float
log_accuracy: bool
class_weight: torch.Tensor | None = None
def __init__(
self,
scale: float = 2**4,
label_smooth: bool = True,
eps: float = 0.1,
loss_term_weight: float = 1.0,
log_accuracy: bool = False,
class_weight: Sequence[float] | None = None,
) -> None:
super(CrossEntropyLoss, self).__init__(loss_term_weight)
self.scale = scale
self.label_smooth = label_smooth
self.eps = eps
self.log_accuracy = log_accuracy
weight_tensor = (
None
if class_weight is None
else torch.as_tensor(class_weight, dtype=torch.float32)
)
if class_weight is None:
self.register_buffer("class_weight", weight_tensor)
else:
self.register_buffer("class_weight", weight_tensor)
def forward(self, logits, labels):
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
logits: [n, c, p]
labels: [n]
"""
n, c, p = logits.size()
_n, _c, p = logits.size()
logits = logits.float()
labels = labels.unsqueeze(1)
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
if self.label_smooth:
loss = F.cross_entropy(
logits*self.scale, labels.repeat(1, p), label_smoothing=self.eps)
logits * self.scale,
labels.repeat(1, p),
weight=class_weight,
label_smoothing=self.eps,
)
else:
loss = F.cross_entropy(logits*self.scale, labels.repeat(1, p))
loss = F.cross_entropy(
logits * self.scale,
labels.repeat(1, p),
weight=class_weight,
)
self.info.update({'loss': loss.detach().clone()})
if self.log_accuracy:
pred = logits.argmax(dim=1) # [n, p]