909 lines
36 KiB
Python
909 lines
36 KiB
Python
import os
|
|
import cv2
|
|
import yaml
|
|
import math
|
|
import torch
|
|
import random
|
|
import pickle
|
|
import argparse
|
|
import numpy as np
|
|
from glob import glob
|
|
from copy import deepcopy
|
|
from collections.abc import Sequence
|
|
from typing import Any, Literal
|
|
from tqdm import tqdm
|
|
import matplotlib.cm as cm
|
|
import torch.distributed as dist
|
|
from torchvision import transforms as T
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from sklearn.impute import KNNImputer, SimpleImputer
|
|
|
|
torch.manual_seed(347)
|
|
random.seed(347)
|
|
|
|
#########################################################################################################
|
|
# The following code is the base class code for generating heatmap.
|
|
#########################################################################################################
|
|
|
|
class GeneratePoseTarget:
|
|
"""Generate pseudo heatmaps based on joint coordinates and confidence.
|
|
Required keys are "keypoint", "img_shape", "keypoint_score" (optional),
|
|
added or modified keys are "imgs".
|
|
Args:
|
|
sigma (float): The sigma of the generated gaussian map. Default: 0.6.
|
|
use_score (bool): Use the confidence score of keypoints as the maximum
|
|
of the gaussian maps. Default: True.
|
|
with_kp (bool): Generate pseudo heatmaps for keypoints. Default: True.
|
|
with_limb (bool): Generate pseudo heatmaps for limbs. At least one of
|
|
'with_kp' and 'with_limb' should be True. Default: False.
|
|
skeletons (tuple[tuple]): The definition of human skeletons.
|
|
Default: ((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), (7, 9),
|
|
(0, 6), (6, 8), (8, 10), (5, 11), (11, 13), (13, 15),
|
|
(6, 12), (12, 14), (14, 16), (11, 12)),
|
|
which is the definition of COCO-17p skeletons.
|
|
double (bool): Output both original heatmaps and flipped heatmaps.
|
|
Default: False.
|
|
left_kp (tuple[int]): Indexes of left keypoints, which is used when
|
|
flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15),
|
|
which is left keypoints in COCO-17p.
|
|
right_kp (tuple[int]): Indexes of right keypoints, which is used when
|
|
flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16),
|
|
which is right keypoints in COCO-17p.
|
|
left_limb (tuple[int]): Indexes of left limbs, which is used when
|
|
flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15),
|
|
which is left limbs of skeletons we defined for COCO-17p.
|
|
right_limb (tuple[int]): Indexes of right limbs, which is used when
|
|
flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16),
|
|
which is right limbs of skeletons we defined for COCO-17p.
|
|
"""
|
|
|
|
def __init__(self,
|
|
sigma=0.6,
|
|
use_score=True,
|
|
with_kp=True,
|
|
with_limb=False,
|
|
skeletons=((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7),
|
|
(7, 9), (0, 6), (6, 8), (8, 10), (5, 11), (11, 13),
|
|
(13, 15), (6, 12), (12, 14), (14, 16), (11, 12)),
|
|
double=False,
|
|
left_kp=(1, 3, 5, 7, 9, 11, 13, 15),
|
|
right_kp=(2, 4, 6, 8, 10, 12, 14, 16),
|
|
left_limb=(0, 2, 4, 5, 6, 10, 11, 12),
|
|
right_limb=(1, 3, 7, 8, 9, 13, 14, 15),
|
|
scaling=1.,
|
|
eps= 1e-3,
|
|
img_h=64,
|
|
img_w = 64,
|
|
joint_indices: Sequence[int] | None = None):
|
|
|
|
self.sigma = sigma
|
|
self.use_score = use_score
|
|
self.with_kp = with_kp
|
|
self.with_limb = with_limb
|
|
self.double = double
|
|
self.eps = eps
|
|
|
|
assert self.with_kp + self.with_limb == 1, ('One of "with_limb" and "with_kp" should be set as True.')
|
|
self.left_kp = left_kp
|
|
self.right_kp = right_kp
|
|
self.skeletons = skeletons
|
|
self.left_limb = left_limb
|
|
self.right_limb = right_limb
|
|
self.scaling = scaling
|
|
self.img_h = img_h
|
|
self.img_w = img_w
|
|
self.joint_indices = tuple(joint_indices) if joint_indices is not None else None
|
|
|
|
def generate_a_heatmap(self, arr, centers, max_values, point_center):
|
|
"""Generate pseudo heatmap for one keypoint in one frame.
|
|
Args:
|
|
arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w.
|
|
centers (np.ndarray): The coordinates of corresponding keypoints (of multiple persons). Shape: 1 * 2.
|
|
max_values (np.ndarray): The max values of each keypoint. Shape: (1, ).
|
|
point_center: Shape: (1, 2)
|
|
Returns:
|
|
np.ndarray: The generated pseudo heatmap.
|
|
"""
|
|
|
|
sigma = self.sigma
|
|
img_h, img_w = arr.shape
|
|
|
|
for center, max_value in zip(centers, max_values):
|
|
if max_value < self.eps:
|
|
continue
|
|
|
|
mu_x, mu_y = center[0], center[1]
|
|
|
|
tmp_st_x = int(mu_x - 3 * sigma)
|
|
tmp_ed_x = int(mu_x + 3 * sigma)
|
|
tmp_st_y = int(mu_y - 3 * sigma)
|
|
tmp_ed_y = int(mu_y + 3 * sigma)
|
|
|
|
st_x = max(tmp_st_x, 0)
|
|
ed_x = min(tmp_ed_x + 1, img_w)
|
|
st_y = max(tmp_st_y, 0)
|
|
ed_y = min(tmp_ed_y + 1, img_h)
|
|
x = np.arange(st_x, ed_x, dtype=np.float32)
|
|
y = np.arange(st_y, ed_y, dtype=np.float32)
|
|
|
|
# if the keypoint not in the heatmap coordinate system
|
|
if not (len(x) and len(y)):
|
|
continue
|
|
y = y[:, None]
|
|
|
|
patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2)
|
|
patch = patch * max_value
|
|
|
|
arr[st_y:ed_y, st_x:ed_x] = np.maximum(arr[st_y:ed_y, st_x:ed_x], patch)
|
|
|
|
def generate_a_limb_heatmap(self, arr, starts, ends, start_values, end_values, point_center):
|
|
"""Generate pseudo heatmap for one limb in one frame.
|
|
Args:
|
|
arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w.
|
|
starts (np.ndarray): The coordinates of one keypoint in the corresponding limbs. Shape: 1 * 2.
|
|
ends (np.ndarray): The coordinates of the other keypoint in the corresponding limbs. Shape: 1 * 2.
|
|
start_values (np.ndarray): The max values of one keypoint in the corresponding limbs. Shape: (1, ).
|
|
end_values (np.ndarray): The max values of the other keypoint in the corresponding limbs. Shape: (1, ).
|
|
Returns:
|
|
np.ndarray: The generated pseudo heatmap.
|
|
"""
|
|
|
|
sigma = self.sigma
|
|
img_h, img_w = arr.shape
|
|
|
|
for start, end, start_value, end_value in zip(starts, ends, start_values, end_values):
|
|
value_coeff = min(start_value, end_value)
|
|
if value_coeff < self.eps:
|
|
continue
|
|
|
|
min_x, max_x = min(start[0], end[0]), max(start[0], end[0])
|
|
min_y, max_y = min(start[1], end[1]), max(start[1], end[1])
|
|
|
|
|
|
|
|
tmp_min_x = int(min_x - 3 * sigma)
|
|
tmp_max_x = int(max_x + 3 * sigma)
|
|
tmp_min_y = int(min_y - 3 * sigma)
|
|
tmp_max_y = int(max_y + 3 * sigma)
|
|
|
|
min_x = max(tmp_min_x, 0)
|
|
max_x = min(tmp_max_x + 1, img_w)
|
|
min_y = max(tmp_min_y, 0)
|
|
max_y = min(tmp_max_y + 1, img_h)
|
|
|
|
x = np.arange(min_x, max_x, dtype=np.float32)
|
|
y = np.arange(min_y, max_y, dtype=np.float32)
|
|
|
|
if not (len(x) and len(y)):
|
|
continue
|
|
|
|
y = y[:, None]
|
|
x_0 = np.zeros_like(x)
|
|
y_0 = np.zeros_like(y)
|
|
|
|
# distance to start keypoints
|
|
d2_start = ((x - start[0])**2 + (y - start[1])**2)
|
|
|
|
# distance to end keypoints
|
|
d2_end = ((x - end[0])**2 + (y - end[1])**2)
|
|
|
|
# the distance between start and end keypoints.
|
|
d2_ab = ((start[0] - end[0])**2 + (start[1] - end[1])**2)
|
|
|
|
if d2_ab < 1:
|
|
self.generate_a_heatmap(arr, start[None], start_value[None], point_center)
|
|
continue
|
|
|
|
coeff = (d2_start - d2_end + d2_ab) / 2. / d2_ab
|
|
|
|
a_dominate = coeff <= 0
|
|
b_dominate = coeff >= 1
|
|
seg_dominate = 1 - a_dominate - b_dominate
|
|
|
|
position = np.stack([x + y_0, y + x_0], axis=-1)
|
|
projection = start + np.stack([coeff, coeff], axis=-1) * (end - start)
|
|
d2_line = position - projection
|
|
d2_line = d2_line[:, :, 0]**2 + d2_line[:, :, 1]**2
|
|
d2_seg = a_dominate * d2_start + b_dominate * d2_end + seg_dominate * d2_line
|
|
|
|
patch = np.exp(-d2_seg / 2. / sigma**2)
|
|
patch = patch * value_coeff
|
|
|
|
arr[min_y:max_y, min_x:max_x] = np.maximum(arr[min_y:max_y, min_x:max_x], patch)
|
|
def generate_heatmap(self, arr, kps, max_values):
|
|
"""Generate pseudo heatmap for all keypoints and limbs in one frame (if
|
|
needed).
|
|
Args:
|
|
arr (np.ndarray): The array to store the generated heatmaps. Shape: V * img_h * img_w.
|
|
kps (np.ndarray): The coordinates of keypoints in this frame. Shape: 1 * V * 2.
|
|
max_values (np.ndarray): The confidence score of each keypoint. Shape: 1 * V.
|
|
Returns:
|
|
np.ndarray: The generated pseudo heatmap.
|
|
"""
|
|
|
|
point_center = kps.mean(1)
|
|
|
|
if self.with_kp:
|
|
joint_indices = (
|
|
tuple(range(kps.shape[1]))
|
|
if self.joint_indices is None
|
|
else self.joint_indices
|
|
)
|
|
for output_index, joint_index in enumerate(joint_indices):
|
|
self.generate_a_heatmap(
|
|
arr[output_index],
|
|
kps[:, joint_index],
|
|
max_values[:, joint_index],
|
|
point_center,
|
|
)
|
|
|
|
if self.with_limb:
|
|
for i, limb in enumerate(self.skeletons):
|
|
start_idx, end_idx = limb
|
|
starts = kps[:, start_idx]
|
|
ends = kps[:, end_idx]
|
|
|
|
start_values = max_values[:, start_idx]
|
|
end_values = max_values[:, end_idx]
|
|
self.generate_a_limb_heatmap(arr[i], starts, ends, start_values, end_values, point_center)
|
|
|
|
def gen_an_aug(self, pose_data):
|
|
"""Generate pseudo heatmaps for all frames.
|
|
Args:
|
|
pose_data (array): [1, T, V, C]
|
|
Returns:
|
|
list[np.ndarray]: The generated pseudo heatmaps.
|
|
"""
|
|
|
|
all_kps = pose_data[..., :2]
|
|
kp_shape = pose_data.shape # [1, T, V, 2]
|
|
|
|
if pose_data.shape[-1] == 3:
|
|
all_kpscores = pose_data[..., -1] # [1, T, V]
|
|
else:
|
|
all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32)
|
|
|
|
|
|
|
|
# scale img_h, img_w and kps
|
|
img_h = int(self.img_h * self.scaling + 0.5)
|
|
img_w = int(self.img_w * self.scaling + 0.5)
|
|
all_kps[..., :2] *= self.scaling
|
|
|
|
num_frame = kp_shape[1]
|
|
num_c = 0
|
|
if self.with_kp:
|
|
num_c += all_kps.shape[2] if self.joint_indices is None else len(self.joint_indices)
|
|
if self.with_limb:
|
|
num_c += len(self.skeletons)
|
|
ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32)
|
|
|
|
for i in range(num_frame):
|
|
# 1, V, C
|
|
kps = all_kps[:, i]
|
|
# 1, V
|
|
kpscores = all_kpscores[:, i] if self.use_score else np.ones_like(all_kpscores[:, i])
|
|
|
|
self.generate_heatmap(ret[i], kps, kpscores)
|
|
return ret
|
|
|
|
def __call__(self, pose_data):
|
|
"""
|
|
pose_data: (T, V, C=3/2)
|
|
1: means person number
|
|
"""
|
|
pose_data = pose_data[None,...] # (1, T, V, C=3/2)
|
|
|
|
heatmap = self.gen_an_aug(pose_data)
|
|
|
|
if self.double:
|
|
indices = np.arange(heatmap.shape[1], dtype=np.int64)
|
|
left, right = (self.left_kp, self.right_kp) if self.with_kp else (self.left_limb, self.right_limb)
|
|
for l, r in zip(left, right): # noqa: E741
|
|
indices[l] = r
|
|
indices[r] = l
|
|
heatmap_flip = heatmap[..., ::-1][:, indices]
|
|
heatmap = np.concatenate([heatmap, heatmap_flip])
|
|
return heatmap
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'sigma={self.sigma}, '
|
|
f'use_score={self.use_score}, '
|
|
f'with_kp={self.with_kp}, '
|
|
f'with_limb={self.with_limb}, '
|
|
f'skeletons={self.skeletons}, '
|
|
f'double={self.double}, '
|
|
f'left_kp={self.left_kp}, '
|
|
f'right_kp={self.right_kp})')
|
|
return repr_str
|
|
|
|
class HeatmapToImage:
|
|
"""
|
|
Convert the heatmap data to image data.
|
|
"""
|
|
def __init__(self) -> None:
|
|
self.cmap = cm.gray
|
|
|
|
def __call__(self, heatmaps):
|
|
"""
|
|
heatmaps: (T, 17, H, W)
|
|
return images: (T, 1, H, W)
|
|
"""
|
|
heatmaps = [x.transpose(1, 2, 0) for x in heatmaps]
|
|
h, w, _ = heatmaps[0].shape
|
|
newh, neww = int(h), int(w)
|
|
heatmaps = [np.max(x, axis=-1) for x in heatmaps]
|
|
heatmaps = [(self.cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps]
|
|
heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
|
|
return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
|
|
|
|
|
|
class HeatmapReducer:
|
|
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
|
|
|
|
def __init__(self, reduction: Literal["max", "sum"] = "max") -> None:
|
|
if reduction not in {"max", "sum"}:
|
|
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
|
|
self.reduction = reduction
|
|
|
|
def __call__(self, heatmaps: np.ndarray) -> np.ndarray:
|
|
"""
|
|
heatmaps: (T, C, H, W)
|
|
return: (T, 1, H, W)
|
|
"""
|
|
if self.reduction == "max":
|
|
reduced = np.max(heatmaps, axis=1, keepdims=True)
|
|
reduced = np.clip(reduced, 0.0, 1.0)
|
|
return (reduced * 255).astype(np.uint8)
|
|
|
|
reduced = np.sum(heatmaps, axis=1, keepdims=True)
|
|
return (reduced * 255.0).astype(np.float32)
|
|
|
|
class CenterAndScaleNormalizer:
|
|
|
|
def __init__(
|
|
self,
|
|
pose_format="coco",
|
|
use_conf=True,
|
|
heatmap_image_height=128,
|
|
target_body_height=None,
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
- pose_format (str): Specifies the format of the keypoints.
|
|
This parameter determines how the keypoints are structured and indexed.
|
|
The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model.
|
|
- use_conf (bool): Indicates whether confidence scores.
|
|
- heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization.
|
|
- target_body_height (float | None): Optional normalized body height. When omitted,
|
|
preserve the historical SkeletonGait scaling heuristic.
|
|
"""
|
|
self.pose_format = pose_format
|
|
self.use_conf = use_conf
|
|
self.heatmap_image_height = heatmap_image_height
|
|
self.target_body_height = target_body_height
|
|
|
|
def __call__(self, data):
|
|
"""
|
|
Implements step (a) from Figure 2 in the SkeletonGait paper.
|
|
data: (T, V, C)
|
|
- T: number of frames
|
|
- V: number of joints
|
|
- C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
|
|
return data: (T, V, C)
|
|
"""
|
|
|
|
if self.use_conf:
|
|
pose_seq = data[..., :-1]
|
|
score = np.expand_dims(data[..., -1], axis=-1)
|
|
else:
|
|
pose_seq = data[..., :-1]
|
|
|
|
# Hip as the center point
|
|
if self.pose_format.lower() == "coco":
|
|
hip = (pose_seq[:, 11] + pose_seq[:, 12]) / 2. # [t, 2]
|
|
elif self.pose_format.split('-')[0].lower() == "openpose":
|
|
hip = (pose_seq[:, 9] + pose_seq[:, 12]) / 2. # [t, 2]
|
|
else:
|
|
raise ValueError(f"Error value for pose_format: {self.pose_format} in CenterAndScale Class.")
|
|
|
|
# Center-normalization
|
|
pose_seq = pose_seq - hip[:, np.newaxis, :]
|
|
|
|
# Scale-normalization
|
|
y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
|
|
y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
|
|
target_body_height = (
|
|
float(self.target_body_height)
|
|
if self.target_body_height is not None
|
|
else float(self.heatmap_image_height // 1.5)
|
|
)
|
|
body_height = np.maximum(y_max - y_min, 1e-6)
|
|
pose_seq *= (target_body_height / body_height)[:, np.newaxis, np.newaxis] # [t, v, 2]
|
|
|
|
pose_seq += self.heatmap_image_height // 2
|
|
|
|
if self.use_conf:
|
|
pose_seq = np.concatenate([pose_seq, score], axis=-1)
|
|
return pose_seq
|
|
|
|
class PadKeypoints:
|
|
"""
|
|
Pad the keypoints with missing values.
|
|
"""
|
|
|
|
def __init__(self, pad_method="knn", use_conf=True) -> None:
|
|
"""
|
|
pad_method (str): Specifies the method used to pad the missing values.
|
|
The supported methods are "knn" and "simple".
|
|
use_conf (bool): Indicates whether confidence scores.
|
|
"""
|
|
self.use_conf = use_conf
|
|
if pad_method.lower() == "knn":
|
|
self.imputer = KNNImputer(missing_values=0.0, n_neighbors=4, weights="distance", add_indicator=False)
|
|
elif pad_method.lower() == "simple":
|
|
self.imputer = SimpleImputer(missing_values=0.0, strategy='mean',add_indicator=True)
|
|
else:
|
|
raise ValueError(f"Error value for padding method: {pad_method}")
|
|
|
|
def __call__(self, raw_data):
|
|
"""
|
|
raw_data: (T, V, C)
|
|
- T: number of frames
|
|
- V: number of joints
|
|
- C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
|
|
return padded_data: (T, V, C)
|
|
"""
|
|
T, V, C = raw_data.shape
|
|
if self.use_conf:
|
|
data = raw_data[..., :-1]
|
|
score = np.expand_dims(raw_data[..., -1], axis=-1)
|
|
C = C - 1
|
|
else:
|
|
data = raw_data[..., :-1]
|
|
data = data.reshape((T, V*C))
|
|
padded_data = self.imputer.fit_transform(data)
|
|
try:
|
|
padded_data = padded_data.reshape((T, V, C))
|
|
except:
|
|
padded_data = data.reshape((T, V, C))
|
|
if self.use_conf:
|
|
padded_data = np.concatenate([padded_data, score], axis=-1)
|
|
return padded_data
|
|
|
|
class COCO18toCOCO17:
|
|
"""
|
|
Transfer COCO18 format (Openpose extracted) to COCO17 format
|
|
"""
|
|
|
|
def __init__(self, transfer_to_coco17=True):
|
|
"""
|
|
transfer_to_coco17 (bool): Indicates whether to transfer the keypoints from COCO18 to COCO17 format.
|
|
"""
|
|
self.map_dict = {
|
|
0: 0,# "nose",
|
|
1: 15,# "left_eye",
|
|
2: 14,# "right_eye",
|
|
3: 17,# "left_ear",
|
|
4: 16,# "right_ear",
|
|
5: 5,# "left_shoulder",
|
|
6: 2,# "right_shoulder",
|
|
7: 6,# "left_elbow",
|
|
8: 3,# "right_elbow",
|
|
9: 7,# "left_wrist",
|
|
10: 4,# "right_wrist",
|
|
11: 11,# "left_hip",
|
|
12: 8,# "right_hip",
|
|
13: 12,# "left_knee",
|
|
14: 9,# "right_knee",
|
|
15: 13,# "left_ankle",
|
|
16: 10,# "right_ankle"
|
|
}
|
|
self.transfer = transfer_to_coco17
|
|
|
|
def __call__(self, data):
|
|
|
|
"""
|
|
data: (T, 18, C)
|
|
- T: number of frames
|
|
- 18: number of joints of COCO18 format
|
|
- C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
|
|
return data: (T, 17, C)
|
|
"""
|
|
|
|
if self.transfer:
|
|
"""
|
|
input data [T, 18, C] coco18 format
|
|
return data [T, 17, C] coco17 format
|
|
"""
|
|
T, _, C = data.shape
|
|
coco17_pkl_data = np.zeros((T, 17, C))
|
|
for i in range(17):
|
|
coco17_pkl_data[:,i,:] = data[:,self.map_dict[i],:]
|
|
return coco17_pkl_data
|
|
else:
|
|
return data
|
|
|
|
class GatherTransform(object):
|
|
"""
|
|
Gather the different transforms.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
base_transform,
|
|
transform_bone,
|
|
transform_joint,
|
|
align_transform=None,
|
|
limb_gain: float = 1.0,
|
|
joint_gain: float = 1.0,
|
|
) -> None:
|
|
|
|
"""
|
|
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
|
|
transform_bone: GeneratePoseTarget for generate bone heatmap
|
|
transform_joint: GeneratePoseTarget for generate joint heatmap
|
|
"""
|
|
self.base_transform = base_transform
|
|
self.transform_bone = transform_bone
|
|
self.transform_joint = transform_joint
|
|
self.align_transform = align_transform
|
|
self.limb_gain = limb_gain
|
|
self.joint_gain = joint_gain
|
|
|
|
def _apply_channel_gains(self, heatmap: np.ndarray) -> np.ndarray:
|
|
if self.limb_gain == 1.0 and self.joint_gain == 1.0:
|
|
return heatmap
|
|
|
|
original_dtype = heatmap.dtype
|
|
scaled = heatmap.astype(np.float32, copy=True)
|
|
scaled[:, 0] *= self.limb_gain
|
|
scaled[:, 1] *= self.joint_gain
|
|
scaled = np.clip(scaled, 0.0, 255.0)
|
|
|
|
if np.issubdtype(original_dtype, np.integer):
|
|
return scaled.astype(original_dtype)
|
|
return scaled.astype(original_dtype)
|
|
|
|
def __call__(self, pose_data):
|
|
x = self.base_transform(pose_data)
|
|
heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
|
|
heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
|
|
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
|
|
if self.align_transform is not None:
|
|
heatmap = self.align_transform(heatmap)
|
|
return self._apply_channel_gains(heatmap)
|
|
|
|
AlignmentScope = Literal["frame", "sequence"]
|
|
AlignmentCropMode = Literal["square_center", "bbox_pad"]
|
|
|
|
|
|
class HeatmapAlignment():
|
|
def __init__(
|
|
self,
|
|
align: bool = True,
|
|
final_img_size: int = 64,
|
|
offset: int = 0,
|
|
heatmap_image_size: int = 128,
|
|
scope: AlignmentScope = "frame",
|
|
crop_mode: AlignmentCropMode = "square_center",
|
|
preserve_aspect_ratio: bool = False,
|
|
) -> None:
|
|
self.align = align
|
|
self.final_img_size = final_img_size
|
|
self.offset = offset
|
|
self.heatmap_image_size = heatmap_image_size
|
|
self.scope = scope
|
|
self.crop_mode = crop_mode
|
|
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
|
|
def _compute_crop_bounds(
|
|
self,
|
|
heatmap: np.ndarray,
|
|
) -> tuple[int, int, int, int] | None:
|
|
support_map = heatmap.max(axis=0)
|
|
y_sum = support_map.sum(axis=1)
|
|
x_sum = support_map.sum(axis=0)
|
|
nonzero_rows = np.flatnonzero(y_sum != 0)
|
|
nonzero_cols = np.flatnonzero(x_sum != 0)
|
|
if nonzero_rows.size == 0:
|
|
return None
|
|
if nonzero_cols.size == 0:
|
|
return None
|
|
|
|
y_top = max(int(nonzero_rows[0]) - self.offset, 0)
|
|
y_btm = min(int(nonzero_rows[-1]) + self.offset, self.heatmap_image_size - 1)
|
|
|
|
if self.crop_mode == "bbox_pad":
|
|
x_left = max(int(nonzero_cols[0]) - self.offset, 0)
|
|
x_right = min(int(nonzero_cols[-1]) + self.offset + 1, self.heatmap_image_size)
|
|
return y_top, y_btm, x_left, x_right
|
|
|
|
height = y_btm - y_top + 1
|
|
x_center = self.heatmap_image_size // 2
|
|
x_left = max(x_center - (height // 2), 0)
|
|
x_right = min(x_center + (height // 2) + 1, self.heatmap_image_size)
|
|
return y_top, y_btm, x_left, x_right
|
|
|
|
def _resize_and_pad(self, cropped_heatmap: np.ndarray) -> np.ndarray:
|
|
_, src_h, src_w = cropped_heatmap.shape
|
|
if src_h <= 0 or src_w <= 0:
|
|
return np.zeros(
|
|
(cropped_heatmap.shape[0], self.final_img_size, self.final_img_size),
|
|
dtype=np.float32,
|
|
)
|
|
|
|
scale = float(self.final_img_size) / float(max(src_h, src_w))
|
|
resized_h = max(1, int(round(src_h * scale)))
|
|
resized_w = max(1, int(round(src_w * scale)))
|
|
|
|
resized = np.stack([
|
|
cv2.resize(channel, (resized_w, resized_h), interpolation=cv2.INTER_AREA)
|
|
for channel in cropped_heatmap
|
|
], axis=0)
|
|
|
|
canvas = np.zeros(
|
|
(cropped_heatmap.shape[0], self.final_img_size, self.final_img_size),
|
|
dtype=np.float32,
|
|
)
|
|
y_offset = (self.final_img_size - resized_h) // 2
|
|
x_offset = (self.final_img_size - resized_w) // 2
|
|
canvas[:, y_offset:y_offset + resized_h, x_offset:x_offset + resized_w] = resized
|
|
return canvas
|
|
|
|
def _crop_and_resize(
|
|
self,
|
|
heatmap: np.ndarray,
|
|
crop_bounds: tuple[int, int, int, int] | None,
|
|
) -> np.ndarray:
|
|
raw_heatmap = heatmap
|
|
if crop_bounds is not None:
|
|
y_top, y_btm, x_left, x_right = crop_bounds
|
|
raw_heatmap = raw_heatmap[:, y_top:y_btm + 1, x_left:x_right]
|
|
if self.preserve_aspect_ratio:
|
|
return self._resize_and_pad(raw_heatmap)
|
|
|
|
return np.stack([
|
|
cv2.resize(channel, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
|
|
for channel in raw_heatmap
|
|
], axis=0)
|
|
|
|
def center_crop(self, heatmap):
|
|
"""
|
|
Input: [C, heatmap_image_size, heatmap_image_size]
|
|
Output: [C, final_img_size, final_img_size]
|
|
"""
|
|
crop_bounds = self._compute_crop_bounds(heatmap) if self.align else None
|
|
return self._crop_and_resize(heatmap, crop_bounds) # [C, final_img_size, final_img_size]
|
|
|
|
def __call__(self, heatmap_imgs):
|
|
"""
|
|
heatmap_imgs: (T, C, raw_size, raw_size)
|
|
return (T, C, final_img_size, final_img_size)
|
|
"""
|
|
original_dtype = heatmap_imgs.dtype
|
|
heatmap_imgs = heatmap_imgs.astype(np.float32) / 255.0
|
|
if self.align and self.scope == "sequence":
|
|
sequence_crop_bounds = self._compute_crop_bounds(heatmap_imgs.max(axis=0))
|
|
heatmap_imgs = np.array(
|
|
[self._crop_and_resize(heatmap_img, sequence_crop_bounds) for heatmap_img in heatmap_imgs],
|
|
dtype=np.float32,
|
|
)
|
|
else:
|
|
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs], dtype=np.float32)
|
|
heatmap_imgs = heatmap_imgs * 255.0
|
|
if np.issubdtype(original_dtype, np.integer):
|
|
return np.clip(heatmap_imgs, 0.0, 255.0).astype(original_dtype)
|
|
return heatmap_imgs.astype(original_dtype)
|
|
|
|
def GenerateHeatmapTransform(
|
|
coco18tococo17_args: dict[str, Any],
|
|
padkeypoints_args: dict[str, Any],
|
|
norm_args: dict[str, Any],
|
|
heatmap_generator_args: dict[str, Any],
|
|
align_args: dict[str, Any],
|
|
reduction: Literal["upstream", "max", "sum"] = "upstream",
|
|
sigma_limb: float | None = None,
|
|
sigma_joint: float | None = None,
|
|
channel_gain_limb: float | None = None,
|
|
channel_gain_joint: float | None = None,
|
|
):
|
|
|
|
base_transform = T.Compose([
|
|
COCO18toCOCO17(**coco18tococo17_args),
|
|
PadKeypoints(**padkeypoints_args),
|
|
CenterAndScaleNormalizer(**norm_args),
|
|
])
|
|
|
|
bone_generator_args = deepcopy(heatmap_generator_args)
|
|
joint_generator_args = deepcopy(heatmap_generator_args)
|
|
|
|
bone_generator_args["with_limb"] = True
|
|
bone_generator_args["with_kp"] = False
|
|
if sigma_limb is not None:
|
|
bone_generator_args["sigma"] = sigma_limb
|
|
bone_image_transform = (
|
|
HeatmapToImage()
|
|
if reduction == "upstream"
|
|
else HeatmapReducer(reduction=reduction)
|
|
)
|
|
transform_bone = T.Compose([
|
|
GeneratePoseTarget(**bone_generator_args),
|
|
bone_image_transform,
|
|
])
|
|
|
|
joint_generator_args["with_limb"] = False
|
|
joint_generator_args["with_kp"] = True
|
|
if sigma_joint is not None:
|
|
joint_generator_args["sigma"] = sigma_joint
|
|
joint_image_transform = (
|
|
HeatmapToImage()
|
|
if reduction == "upstream"
|
|
else HeatmapReducer(reduction=reduction)
|
|
)
|
|
transform_joint = T.Compose([
|
|
GeneratePoseTarget(**joint_generator_args),
|
|
joint_image_transform,
|
|
])
|
|
|
|
transform = T.Compose([
|
|
GatherTransform(
|
|
base_transform,
|
|
transform_bone,
|
|
transform_joint,
|
|
HeatmapAlignment(**align_args),
|
|
limb_gain=1.0 if channel_gain_limb is None else channel_gain_limb,
|
|
joint_gain=1.0 if channel_gain_joint is None else channel_gain_joint,
|
|
) # [T, 2, H, W]
|
|
])
|
|
|
|
return transform
|
|
|
|
#########################################################################################################
|
|
# The following code is DDP progress codes.
|
|
#########################################################################################################
|
|
class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
|
|
"""
|
|
Distributed Sampler that subsamples indicies sequentially,
|
|
making it easier to collate all results at the end.
|
|
Even though we only use this sampler for eval and predict (no training),
|
|
which means that the model params won't have to be synced (i.e. will not hang
|
|
for synchronization even if varied number of forward passes), we still add extra
|
|
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
|
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
|
"""
|
|
|
|
def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
|
|
if num_replicas is None:
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("Requires distributed package to be available")
|
|
num_replicas = torch.distributed.get_world_size()
|
|
if rank is None:
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("Requires distributed package to be available")
|
|
rank = torch.distributed.get_rank()
|
|
self.dataset = dataset
|
|
self.num_replicas = num_replicas
|
|
self.rank = rank
|
|
self.batch_size = batch_size
|
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
|
|
self.total_size = self.num_samples * self.num_replicas
|
|
|
|
def __iter__(self):
|
|
indices = list(range(len(self.dataset)))
|
|
# add extra samples to make it evenly divisible
|
|
indices += [indices[-1]] * (self.total_size - len(indices))
|
|
# subsample
|
|
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
|
return iter(indices)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
|
|
class TransferDataset(Dataset):
|
|
def __init__(self, args, generate_heatemap_cfgs) -> None:
|
|
super().__init__()
|
|
pose_root = args.pose_data_path
|
|
sigma = generate_heatemap_cfgs['heatmap_generator_args']['sigma']
|
|
self.dataset_name = args.dataset_name
|
|
assert self.dataset_name.lower() in ["sustech1k", "grew", "ccpg", "oumvlp", "ou-mvlp", "gait3d", "casiab", "casiae"], f"Invalid dataset name: {self.dataset_name}"
|
|
self.save_root = os.path.join(args.save_root, f"{self.dataset_name}_sigma_{sigma}_{args.ext_name}")
|
|
os.makedirs(self.save_root, exist_ok=True)
|
|
|
|
self.heatmap_transform = GenerateHeatmapTransform(**generate_heatemap_cfgs)
|
|
|
|
if self.dataset_name.lower() == "sustech1k":
|
|
self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/03*.pkl")))
|
|
else:
|
|
self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl")))
|
|
|
|
def __len__(self):
|
|
return len(self.all_ps_data_paths)
|
|
|
|
def __getitem__(self, index):
|
|
pose_path = self.all_ps_data_paths[index]
|
|
with open(pose_path, "rb") as f:
|
|
pose_data = pickle.load(f)
|
|
if self.dataset_name.lower() == "grew":
|
|
# print(pose_data.shape)
|
|
pose_data = pose_data[:,2:].reshape(-1, 17, 3)
|
|
|
|
tmp_split = pose_path.split('/')
|
|
|
|
heatmap_img = self.heatmap_transform(pose_data) # [T, 2, H, W]
|
|
|
|
save_path_pkl = os.path.join(self.save_root, 'pkl', *tmp_split[-4:-1])
|
|
os.makedirs(save_path_pkl, exist_ok=True)
|
|
|
|
# save some visualization
|
|
if index < 10:
|
|
# save images
|
|
save_path_img = os.path.join(self.save_root, 'images', *tmp_split[-4:-1])
|
|
os.makedirs(save_path_img, exist_ok=True)
|
|
# save_heatemapimg_index = random.choice(list(range(heatmap_img.shape[0])))
|
|
for save_heatemapimg_index in range(heatmap_img.shape[0]):
|
|
cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0])
|
|
cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1])
|
|
|
|
pickle.dump(heatmap_img, open(os.path.join(save_path_pkl, tmp_split[-1]), 'wb'))
|
|
return None
|
|
|
|
def mycollate(_):
|
|
return None
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description='Utility for generating heatmaps from pose data.')
|
|
parser.add_argument('--pose_data_path', type=str, required=True, help="Path to the root directory containing pose data (.pkl files, ID-level) files.")
|
|
parser.add_argument('--save_root', type=str, required=True, help="Root directory where generated heatmap .pkl files will be saved (ID-level).")
|
|
parser.add_argument('--ext_name', type=str, default='', help="Extension name to be appended to the 'save_root' for identification.")
|
|
parser.add_argument('--dataset_name', type=str, required=True, help="Name of the dataset being preprocessed.")
|
|
parser.add_argument('--heatemap_cfg_path', type=str, default='configs/skeletongait/pretreatment_heatmap.yaml', help="Path to the heatmap generator configuration file.")
|
|
parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed processing, defaults to 0 for non-distributed setups.")
|
|
opt = parser.parse_args()
|
|
return opt
|
|
|
|
def replace_variables(data, context=None):
|
|
if context is None:
|
|
context = {}
|
|
|
|
if isinstance(data, dict):
|
|
for key, value in data.items():
|
|
data[key] = replace_variables(value, context)
|
|
elif isinstance(data, list):
|
|
data = [replace_variables(item, context) for item in data]
|
|
elif isinstance(data, str):
|
|
if data.startswith('${') and data.endswith('}'):
|
|
var_path = data[2:-1].split('.')
|
|
var_value = context
|
|
try:
|
|
for part in var_path:
|
|
var_value = var_value[part]
|
|
return var_value
|
|
except KeyError:
|
|
raise ValueError(f"Variable {data} not found in context")
|
|
return data
|
|
|
|
if __name__ == "__main__":
|
|
dist.init_process_group("nccl", init_method='env://')
|
|
local_rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
args = get_args()
|
|
|
|
# Load the heatmap generator configuration
|
|
with open(args.heatemap_cfg_path, 'r') as stream:
|
|
generate_heatemap_cfgs = yaml.safe_load(stream)
|
|
generate_heatemap_cfgs = replace_variables(generate_heatemap_cfgs, generate_heatemap_cfgs)
|
|
# Create the dataset
|
|
dataset = TransferDataset(args, generate_heatemap_cfgs)
|
|
|
|
# Create the dataloader
|
|
dist_sampler = SequentialDistributedSampler(dataset, batch_size=1, rank=local_rank, num_replicas=world_size)
|
|
dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=dist_sampler, num_workers=8, collate_fn=mycollate)
|
|
for _, tmp in tqdm(enumerate(dataloader), total=len(dataloader)):
|
|
pass
|
|
|
|
|