feat: add scoliosis skeleton experiment tooling
This commit is contained in:
+1
-1
@@ -95,7 +95,7 @@ trainer_cfg:
|
||||
resume_every_iter: 500
|
||||
resume_keep: 3
|
||||
eval_iter: 500
|
||||
save_iter: 2000
|
||||
save_iter: 500
|
||||
save_name: ScoNet_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu
|
||||
sync_BN: true
|
||||
total_iter: 2000
|
||||
|
||||
+1
-1
@@ -95,7 +95,7 @@ trainer_cfg:
|
||||
resume_every_iter: 500
|
||||
resume_keep: 3
|
||||
eval_iter: 500
|
||||
save_iter: 2000
|
||||
save_iter: 500
|
||||
save_name: ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu
|
||||
sync_BN: true
|
||||
total_iter: 2000
|
||||
|
||||
+1
-1
@@ -99,7 +99,7 @@ trainer_cfg:
|
||||
resume_every_iter: 500
|
||||
resume_keep: 3
|
||||
eval_iter: 500
|
||||
save_iter: 2000
|
||||
save_iter: 500
|
||||
save_name: ScoNet_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu
|
||||
sync_BN: true
|
||||
total_iter: 2000
|
||||
|
||||
@@ -9,6 +9,7 @@ import argparse
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from copy import deepcopy
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from tqdm import tqdm
|
||||
import matplotlib.cm as cm
|
||||
@@ -72,7 +73,8 @@ class GeneratePoseTarget:
|
||||
scaling=1.,
|
||||
eps= 1e-3,
|
||||
img_h=64,
|
||||
img_w = 64):
|
||||
img_w = 64,
|
||||
joint_indices: Sequence[int] | None = None):
|
||||
|
||||
self.sigma = sigma
|
||||
self.use_score = use_score
|
||||
@@ -90,6 +92,7 @@ class GeneratePoseTarget:
|
||||
self.scaling = scaling
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
self.joint_indices = tuple(joint_indices) if joint_indices is not None else None
|
||||
|
||||
def generate_a_heatmap(self, arr, centers, max_values, point_center):
|
||||
"""Generate pseudo heatmap for one keypoint in one frame.
|
||||
@@ -221,9 +224,18 @@ class GeneratePoseTarget:
|
||||
point_center = kps.mean(1)
|
||||
|
||||
if self.with_kp:
|
||||
num_kp = kps.shape[1]
|
||||
for i in range(num_kp):
|
||||
self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center)
|
||||
joint_indices = (
|
||||
tuple(range(kps.shape[1]))
|
||||
if self.joint_indices is None
|
||||
else self.joint_indices
|
||||
)
|
||||
for output_index, joint_index in enumerate(joint_indices):
|
||||
self.generate_a_heatmap(
|
||||
arr[output_index],
|
||||
kps[:, joint_index],
|
||||
max_values[:, joint_index],
|
||||
point_center,
|
||||
)
|
||||
|
||||
if self.with_limb:
|
||||
for i, limb in enumerate(self.skeletons):
|
||||
@@ -261,7 +273,7 @@ class GeneratePoseTarget:
|
||||
num_frame = kp_shape[1]
|
||||
num_c = 0
|
||||
if self.with_kp:
|
||||
num_c += all_kps.shape[2]
|
||||
num_c += all_kps.shape[2] if self.joint_indices is None else len(self.joint_indices)
|
||||
if self.with_limb:
|
||||
num_c += len(self.skeletons)
|
||||
ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32)
|
||||
|
||||
@@ -353,6 +353,18 @@ class BaseModel(MetaModel, nn.Module):
|
||||
weights_only=False,
|
||||
)
|
||||
model_state_dict = checkpoint['model']
|
||||
current_state_keys = set(self.state_dict().keys())
|
||||
stale_loss_keys = [
|
||||
key for key in model_state_dict.keys()
|
||||
if key.startswith("loss_aggregator.") and key not in current_state_keys
|
||||
]
|
||||
for key in stale_loss_keys:
|
||||
model_state_dict.pop(key)
|
||||
if stale_loss_keys:
|
||||
self.msg_mgr.log_warning(
|
||||
"Ignoring stale loss state from %s: %s"
|
||||
% (save_name, stale_loss_keys)
|
||||
)
|
||||
|
||||
if not load_ckpt_strict:
|
||||
self.msg_mgr.log_info("-------- Restored Params List --------")
|
||||
|
||||
@@ -15,7 +15,6 @@ class CrossEntropyLoss(BaseLoss):
|
||||
label_smooth: bool
|
||||
eps: float
|
||||
log_accuracy: bool
|
||||
class_weight: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,10 +35,12 @@ class CrossEntropyLoss(BaseLoss):
|
||||
if class_weight is None
|
||||
else torch.as_tensor(class_weight, dtype=torch.float32)
|
||||
)
|
||||
if class_weight is None:
|
||||
self.register_buffer("class_weight", weight_tensor)
|
||||
else:
|
||||
self.register_buffer("class_weight", weight_tensor)
|
||||
self.register_buffer("_class_weight", weight_tensor)
|
||||
|
||||
@property
|
||||
def class_weight(self) -> torch.Tensor | None:
|
||||
buffer = self._buffers.get("_class_weight")
|
||||
return buffer if isinstance(buffer, torch.Tensor) else None
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]:
|
||||
"""
|
||||
@@ -49,7 +50,7 @@ class CrossEntropyLoss(BaseLoss):
|
||||
_n, _c, p = logits.size()
|
||||
logits = logits.float()
|
||||
labels = labels.unsqueeze(1)
|
||||
class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None
|
||||
class_weight = self.class_weight
|
||||
if self.label_smooth:
|
||||
loss = F.cross_entropy(
|
||||
logits * self.scale,
|
||||
|
||||
Reference in New Issue
Block a user