Add weighted CE proxy and fix loss imports
This commit is contained in:
@@ -1,17 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from ctypes import ArgumentError
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from opengait.utils import Odict
|
||||
import functools
|
||||
from opengait.utils import ddp_all_gather
|
||||
|
||||
|
||||
def gather_and_scale_wrapper(func):
|
||||
def gather_and_scale_wrapper(
|
||||
func: Callable[..., tuple[torch.Tensor | float, Odict]],
|
||||
) -> Callable[..., tuple[torch.Tensor | float, Odict]]:
|
||||
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
def inner(*args: Any, **kwds: Any) -> tuple[torch.Tensor | float, Odict]:
|
||||
try:
|
||||
|
||||
for k, v in kwds.items():
|
||||
@@ -32,7 +40,7 @@ class BaseLoss(nn.Module):
|
||||
Your loss should also subclass this class.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_term_weight=1.0):
|
||||
def __init__(self, loss_term_weight: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize the base class.
|
||||
|
||||
@@ -43,7 +51,7 @@ class BaseLoss(nn.Module):
|
||||
self.loss_term_weight = loss_term_weight
|
||||
self.info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
The default forward function.
|
||||
|
||||
@@ -56,4 +64,4 @@ class BaseLoss(nn.Module):
|
||||
Returns:
|
||||
tuple of loss and info.
|
||||
"""
|
||||
return .0, self.info
|
||||
return torch.tensor(0.0, device=logits.device), self.info
|
||||
|
||||
Reference in New Issue
Block a user