Align DRF skeleton preprocessing with upstream heatmap path

This commit is contained in:
2026-03-08 14:50:35 +08:00
parent bbb41e8dd9
commit 295d951206
10 changed files with 174 additions and 21 deletions
+15 -4
View File
@@ -8,6 +8,7 @@ import pickle
import argparse
import numpy as np
from glob import glob
from typing import Literal
from tqdm import tqdm
import matplotlib.cm as cm
import torch.distributed as dist
@@ -328,7 +329,7 @@ class HeatmapToImage:
class HeatmapReducer:
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
def __init__(self, reduction: str = "max") -> None:
def __init__(self, reduction: Literal["max", "sum"] = "max") -> None:
if reduction not in {"max", "sum"}:
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
self.reduction = reduction
@@ -574,7 +575,7 @@ def GenerateHeatmapTransform(
norm_args,
heatmap_generator_args,
align_args,
reduction="max",
reduction: Literal["upstream", "max", "sum"] = "upstream",
):
base_transform = T.Compose([
@@ -585,17 +586,27 @@ def GenerateHeatmapTransform(
heatmap_generator_args["with_limb"] = True
heatmap_generator_args["with_kp"] = False
bone_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_bone = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
HeatmapReducer(reduction=reduction),
bone_image_transform,
HeatmapAlignment(**align_args)
])
heatmap_generator_args["with_limb"] = False
heatmap_generator_args["with_kp"] = True
joint_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_joint = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
HeatmapReducer(reduction=reduction),
joint_image_transform,
HeatmapAlignment(**align_args)
])