Add weighted CE proxy and fix loss imports
This commit is contained in:
@@ -1,17 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from ctypes import ArgumentError
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from opengait.utils import Odict
|
||||
import functools
|
||||
from opengait.utils import ddp_all_gather
|
||||
|
||||
|
||||
def gather_and_scale_wrapper(func):
|
||||
def gather_and_scale_wrapper(
|
||||
func: Callable[..., tuple[torch.Tensor | float, Odict]],
|
||||
) -> Callable[..., tuple[torch.Tensor | float, Odict]]:
|
||||
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
def inner(*args: Any, **kwds: Any) -> tuple[torch.Tensor | float, Odict]:
|
||||
try:
|
||||
|
||||
for k, v in kwds.items():
|
||||
@@ -32,7 +40,7 @@ class BaseLoss(nn.Module):
|
||||
Your loss should also subclass this class.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_term_weight=1.0):
|
||||
def __init__(self, loss_term_weight: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize the base class.
|
||||
|
||||
@@ -43,7 +51,7 @@ class BaseLoss(nn.Module):
|
||||
self.loss_term_weight = loss_term_weight
|
||||
self.info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
The default forward function.
|
||||
|
||||
@@ -56,4 +64,4 @@ class BaseLoss(nn.Module):
|
||||
Returns:
|
||||
tuple of loss and info.
|
||||
"""
|
||||
return .0, self.info
|
||||
return torch.tensor(0.0, device=logits.device), self.info
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from opengait.evaluation.metric import mean_iou
|
||||
from opengait.utils.common import Odict
|
||||
|
||||
from .base import BaseLoss
|
||||
from evaluation import mean_iou
|
||||
|
||||
|
||||
class BinaryCrossEntropyLoss(BaseLoss):
|
||||
def __init__(self, loss_term_weight=1.0, eps=1.0e-9):
|
||||
eps: float
|
||||
|
||||
def __init__(self, loss_term_weight: float = 1.0, eps: float = 1.0e-9) -> None:
|
||||
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, logits, labels):
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
logits: [n, 1, h, w]
|
||||
labels: [n, 1, h, w]
|
||||
|
||||
@@ -1,29 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from opengait.utils.common import Odict
|
||||
|
||||
from .base import BaseLoss
|
||||
|
||||
|
||||
class CrossEntropyLoss(BaseLoss):
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
|
||||
scale: float
|
||||
label_smooth: bool
|
||||
eps: float
|
||||
log_accuracy: bool
|
||||
class_weight: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: float = 2**4,
|
||||
label_smooth: bool = True,
|
||||
eps: float = 0.1,
|
||||
loss_term_weight: float = 1.0,
|
||||
log_accuracy: bool = False,
|
||||
class_weight: Sequence[float] | None = None,
|
||||
) -> None:
|
||||
super(CrossEntropyLoss, self).__init__(loss_term_weight)
|
||||
self.scale = scale
|
||||
self.label_smooth = label_smooth
|
||||
self.eps = eps
|
||||
self.log_accuracy = log_accuracy
|
||||
weight_tensor = (
|
||||
None
|
||||
if class_weight is None
|
||||
else torch.as_tensor(class_weight, dtype=torch.float32)
|
||||
)
|
||||
if class_weight is None:
|
||||
self.register_buffer("class_weight", weight_tensor)
|
||||
else:
|
||||
self.register_buffer("class_weight", weight_tensor)
|
||||
|
||||
def forward(self, logits, labels):
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
logits: [n, c, p]
|
||||
labels: [n]
|
||||
"""
|
||||
n, c, p = logits.size()
|
||||
_n, _c, p = logits.size()
|
||||
logits = logits.float()
|
||||
labels = labels.unsqueeze(1)
|
||||
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
|
||||
if self.label_smooth:
|
||||
loss = F.cross_entropy(
|
||||
logits*self.scale, labels.repeat(1, p), label_smoothing=self.eps)
|
||||
logits * self.scale,
|
||||
labels.repeat(1, p),
|
||||
weight=class_weight,
|
||||
label_smoothing=self.eps,
|
||||
)
|
||||
else:
|
||||
loss = F.cross_entropy(logits*self.scale, labels.repeat(1, p))
|
||||
loss = F.cross_entropy(
|
||||
logits * self.scale,
|
||||
labels.repeat(1, p),
|
||||
weight=class_weight,
|
||||
)
|
||||
self.info.update({'loss': loss.detach().clone()})
|
||||
if self.log_accuracy:
|
||||
pred = logits.argmax(dim=1) # [n, p]
|
||||
|
||||
Reference in New Issue
Block a user