feat: add scoliosis skeleton experiment tooling

This commit is contained in:
2026-03-10 15:03:53 +08:00
parent 2647398307
commit 44e62ae3ae
6 changed files with 39 additions and 14 deletions
+7 -6
View File
@@ -15,7 +15,6 @@ class CrossEntropyLoss(BaseLoss):
label_smooth: bool
eps: float
log_accuracy: bool
class_weight: torch.Tensor | None = None
def __init__(
self,
@@ -36,10 +35,12 @@ class CrossEntropyLoss(BaseLoss):
if class_weight is None
else torch.as_tensor(class_weight, dtype=torch.float32)
)
if class_weight is None:
self.register_buffer("class_weight", weight_tensor)
else:
self.register_buffer("class_weight", weight_tensor)
self.register_buffer("_class_weight", weight_tensor)
@property
def class_weight(self) -> torch.Tensor | None:
buffer = self._buffers.get("_class_weight")
return buffer if isinstance(buffer, torch.Tensor) else None
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
@@ -49,7 +50,7 @@ class CrossEntropyLoss(BaseLoss):
_n, _c, p = logits.size()
logits = logits.float()
labels = labels.unsqueeze(1)
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
class_weight = self.class_weight
if self.label_smooth:
loss = F.cross_entropy(
logits * self.scale,