feat: add scoliosis skeleton experiment tooling

This commit is contained in:
2026-03-10 15:03:53 +08:00
parent 2647398307
commit 44e62ae3ae
6 changed files with 39 additions and 14 deletions
@@ -95,7 +95,7 @@ trainer_cfg:
resume_every_iter: 500
resume_keep: 3
eval_iter: 500
save_iter: 2000
save_iter: 500
save_name: ScoNet_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu
sync_BN: true
total_iter: 2000
@@ -95,7 +95,7 @@ trainer_cfg:
resume_every_iter: 500
resume_keep: 3
eval_iter: 500
save_iter: 2000
save_iter: 500
save_name: ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu
sync_BN: true
total_iter: 2000
@@ -99,7 +99,7 @@ trainer_cfg:
resume_every_iter: 500
resume_keep: 3
eval_iter: 500
save_iter: 2000
save_iter: 500
save_name: ScoNet_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu
sync_BN: true
total_iter: 2000
+17 -5
View File
@@ -9,6 +9,7 @@ import argparse
import numpy as np
from glob import glob
from copy import deepcopy
from collections.abc import Sequence
from typing import Any, Literal
from tqdm import tqdm
import matplotlib.cm as cm
@@ -72,7 +73,8 @@ class GeneratePoseTarget:
scaling=1.,
eps= 1e-3,
img_h=64,
img_w = 64):
img_w = 64,
joint_indices: Sequence[int] | None = None):
self.sigma = sigma
self.use_score = use_score
@@ -90,6 +92,7 @@ class GeneratePoseTarget:
self.scaling = scaling
self.img_h = img_h
self.img_w = img_w
self.joint_indices = tuple(joint_indices) if joint_indices is not None else None
def generate_a_heatmap(self, arr, centers, max_values, point_center):
"""Generate pseudo heatmap for one keypoint in one frame.
@@ -221,9 +224,18 @@ class GeneratePoseTarget:
point_center = kps.mean(1)
if self.with_kp:
num_kp = kps.shape[1]
for i in range(num_kp):
self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center)
joint_indices = (
tuple(range(kps.shape[1]))
if self.joint_indices is None
else self.joint_indices
)
for output_index, joint_index in enumerate(joint_indices):
self.generate_a_heatmap(
arr[output_index],
kps[:, joint_index],
max_values[:, joint_index],
point_center,
)
if self.with_limb:
for i, limb in enumerate(self.skeletons):
@@ -261,7 +273,7 @@ class GeneratePoseTarget:
num_frame = kp_shape[1]
num_c = 0
if self.with_kp:
num_c += all_kps.shape[2]
num_c += all_kps.shape[2] if self.joint_indices is None else len(self.joint_indices)
if self.with_limb:
num_c += len(self.skeletons)
ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32)
+12
View File
@@ -353,6 +353,18 @@ class BaseModel(MetaModel, nn.Module):
weights_only=False,
)
model_state_dict = checkpoint['model']
current_state_keys = set(self.state_dict().keys())
stale_loss_keys = [
key for key in model_state_dict.keys()
if key.startswith("loss_aggregator.") and key not in current_state_keys
]
for key in stale_loss_keys:
model_state_dict.pop(key)
if stale_loss_keys:
self.msg_mgr.log_warning(
"Ignoring stale loss state from %s: %s"
% (save_name, stale_loss_keys)
)
if not load_ckpt_strict:
self.msg_mgr.log_info("-------- Restored Params List --------")
+7 -6
View File
@@ -15,7 +15,6 @@ class CrossEntropyLoss(BaseLoss):
label_smooth: bool
eps: float
log_accuracy: bool
class_weight: torch.Tensor | None = None
def __init__(
self,
@@ -36,10 +35,12 @@ class CrossEntropyLoss(BaseLoss):
if class_weight is None
else torch.as_tensor(class_weight, dtype=torch.float32)
)
if class_weight is None:
self.register_buffer("class_weight", weight_tensor)
else:
self.register_buffer("class_weight", weight_tensor)
self.register_buffer("_class_weight", weight_tensor)
@property
def class_weight(self) -> torch.Tensor | None:
buffer = self._buffers.get("_class_weight")
return buffer if isinstance(buffer, torch.Tensor) else None
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
"""
@@ -49,7 +50,7 @@ class CrossEntropyLoss(BaseLoss):
_n, _c, p = logits.size()
logits = logits.float()
labels = labels.unsqueeze(1)
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
class_weight = self.class_weight
if self.label_smooth:
loss = F.cross_entropy(
logits * self.scale,