Add weighted CE proxy and fix loss imports

This commit is contained in:
2026-03-10 00:40:41 +08:00
parent 24381551f4
commit 5a02036318
5 changed files with 184 additions and 14 deletions
+14 -6
View File
@@ -1,17 +1,25 @@
from __future__ import annotations
from collections.abc import Callable
from ctypes import ArgumentError
import torch.nn as nn
from typing import Any
import torch
import torch.nn as nn
from opengait.utils import Odict
import functools
from opengait.utils import ddp_all_gather
def gather_and_scale_wrapper(func):
def gather_and_scale_wrapper(
func: Callable[..., tuple[torch.Tensor | float, Odict]],
) -> Callable[..., tuple[torch.Tensor | float, Odict]]:
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
"""
@functools.wraps(func)
def inner(*args, **kwds):
def inner(*args: Any, **kwds: Any) -> tuple[torch.Tensor | float, Odict]:
try:
for k, v in kwds.items():
@@ -32,7 +40,7 @@ class BaseLoss(nn.Module):
Your loss should also subclass this class.
"""
def __init__(self, loss_term_weight=1.0):
def __init__(self, loss_term_weight: float = 1.0) -> None:
"""
Initialize the base class.
@@ -43,7 +51,7 @@ class BaseLoss(nn.Module):
self.loss_term_weight = loss_term_weight
self.info = Odict()
def forward(self, logits, labels):
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
The default forward function.
@@ -56,4 +64,4 @@ class BaseLoss(nn.Module):
Returns:
tuple of loss and info.
"""
return .0, self.info
return torch.tensor(0.0, device=logits.device), self.info