Add weighted CE proxy and fix loss imports
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from opengait.evaluation.metric import mean_iou
|
||||
from opengait.utils.common import Odict
|
||||
|
||||
from .base import BaseLoss
|
||||
from evaluation import mean_iou
|
||||
|
||||
|
||||
class BinaryCrossEntropyLoss(BaseLoss):
|
||||
def __init__(self, loss_term_weight=1.0, eps=1.0e-9):
|
||||
eps: float
|
||||
|
||||
def __init__(self, loss_term_weight: float = 1.0, eps: float = 1.0e-9) -> None:
|
||||
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, logits, labels):
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
logits: [n, 1, h, w]
|
||||
labels: [n, 1, h, w]
|
||||
|
||||
Reference in New Issue
Block a user