Align DRF skeleton preprocessing with upstream heatmap path
This commit is contained in:
@@ -43,7 +43,7 @@ class DRF(BaseModelBody):
|
||||
list[str],
|
||||
list[str],
|
||||
Int[torch.Tensor, "1 batch"] | None,
|
||||
Float[torch.Tensor, "batch pairs metrics"],
|
||||
Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
|
||||
],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
ipts, pids, labels, _, seqL, key_features = inputs
|
||||
@@ -64,6 +64,7 @@ class DRF(BaseModelBody):
|
||||
|
||||
feat = self.HPP(outs)
|
||||
embed_1 = self.FCs(feat)
|
||||
key_features = canonicalize_pav(key_features)
|
||||
embed_1 = self.PGA(embed_1, key_features)
|
||||
|
||||
embed_2, logits = self.BNNecks(embed_1)
|
||||
@@ -123,3 +124,15 @@ LABEL_MAP: dict[str, int] = {
|
||||
"neutral": 1,
|
||||
"positive": 2,
|
||||
}
|
||||
|
||||
|
||||
def canonicalize_pav(
|
||||
pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
|
||||
) -> Float[torch.Tensor, "batch pairs metrics"]:
|
||||
if pav.ndim == 4:
|
||||
if pav.shape[1] != 1:
|
||||
raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}")
|
||||
return pav.squeeze(1)
|
||||
if pav.ndim != 3:
|
||||
raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}")
|
||||
return pav
|
||||
|
||||
Reference in New Issue
Block a user