Align DRF skeleton preprocessing with upstream heatmap path

This commit is contained in:
2026-03-08 14:50:35 +08:00
parent bbb41e8dd9
commit 295d951206
10 changed files with 174 additions and 21 deletions
+14 -1
View File
@@ -43,7 +43,7 @@ class DRF(BaseModelBody):
list[str],
list[str],
Int[torch.Tensor, "1 batch"] | None,
Float[torch.Tensor, "batch pairs metrics"],
Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
],
) -> dict[str, dict[str, Any]]:
ipts, pids, labels, _, seqL, key_features = inputs
@@ -64,6 +64,7 @@ class DRF(BaseModelBody):
feat = self.HPP(outs)
embed_1 = self.FCs(feat)
key_features = canonicalize_pav(key_features)
embed_1 = self.PGA(embed_1, key_features)
embed_2, logits = self.BNNecks(embed_1)
@@ -123,3 +124,15 @@ LABEL_MAP: dict[str, int] = {
"neutral": 1,
"positive": 2,
}
def canonicalize_pav(
pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
) -> Float[torch.Tensor, "batch pairs metrics"]:
if pav.ndim == 4:
if pav.shape[1] != 1:
raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}")
return pav.squeeze(1)
if pav.ndim != 3:
raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}")
return pav