feat: add scoliosis skeleton experiment tooling

This commit is contained in:
2026-03-10 15:03:53 +08:00
parent 2647398307
commit 44e62ae3ae
6 changed files with 39 additions and 14 deletions
+12
View File
@@ -353,6 +353,18 @@ class BaseModel(MetaModel, nn.Module):
weights_only=False,
)
model_state_dict = checkpoint['model']
current_state_keys = set(self.state_dict().keys())
stale_loss_keys = [
key for key in model_state_dict.keys()
if key.startswith("loss_aggregator.") and key not in current_state_keys
]
for key in stale_loss_keys:
model_state_dict.pop(key)
if stale_loss_keys:
self.msg_mgr.log_warning(
"Ignoring stale loss state from %s: %s"
% (save_name, stale_loss_keys)
)
if not load_ckpt_strict:
self.msg_mgr.log_info("-------- Restored Params List --------")
+7 -6
View File
@@ -15,7 +15,6 @@ class CrossEntropyLoss(BaseLoss):
label_smooth: bool
eps: float
log_accuracy: bool
class_weight: torch.Tensor | None = None
def __init__(
self,
@@ -36,10 +35,12 @@ class CrossEntropyLoss(BaseLoss):
if class_weight is None
else torch.as_tensor(class_weight, dtype=torch.float32)
)
if class_weight is None:
self.register_buffer("class_weight", weight_tensor)
else:
self.register_buffer("class_weight", weight_tensor)
self.register_buffer("_class_weight", weight_tensor)
@property
def class_weight(self) -> torch.Tensor | None:
buffer = self._buffers.get("_class_weight")
return buffer if isinstance(buffer, torch.Tensor) else None
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
@@ -49,7 +50,7 @@ class CrossEntropyLoss(BaseLoss):
_n, _c, p = logits.size()
logits = logits.float()
labels = labels.unsqueeze(1)
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
class_weight = self.class_weight
if self.label_smooth:
loss = F.cross_entropy(
logits * self.scale,