from __future__ import annotations from collections import OrderedDict from collections.abc import Mapping from typing import Any import torch import torch.nn as nn from jaxtyping import Float, Int from einops import rearrange from ..base_model_body import BaseModelBody from ..modules import ( HorizontalPoolingPyramid, PackSequenceWrapper, SeparateBNNecks, SeparateFCs, SetBlockWrapper, ) class DRF(BaseModelBody): """Dual Representation Framework from arXiv:2509.00872v1.""" LEGACY_STATE_PREFIXES: dict[str, str] = { "attention_layer.": "PGA.", } CANONICAL_LABEL_ORDER: tuple[str, str, str] = ("negative", "neutral", "positive") def build_network(self, model_cfg: dict[str, Any]) -> None: self.Backbone = self.get_backbone(model_cfg["backbone_cfg"]) self.Backbone = SetBlockWrapper(self.Backbone) self.FCs = SeparateFCs(**model_cfg["SeparateFCs"]) self.BNNecks = SeparateBNNecks(**model_cfg["SeparateBNNecks"]) self.TP = PackSequenceWrapper(torch.max) self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg["bin_num"]) self.PGA = PAVGuidedAttention( in_channels=model_cfg["SeparateFCs"]["out_channels"], parts_num=model_cfg["SeparateFCs"]["parts_num"], num_pairs=model_cfg.get("num_pairs", 8), num_metrics=model_cfg.get("num_metrics", 3), ) self.label_order = resolve_label_order(model_cfg.get("label_order")) self.label_map = {label: idx for idx, label in enumerate(self.label_order)} self.canonical_inference_logits = bool(model_cfg.get("canonical_inference_logits", True)) self.logit_canonical_indices = torch.tensor( [self.label_map[label] for label in self.CANONICAL_LABEL_ORDER], dtype=torch.long, ) def forward( self, inputs: tuple[ list[torch.Tensor], Int[torch.Tensor, "batch"], list[str], list[str], Int[torch.Tensor, "1 batch"] | None, Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"], ], ) -> dict[str, dict[str, Any]]: ipts, pids, labels, _, seqL, key_features = inputs label_ids = torch.as_tensor( [self.label_map[str(label).lower()] for label in labels], device=pids.device, dtype=torch.long, ) heatmaps = ipts[0] if heatmaps.ndim == 4: heatmaps = heatmaps.unsqueeze(1) else: heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w") outs = self.Backbone(heatmaps) outs = self.TP(outs, seqL, options={"dim": 2})[0] feat = self.HPP(outs) embed_1 = self.FCs(feat) key_features = canonicalize_pav(key_features) embed_1 = self.PGA(embed_1, key_features) embed_2, logits = self.BNNecks(embed_1) del embed_2 inference_logits = self.canonicalize_logits(logits) return { "training_feat": { "triplet": {"embeddings": embed_1, "labels": pids}, "softmax": {"logits": logits, "labels": label_ids}, }, "visual_summary": { "image/sils": rearrange(heatmaps, "n c s h w -> (n s) c h w"), }, "inference_feat": { "embeddings": inference_logits, }, } def canonicalize_logits( self, logits: Float[torch.Tensor, "batch classes parts"], ) -> Float[torch.Tensor, "batch classes parts"]: if not self.canonical_inference_logits or tuple(self.label_order) == self.CANONICAL_LABEL_ORDER: return logits indices = self.logit_canonical_indices.to(device=logits.device) return logits.index_select(dim=1, index=indices) @classmethod def remap_legacy_state_dict( cls, state_dict: Mapping[str, torch.Tensor], ) -> OrderedDict[str, torch.Tensor]: """Map older author checkpoint names onto the current DRF module tree.""" remapped_state = OrderedDict[str, torch.Tensor]() for key, value in state_dict.items(): remapped_key = key for old_prefix, new_prefix in cls.LEGACY_STATE_PREFIXES.items(): if remapped_key.startswith(old_prefix): remapped_key = new_prefix + remapped_key[len(old_prefix) :] break remapped_state[remapped_key] = value metadata = getattr(state_dict, "_metadata", None) if metadata is not None: setattr(remapped_state, "_metadata", metadata) return remapped_state def load_state_dict( self, state_dict: Mapping[str, torch.Tensor], strict: bool = True, assign: bool = False, ) -> Any: return super().load_state_dict( self.remap_legacy_state_dict(state_dict), strict=strict, assign=assign, ) class PAVGuidedAttention(nn.Module): channel_att: nn.Sequential spatial_att: nn.Sequential def __init__( self, in_channels: int = 256, parts_num: int = 16, num_pairs: int = 8, num_metrics: int = 3, ) -> None: super().__init__() pav_dim = num_pairs * num_metrics self.channel_att = nn.Sequential( nn.Linear(pav_dim, in_channels), nn.Sigmoid(), ) self.spatial_att = nn.Sequential( nn.Conv1d(pav_dim, parts_num, kernel_size=1), nn.Sigmoid(), ) def forward( self, embeddings: Float[torch.Tensor, "batch channels parts"], pav: Float[torch.Tensor, "batch pairs metrics"], ) -> Float[torch.Tensor, "batch channels parts"]: pav_flat = pav.flatten(1) channel_att = self.channel_att(pav_flat).unsqueeze(-1) spatial_att = self.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) return embeddings * channel_att * spatial_att def resolve_label_order(label_order_cfg: Any) -> tuple[str, str, str]: if label_order_cfg is None: return DRF.CANONICAL_LABEL_ORDER if not isinstance(label_order_cfg, list | tuple): raise TypeError( "DRF model_cfg.label_order must be a list/tuple of " "['negative', 'neutral', 'positive'] in the desired logit order." ) normalized_order = tuple(str(label).lower() for label in label_order_cfg) expected = set(DRF.CANONICAL_LABEL_ORDER) if len(normalized_order) != 3 or set(normalized_order) != expected: raise ValueError( "DRF model_cfg.label_order must contain exactly " "negative/neutral/positive once each; " f"got {label_order_cfg!r}" ) return normalized_order def canonicalize_pav( pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"], ) -> Float[torch.Tensor, "batch pairs metrics"]: if pav.ndim == 4: if pav.shape[1] != 1: raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}") return pav.squeeze(1) if pav.ndim != 3: raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}") return pav