Add weighted CE proxy and fix loss imports

This commit is contained in:
2026-03-10 00:40:41 +08:00
parent 24381551f4
commit 5a02036318
5 changed files with 184 additions and 14 deletions
+10 -3
View File
@@ -1,14 +1,21 @@
from __future__ import annotations
import torch
from opengait.evaluation.metric import mean_iou
from opengait.utils.common import Odict
from .base import BaseLoss
from evaluation import mean_iou
class BinaryCrossEntropyLoss(BaseLoss):
def __init__(self, loss_term_weight=1.0, eps=1.0e-9):
eps: float
def __init__(self, loss_term_weight: float = 1.0, eps: float = 1.0e-9) -> None:
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
self.eps = eps
def forward(self, logits, labels):
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
logits: [n, 1, h, w]
labels: [n, 1, h, w]