feat: add drf author checkpoint compatibility bundle
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -21,6 +23,11 @@ from ..modules import (
|
||||
class DRF(BaseModelBody):
|
||||
"""Dual Representation Framework from arXiv:2509.00872v1."""
|
||||
|
||||
LEGACY_STATE_PREFIXES: dict[str, str] = {
|
||||
"attention_layer.": "PGA.",
|
||||
}
|
||||
CANONICAL_LABEL_ORDER: tuple[str, str, str] = ("negative", "neutral", "positive")
|
||||
|
||||
def build_network(self, model_cfg: dict[str, Any]) -> None:
|
||||
self.Backbone = self.get_backbone(model_cfg["backbone_cfg"])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
@@ -34,6 +41,13 @@ class DRF(BaseModelBody):
|
||||
num_pairs=model_cfg.get("num_pairs", 8),
|
||||
num_metrics=model_cfg.get("num_metrics", 3),
|
||||
)
|
||||
self.label_order = resolve_label_order(model_cfg.get("label_order"))
|
||||
self.label_map = {label: idx for idx, label in enumerate(self.label_order)}
|
||||
self.canonical_inference_logits = bool(model_cfg.get("canonical_inference_logits", True))
|
||||
self.logit_canonical_indices = torch.tensor(
|
||||
[self.label_map[label] for label in self.CANONICAL_LABEL_ORDER],
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -48,7 +62,7 @@ class DRF(BaseModelBody):
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
ipts, pids, labels, _, seqL, key_features = inputs
|
||||
label_ids = torch.as_tensor(
|
||||
[LABEL_MAP[str(label).lower()] for label in labels],
|
||||
[self.label_map[str(label).lower()] for label in labels],
|
||||
device=pids.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
@@ -69,6 +83,7 @@ class DRF(BaseModelBody):
|
||||
|
||||
embed_2, logits = self.BNNecks(embed_1)
|
||||
del embed_2
|
||||
inference_logits = self.canonicalize_logits(logits)
|
||||
|
||||
return {
|
||||
"training_feat": {
|
||||
@@ -79,10 +94,52 @@ class DRF(BaseModelBody):
|
||||
"image/sils": rearrange(heatmaps, "n c s h w -> (n s) c h w"),
|
||||
},
|
||||
"inference_feat": {
|
||||
"embeddings": logits,
|
||||
"embeddings": inference_logits,
|
||||
},
|
||||
}
|
||||
|
||||
def canonicalize_logits(
|
||||
self,
|
||||
logits: Float[torch.Tensor, "batch classes parts"],
|
||||
) -> Float[torch.Tensor, "batch classes parts"]:
|
||||
if not self.canonical_inference_logits or tuple(self.label_order) == self.CANONICAL_LABEL_ORDER:
|
||||
return logits
|
||||
indices = self.logit_canonical_indices.to(device=logits.device)
|
||||
return logits.index_select(dim=1, index=indices)
|
||||
|
||||
@classmethod
|
||||
def remap_legacy_state_dict(
|
||||
cls,
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
) -> OrderedDict[str, torch.Tensor]:
|
||||
"""Map older author checkpoint names onto the current DRF module tree."""
|
||||
|
||||
remapped_state = OrderedDict[str, torch.Tensor]()
|
||||
for key, value in state_dict.items():
|
||||
remapped_key = key
|
||||
for old_prefix, new_prefix in cls.LEGACY_STATE_PREFIXES.items():
|
||||
if remapped_key.startswith(old_prefix):
|
||||
remapped_key = new_prefix + remapped_key[len(old_prefix) :]
|
||||
break
|
||||
remapped_state[remapped_key] = value
|
||||
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
if metadata is not None:
|
||||
setattr(remapped_state, "_metadata", metadata)
|
||||
return remapped_state
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
strict: bool = True,
|
||||
assign: bool = False,
|
||||
) -> Any:
|
||||
return super().load_state_dict(
|
||||
self.remap_legacy_state_dict(state_dict),
|
||||
strict=strict,
|
||||
assign=assign,
|
||||
)
|
||||
|
||||
|
||||
class PAVGuidedAttention(nn.Module):
|
||||
channel_att: nn.Sequential
|
||||
@@ -119,11 +176,25 @@ class PAVGuidedAttention(nn.Module):
|
||||
return embeddings * channel_att * spatial_att
|
||||
|
||||
|
||||
LABEL_MAP: dict[str, int] = {
|
||||
"negative": 0,
|
||||
"neutral": 1,
|
||||
"positive": 2,
|
||||
}
|
||||
def resolve_label_order(label_order_cfg: Any) -> tuple[str, str, str]:
|
||||
if label_order_cfg is None:
|
||||
return DRF.CANONICAL_LABEL_ORDER
|
||||
|
||||
if not isinstance(label_order_cfg, list | tuple):
|
||||
raise TypeError(
|
||||
"DRF model_cfg.label_order must be a list/tuple of "
|
||||
"['negative', 'neutral', 'positive'] in the desired logit order."
|
||||
)
|
||||
|
||||
normalized_order = tuple(str(label).lower() for label in label_order_cfg)
|
||||
expected = set(DRF.CANONICAL_LABEL_ORDER)
|
||||
if len(normalized_order) != 3 or set(normalized_order) != expected:
|
||||
raise ValueError(
|
||||
"DRF model_cfg.label_order must contain exactly "
|
||||
"negative/neutral/positive once each; "
|
||||
f"got {label_order_cfg!r}"
|
||||
)
|
||||
return normalized_order
|
||||
|
||||
|
||||
def canonicalize_pav(
|
||||
|
||||
Reference in New Issue
Block a user