diff --git a/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu.yaml b/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu.yaml index d4d118a..0a6c14f 100644 --- a/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu.yaml +++ b/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu.yaml @@ -95,7 +95,7 @@ trainer_cfg: resume_every_iter: 500 resume_keep: 3 eval_iter: 500 - save_iter: 2000 + save_iter: 500 save_name: ScoNet_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu sync_BN: true total_iter: 2000 diff --git a/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_nocut_adamw_proxy_1gpu.yaml b/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_nocut_adamw_proxy_1gpu.yaml index 54b899b..5283f4d 100644 --- a/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_nocut_adamw_proxy_1gpu.yaml +++ b/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_nocut_adamw_proxy_1gpu.yaml @@ -95,7 +95,7 @@ trainer_cfg: resume_every_iter: 500 resume_keep: 3 eval_iter: 500 - save_iter: 2000 + save_iter: 500 save_name: ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu sync_BN: true total_iter: 2000 diff --git a/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu.yaml b/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu.yaml index 4991553..ed8d0c5 100644 --- a/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu.yaml +++ b/configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu.yaml @@ -99,7 +99,7 @@ trainer_cfg: resume_every_iter: 500 resume_keep: 3 eval_iter: 500 - save_iter: 2000 + save_iter: 500 save_name: ScoNet_skeleton_118_sigma15_joint8_sharedalign_weightedce_proxy_1gpu sync_BN: true total_iter: 2000 diff --git a/datasets/pretreatment_heatmap.py b/datasets/pretreatment_heatmap.py index 4a51d2c..bb26ef6 100644 --- a/datasets/pretreatment_heatmap.py +++ b/datasets/pretreatment_heatmap.py @@ -9,6 +9,7 @@ import argparse import numpy as np from glob import glob from copy import deepcopy +from collections.abc import Sequence from typing import Any, Literal from tqdm import tqdm import matplotlib.cm as cm @@ -72,7 +73,8 @@ class GeneratePoseTarget: scaling=1., eps= 1e-3, img_h=64, - img_w = 64): + img_w = 64, + joint_indices: Sequence[int] | None = None): self.sigma = sigma self.use_score = use_score @@ -90,6 +92,7 @@ class GeneratePoseTarget: self.scaling = scaling self.img_h = img_h self.img_w = img_w + self.joint_indices = tuple(joint_indices) if joint_indices is not None else None def generate_a_heatmap(self, arr, centers, max_values, point_center): """Generate pseudo heatmap for one keypoint in one frame. @@ -221,9 +224,18 @@ class GeneratePoseTarget: point_center = kps.mean(1) if self.with_kp: - num_kp = kps.shape[1] - for i in range(num_kp): - self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center) + joint_indices = ( + tuple(range(kps.shape[1])) + if self.joint_indices is None + else self.joint_indices + ) + for output_index, joint_index in enumerate(joint_indices): + self.generate_a_heatmap( + arr[output_index], + kps[:, joint_index], + max_values[:, joint_index], + point_center, + ) if self.with_limb: for i, limb in enumerate(self.skeletons): @@ -261,7 +273,7 @@ class GeneratePoseTarget: num_frame = kp_shape[1] num_c = 0 if self.with_kp: - num_c += all_kps.shape[2] + num_c += all_kps.shape[2] if self.joint_indices is None else len(self.joint_indices) if self.with_limb: num_c += len(self.skeletons) ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32) diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index a34b2c3..28d7ac9 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -353,6 +353,18 @@ class BaseModel(MetaModel, nn.Module): weights_only=False, ) model_state_dict = checkpoint['model'] + current_state_keys = set(self.state_dict().keys()) + stale_loss_keys = [ + key for key in model_state_dict.keys() + if key.startswith("loss_aggregator.") and key not in current_state_keys + ] + for key in stale_loss_keys: + model_state_dict.pop(key) + if stale_loss_keys: + self.msg_mgr.log_warning( + "Ignoring stale loss state from %s: %s" + % (save_name, stale_loss_keys) + ) if not load_ckpt_strict: self.msg_mgr.log_info("-------- Restored Params List --------") diff --git a/opengait/modeling/losses/ce.py b/opengait/modeling/losses/ce.py index fce602b..04b8548 100644 --- a/opengait/modeling/losses/ce.py +++ b/opengait/modeling/losses/ce.py @@ -15,7 +15,6 @@ class CrossEntropyLoss(BaseLoss): label_smooth: bool eps: float log_accuracy: bool - class_weight: torch.Tensor | None = None def __init__( self, @@ -36,10 +35,12 @@ class CrossEntropyLoss(BaseLoss): if class_weight is None else torch.as_tensor(class_weight, dtype=torch.float32) ) - if class_weight is None: - self.register_buffer("class_weight", weight_tensor) - else: - self.register_buffer("class_weight", weight_tensor) + self.register_buffer("_class_weight", weight_tensor) + + @property + def class_weight(self) -> torch.Tensor | None: + buffer = self._buffers.get("_class_weight") + return buffer if isinstance(buffer, torch.Tensor) else None def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, Odict]: """ @@ -49,7 +50,7 @@ class CrossEntropyLoss(BaseLoss): _n, _c, p = logits.size() logits = logits.float() labels = labels.unsqueeze(1) - class_weight = self.class_weight if isinstance(self.class_weight, torch.Tensor) else None + class_weight = self.class_weight if self.label_smooth: loss = F.cross_entropy( logits * self.scale,