Add resumable ScoNet skeleton training diagnostics

This commit is contained in:
2026-03-09 15:57:13 +08:00
parent 4e0b0a18dc
commit 36aef46a0d
15 changed files with 1226 additions and 44 deletions
+54 -29
View File
@@ -8,7 +8,8 @@ import pickle
import argparse
import numpy as np
from glob import glob
from typing import Literal
from copy import deepcopy
from typing import Any, Literal
from tqdm import tqdm
import matplotlib.cm as cm
import torch.distributed as dist
@@ -516,7 +517,7 @@ class GatherTransform(object):
"""
Gather the different transforms.
"""
def __init__(self, base_transform, transform_bone, transform_joint):
def __init__(self, base_transform, transform_bone, transform_joint, align_transform=None):
"""
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
@@ -526,12 +527,15 @@ class GatherTransform(object):
self.base_transform = base_transform
self.transform_bone = transform_bone
self.transform_joint = transform_joint
self.align_transform = align_transform
def __call__(self, pose_data):
x = self.base_transform(pose_data)
heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
if self.align_transform is not None:
heatmap = self.align_transform(heatmap)
return heatmap
class HeatmapAlignment():
@@ -543,23 +547,32 @@ class HeatmapAlignment():
def center_crop(self, heatmap):
"""
Input: [1, heatmap_image_size, heatmap_image_size]
Output: [1, final_img_size, final_img_size]
Input: [C, heatmap_image_size, heatmap_image_size]
Output: [C, final_img_size, final_img_size]
"""
raw_heatmap = heatmap[0]
if self.align:
y_sum = raw_heatmap.sum(axis=1)
y_top = (y_sum != 0).argmax(axis=0)
y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
height = y_btm - y_top + 1
raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1]
raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size]
raw_heatmap = heatmap
if self.align:
support_map = raw_heatmap.max(axis=0)
y_sum = support_map.sum(axis=1)
nonzero_rows = np.flatnonzero(y_sum != 0)
if nonzero_rows.size != 0:
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
height = y_btm - y_top + 1
x_center = self.heatmap_image_size // 2
x_left = max(x_center - (height // 2), 0)
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
resized = np.stack([
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
for channel in raw_heatmap
], axis=0)
return resized # [C, final_img_size, final_img_size]
def __call__(self, heatmap_imgs):
"""
heatmap_imgs: (T, 1, raw_size, raw_size)
return (T, 1, final_img_size, final_img_size)
heatmap_imgs: (T, C, raw_size, raw_size)
return (T, C, final_img_size, final_img_size)
"""
original_dtype = heatmap_imgs.dtype
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
@@ -570,12 +583,14 @@ class HeatmapAlignment():
return heatmap_imgs.astype(original_dtype)
def GenerateHeatmapTransform(
coco18tococo17_args,
padkeypoints_args,
norm_args,
heatmap_generator_args,
align_args,
coco18tococo17_args: dict[str, Any],
padkeypoints_args: dict[str, Any],
norm_args: dict[str, Any],
heatmap_generator_args: dict[str, Any],
align_args: dict[str, Any],
reduction: Literal["upstream", "max", "sum"] = "upstream",
sigma_limb: float | None = None,
sigma_joint: float | None = None,
):
base_transform = T.Compose([
@@ -584,34 +599,44 @@ def GenerateHeatmapTransform(
CenterAndScaleNormalizer(**norm_args),
])
heatmap_generator_args["with_limb"] = True
heatmap_generator_args["with_kp"] = False
bone_generator_args = deepcopy(heatmap_generator_args)
joint_generator_args = deepcopy(heatmap_generator_args)
bone_generator_args["with_limb"] = True
bone_generator_args["with_kp"] = False
if sigma_limb is not None:
bone_generator_args["sigma"] = sigma_limb
bone_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_bone = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
GeneratePoseTarget(**bone_generator_args),
bone_image_transform,
HeatmapAlignment(**align_args)
])
heatmap_generator_args["with_limb"] = False
heatmap_generator_args["with_kp"] = True
joint_generator_args["with_limb"] = False
joint_generator_args["with_kp"] = True
if sigma_joint is not None:
joint_generator_args["sigma"] = sigma_joint
joint_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_joint = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
GeneratePoseTarget(**joint_generator_args),
joint_image_transform,
HeatmapAlignment(**align_args)
])
transform = T.Compose([
GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W]
GatherTransform(
base_transform,
transform_bone,
transform_joint,
HeatmapAlignment(**align_args),
) # [T, 2, H, W]
])
return transform