Add resumable ScoNet skeleton training diagnostics
This commit is contained in:
@@ -8,7 +8,8 @@ import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Literal
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal
|
||||
from tqdm import tqdm
|
||||
import matplotlib.cm as cm
|
||||
import torch.distributed as dist
|
||||
@@ -516,7 +517,7 @@ class GatherTransform(object):
|
||||
"""
|
||||
Gather the different transforms.
|
||||
"""
|
||||
def __init__(self, base_transform, transform_bone, transform_joint):
|
||||
def __init__(self, base_transform, transform_bone, transform_joint, align_transform=None):
|
||||
|
||||
"""
|
||||
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
|
||||
@@ -526,12 +527,15 @@ class GatherTransform(object):
|
||||
self.base_transform = base_transform
|
||||
self.transform_bone = transform_bone
|
||||
self.transform_joint = transform_joint
|
||||
self.align_transform = align_transform
|
||||
|
||||
def __call__(self, pose_data):
|
||||
x = self.base_transform(pose_data)
|
||||
heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
|
||||
heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
|
||||
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
|
||||
if self.align_transform is not None:
|
||||
heatmap = self.align_transform(heatmap)
|
||||
return heatmap
|
||||
|
||||
class HeatmapAlignment():
|
||||
@@ -543,23 +547,32 @@ class HeatmapAlignment():
|
||||
|
||||
def center_crop(self, heatmap):
|
||||
"""
|
||||
Input: [1, heatmap_image_size, heatmap_image_size]
|
||||
Output: [1, final_img_size, final_img_size]
|
||||
Input: [C, heatmap_image_size, heatmap_image_size]
|
||||
Output: [C, final_img_size, final_img_size]
|
||||
"""
|
||||
raw_heatmap = heatmap[0]
|
||||
if self.align:
|
||||
y_sum = raw_heatmap.sum(axis=1)
|
||||
y_top = (y_sum != 0).argmax(axis=0)
|
||||
y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
|
||||
height = y_btm - y_top + 1
|
||||
raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1]
|
||||
raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
||||
return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size]
|
||||
raw_heatmap = heatmap
|
||||
if self.align:
|
||||
support_map = raw_heatmap.max(axis=0)
|
||||
y_sum = support_map.sum(axis=1)
|
||||
nonzero_rows = np.flatnonzero(y_sum != 0)
|
||||
if nonzero_rows.size != 0:
|
||||
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
|
||||
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
|
||||
height = y_btm - y_top + 1
|
||||
x_center = self.heatmap_image_size // 2
|
||||
x_left = max(x_center - (height // 2), 0)
|
||||
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
|
||||
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
|
||||
resized = np.stack([
|
||||
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
||||
for channel in raw_heatmap
|
||||
], axis=0)
|
||||
return resized # [C, final_img_size, final_img_size]
|
||||
|
||||
def __call__(self, heatmap_imgs):
|
||||
"""
|
||||
heatmap_imgs: (T, 1, raw_size, raw_size)
|
||||
return (T, 1, final_img_size, final_img_size)
|
||||
heatmap_imgs: (T, C, raw_size, raw_size)
|
||||
return (T, C, final_img_size, final_img_size)
|
||||
"""
|
||||
original_dtype = heatmap_imgs.dtype
|
||||
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
|
||||
@@ -570,12 +583,14 @@ class HeatmapAlignment():
|
||||
return heatmap_imgs.astype(original_dtype)
|
||||
|
||||
def GenerateHeatmapTransform(
|
||||
coco18tococo17_args,
|
||||
padkeypoints_args,
|
||||
norm_args,
|
||||
heatmap_generator_args,
|
||||
align_args,
|
||||
coco18tococo17_args: dict[str, Any],
|
||||
padkeypoints_args: dict[str, Any],
|
||||
norm_args: dict[str, Any],
|
||||
heatmap_generator_args: dict[str, Any],
|
||||
align_args: dict[str, Any],
|
||||
reduction: Literal["upstream", "max", "sum"] = "upstream",
|
||||
sigma_limb: float | None = None,
|
||||
sigma_joint: float | None = None,
|
||||
):
|
||||
|
||||
base_transform = T.Compose([
|
||||
@@ -584,34 +599,44 @@ def GenerateHeatmapTransform(
|
||||
CenterAndScaleNormalizer(**norm_args),
|
||||
])
|
||||
|
||||
heatmap_generator_args["with_limb"] = True
|
||||
heatmap_generator_args["with_kp"] = False
|
||||
bone_generator_args = deepcopy(heatmap_generator_args)
|
||||
joint_generator_args = deepcopy(heatmap_generator_args)
|
||||
|
||||
bone_generator_args["with_limb"] = True
|
||||
bone_generator_args["with_kp"] = False
|
||||
if sigma_limb is not None:
|
||||
bone_generator_args["sigma"] = sigma_limb
|
||||
bone_image_transform = (
|
||||
HeatmapToImage()
|
||||
if reduction == "upstream"
|
||||
else HeatmapReducer(reduction=reduction)
|
||||
)
|
||||
transform_bone = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
GeneratePoseTarget(**bone_generator_args),
|
||||
bone_image_transform,
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
heatmap_generator_args["with_limb"] = False
|
||||
heatmap_generator_args["with_kp"] = True
|
||||
joint_generator_args["with_limb"] = False
|
||||
joint_generator_args["with_kp"] = True
|
||||
if sigma_joint is not None:
|
||||
joint_generator_args["sigma"] = sigma_joint
|
||||
joint_image_transform = (
|
||||
HeatmapToImage()
|
||||
if reduction == "upstream"
|
||||
else HeatmapReducer(reduction=reduction)
|
||||
)
|
||||
transform_joint = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
GeneratePoseTarget(**joint_generator_args),
|
||||
joint_image_transform,
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
transform = T.Compose([
|
||||
GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W]
|
||||
GatherTransform(
|
||||
base_transform,
|
||||
transform_bone,
|
||||
transform_joint,
|
||||
HeatmapAlignment(**align_args),
|
||||
) # [T, 2, H, W]
|
||||
])
|
||||
|
||||
return transform
|
||||
|
||||
Reference in New Issue
Block a user