49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
from opengait.evaluation.metric import mean_iou
|
|
from opengait.utils.common import Odict
|
|
|
|
from .base import BaseLoss
|
|
|
|
|
|
class BinaryCrossEntropyLoss(BaseLoss):
|
|
eps: float
|
|
|
|
def __init__(self, loss_term_weight: float = 1.0, eps: float = 1.0e-9) -> None:
|
|
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
|
|
self.eps = eps
|
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
|
"""
|
|
logits: [n, 1, h, w]
|
|
labels: [n, 1, h, w]
|
|
"""
|
|
# predts = torch.sigmoid(logits.float())
|
|
labels = labels.float()
|
|
logits = logits.float()
|
|
|
|
loss = - (labels * torch.log(logits + self.eps) +
|
|
(1 - labels) * torch.log(1. - logits + self.eps))
|
|
|
|
n = loss.size(0)
|
|
loss = loss.view(n, -1)
|
|
mean_loss = loss.mean()
|
|
hard_loss = loss.max()
|
|
miou = mean_iou((logits > 0.5).float(), labels)
|
|
self.info.update({
|
|
'loss': mean_loss.detach().clone(),
|
|
'hard_loss': hard_loss.detach().clone(),
|
|
'miou': miou.detach().clone()})
|
|
|
|
return mean_loss, self.info
|
|
|
|
|
|
if __name__ == "__main__":
|
|
loss_func = BinaryCrossEntropyLoss()
|
|
ipts = torch.randn(1, 1, 128, 64)
|
|
tags = (torch.randn(1, 1, 128, 64) > 0.).float()
|
|
loss = loss_func(ipts, tags)
|
|
print(loss)
|