feat: add scoliosis skeleton experiment tooling
This commit is contained in:
@@ -353,6 +353,18 @@ class BaseModel(MetaModel, nn.Module):
|
||||
weights_only=False,
|
||||
)
|
||||
model_state_dict = checkpoint['model']
|
||||
current_state_keys = set(self.state_dict().keys())
|
||||
stale_loss_keys = [
|
||||
key for key in model_state_dict.keys()
|
||||
if key.startswith("loss_aggregator.") and key not in current_state_keys
|
||||
]
|
||||
for key in stale_loss_keys:
|
||||
model_state_dict.pop(key)
|
||||
if stale_loss_keys:
|
||||
self.msg_mgr.log_warning(
|
||||
"Ignoring stale loss state from %s: %s"
|
||||
% (save_name, stale_loss_keys)
|
||||
)
|
||||
|
||||
if not load_ckpt_strict:
|
||||
self.msg_mgr.log_info("-------- Restored Params List --------")
|
||||
|
||||
@@ -15,7 +15,6 @@ class CrossEntropyLoss(BaseLoss):
|
||||
label_smooth: bool
|
||||
eps: float
|
||||
log_accuracy: bool
|
||||
class_weight: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,10 +35,12 @@ class CrossEntropyLoss(BaseLoss):
|
||||
if class_weight is None
|
||||
else torch.as_tensor(class_weight, dtype=torch.float32)
|
||||
)
|
||||
if class_weight is None:
|
||||
self.register_buffer("class_weight", weight_tensor)
|
||||
else:
|
||||
self.register_buffer("class_weight", weight_tensor)
|
||||
self.register_buffer("_class_weight", weight_tensor)
|
||||
|
||||
@property
|
||||
def class_weight(self) -> torch.Tensor | None:
|
||||
buffer = self._buffers.get("_class_weight")
|
||||
return buffer if isinstance(buffer, torch.Tensor) else None
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
@@ -49,7 +50,7 @@ class CrossEntropyLoss(BaseLoss):
|
||||
_n, _c, p = logits.size()
|
||||
logits = logits.float()
|
||||
labels = labels.unsqueeze(1)
|
||||
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
|
||||
class_weight = self.class_weight
|
||||
if self.label_smooth:
|
||||
loss = F.cross_entropy(
|
||||
logits * self.scale,
|
||||
|
||||
Reference in New Issue
Block a user