feat: add drf author checkpoint compatibility bundle

This commit is contained in:
2026-03-14 17:12:27 +08:00
parent d4e2a59ad2
commit 5f98844aff
18 changed files with 1144 additions and 8 deletions
+78 -7
View File
@@ -1,5 +1,7 @@
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Mapping
from typing import Any
import torch
@@ -21,6 +23,11 @@ from ..modules import (
class DRF(BaseModelBody):
"""Dual Representation Framework from arXiv:2509.00872v1."""
LEGACY_STATE_PREFIXES: dict[str, str] = {
"attention_layer.": "PGA.",
}
CANONICAL_LABEL_ORDER: tuple[str, str, str] = ("negative", "neutral", "positive")
def build_network(self, model_cfg: dict[str, Any]) -> None:
self.Backbone = self.get_backbone(model_cfg["backbone_cfg"])
self.Backbone = SetBlockWrapper(self.Backbone)
@@ -34,6 +41,13 @@ class DRF(BaseModelBody):
num_pairs=model_cfg.get("num_pairs", 8),
num_metrics=model_cfg.get("num_metrics", 3),
)
self.label_order = resolve_label_order(model_cfg.get("label_order"))
self.label_map = {label: idx for idx, label in enumerate(self.label_order)}
self.canonical_inference_logits = bool(model_cfg.get("canonical_inference_logits", True))
self.logit_canonical_indices = torch.tensor(
[self.label_map[label] for label in self.CANONICAL_LABEL_ORDER],
dtype=torch.long,
)
def forward(
self,
@@ -48,7 +62,7 @@ class DRF(BaseModelBody):
) -> dict[str, dict[str, Any]]:
ipts, pids, labels, _, seqL, key_features = inputs
label_ids = torch.as_tensor(
[LABEL_MAP[str(label).lower()] for label in labels],
[self.label_map[str(label).lower()] for label in labels],
device=pids.device,
dtype=torch.long,
)
@@ -69,6 +83,7 @@ class DRF(BaseModelBody):
embed_2, logits = self.BNNecks(embed_1)
del embed_2
inference_logits = self.canonicalize_logits(logits)
return {
"training_feat": {
@@ -79,10 +94,52 @@ class DRF(BaseModelBody):
"image/sils": rearrange(heatmaps, "n c s h w -> (n s) c h w"),
},
"inference_feat": {
"embeddings": logits,
"embeddings": inference_logits,
},
}
def canonicalize_logits(
self,
logits: Float[torch.Tensor, "batch classes parts"],
) -> Float[torch.Tensor, "batch classes parts"]:
if not self.canonical_inference_logits or tuple(self.label_order) == self.CANONICAL_LABEL_ORDER:
return logits
indices = self.logit_canonical_indices.to(device=logits.device)
return logits.index_select(dim=1, index=indices)
@classmethod
def remap_legacy_state_dict(
cls,
state_dict: Mapping[str, torch.Tensor],
) -> OrderedDict[str, torch.Tensor]:
"""Map older author checkpoint names onto the current DRF module tree."""
remapped_state = OrderedDict[str, torch.Tensor]()
for key, value in state_dict.items():
remapped_key = key
for old_prefix, new_prefix in cls.LEGACY_STATE_PREFIXES.items():
if remapped_key.startswith(old_prefix):
remapped_key = new_prefix + remapped_key[len(old_prefix) :]
break
remapped_state[remapped_key] = value
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
setattr(remapped_state, "_metadata", metadata)
return remapped_state
def load_state_dict(
self,
state_dict: Mapping[str, torch.Tensor],
strict: bool = True,
assign: bool = False,
) -> Any:
return super().load_state_dict(
self.remap_legacy_state_dict(state_dict),
strict=strict,
assign=assign,
)
class PAVGuidedAttention(nn.Module):
channel_att: nn.Sequential
@@ -119,11 +176,25 @@ class PAVGuidedAttention(nn.Module):
return embeddings * channel_att * spatial_att
LABEL_MAP: dict[str, int] = {
"negative": 0,
"neutral": 1,
"positive": 2,
}
def resolve_label_order(label_order_cfg: Any) -> tuple[str, str, str]:
if label_order_cfg is None:
return DRF.CANONICAL_LABEL_ORDER
if not isinstance(label_order_cfg, list | tuple):
raise TypeError(
"DRF model_cfg.label_order must be a list/tuple of "
"['negative', 'neutral', 'positive'] in the desired logit order."
)
normalized_order = tuple(str(label).lower() for label in label_order_cfg)
expected = set(DRF.CANONICAL_LABEL_ORDER)
if len(normalized_order) != 3 or set(normalized_order) != expected:
raise ValueError(
"DRF model_cfg.label_order must contain exactly "
"negative/neutral/positive once each; "
f"got {label_order_cfg!r}"
)
return normalized_order
def canonicalize_pav(