Align DRF skeleton preprocessing with upstream heatmap path
This commit is contained in:
@@ -47,7 +47,9 @@ class BaseModelBody(BaseModel):
|
||||
labs = list2var(labs_batch).long()
|
||||
seqL = np2var(seqL_batch).int() if seqL_batch is not None else None
|
||||
|
||||
body_features = aggregate_body_features(body_seq, seqL)
|
||||
# Preserve a singleton modality axis so DRF can mirror the author stub's
|
||||
# `squeeze(1)` behavior while still accepting the same sequence-level prior.
|
||||
body_features = aggregate_body_features(body_seq, seqL).unsqueeze(1)
|
||||
|
||||
if seqL is not None:
|
||||
seqL_sum = int(seqL.sum().data.cpu().numpy())
|
||||
@@ -80,3 +82,7 @@ def aggregate_body_features(
|
||||
aggregated.append(flattened[start:end].mean(dim=0))
|
||||
start = end
|
||||
return torch.stack(aggregated, dim=0)
|
||||
|
||||
|
||||
# Match the symbol name used by the author-provided DRF stub.
|
||||
BaseModel = BaseModelBody
|
||||
|
||||
@@ -43,7 +43,7 @@ class DRF(BaseModelBody):
|
||||
list[str],
|
||||
list[str],
|
||||
Int[torch.Tensor, "1 batch"] | None,
|
||||
Float[torch.Tensor, "batch pairs metrics"],
|
||||
Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
|
||||
],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
ipts, pids, labels, _, seqL, key_features = inputs
|
||||
@@ -64,6 +64,7 @@ class DRF(BaseModelBody):
|
||||
|
||||
feat = self.HPP(outs)
|
||||
embed_1 = self.FCs(feat)
|
||||
key_features = canonicalize_pav(key_features)
|
||||
embed_1 = self.PGA(embed_1, key_features)
|
||||
|
||||
embed_2, logits = self.BNNecks(embed_1)
|
||||
@@ -123,3 +124,15 @@ LABEL_MAP: dict[str, int] = {
|
||||
"neutral": 1,
|
||||
"positive": 2,
|
||||
}
|
||||
|
||||
|
||||
def canonicalize_pav(
|
||||
pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
|
||||
) -> Float[torch.Tensor, "batch pairs metrics"]:
|
||||
if pav.ndim == 4:
|
||||
if pav.shape[1] != 1:
|
||||
raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}")
|
||||
return pav.squeeze(1)
|
||||
if pav.ndim != 3:
|
||||
raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}")
|
||||
return pav
|
||||
|
||||
Reference in New Issue
Block a user