Add resumable ScoNet skeleton training diagnostics
This commit is contained in:
@@ -98,6 +98,15 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], replaced)
|
||||
|
||||
|
||||
def optional_cfg_float(cfg: dict[str, Any], key: str) -> float | None:
|
||||
value = cfg.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, (int, float)):
|
||||
raise TypeError(f"Expected numeric value for {key}, got {type(value).__name__}")
|
||||
return float(value)
|
||||
|
||||
|
||||
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
||||
return T.Compose([
|
||||
heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]),
|
||||
@@ -192,6 +201,8 @@ def main() -> None:
|
||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||
align_args=heatmap_cfg["align_args"],
|
||||
reduction=cast(HeatmapReduction, args.heatmap_reduction),
|
||||
sigma_limb=optional_cfg_float(heatmap_cfg, "sigma_limb"),
|
||||
sigma_joint=optional_cfg_float(heatmap_cfg, "sigma_joint"),
|
||||
)
|
||||
|
||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||
|
||||
Reference in New Issue
Block a user