from __future__ import annotations from typing import Any import torch import torch.nn as nn from jaxtyping import Float, Int from einops import rearrange from ..base_model_body import BaseModelBody from ..modules import ( HorizontalPoolingPyramid, PackSequenceWrapper, SeparateBNNecks, SeparateFCs, SetBlockWrapper, ) class DRF(BaseModelBody): """Dual Representation Framework from arXiv:2509.00872v1.""" def build_network(self, model_cfg: dict[str, Any]) -> None: self.Backbone = self.get_backbone(model_cfg["backbone_cfg"]) self.Backbone = SetBlockWrapper(self.Backbone) self.FCs = SeparateFCs(**model_cfg["SeparateFCs"]) self.BNNecks = SeparateBNNecks(**model_cfg["SeparateBNNecks"]) self.TP = PackSequenceWrapper(torch.max) self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg["bin_num"]) self.PGA = PAVGuidedAttention( in_channels=model_cfg["SeparateFCs"]["out_channels"], parts_num=model_cfg["SeparateFCs"]["parts_num"], num_pairs=model_cfg.get("num_pairs", 8), num_metrics=model_cfg.get("num_metrics", 3), ) def forward( self, inputs: tuple[ list[torch.Tensor], Int[torch.Tensor, "batch"], list[str], list[str], Int[torch.Tensor, "1 batch"] | None, Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"], ], ) -> dict[str, dict[str, Any]]: ipts, pids, labels, _, seqL, key_features = inputs label_ids = torch.as_tensor( [LABEL_MAP[str(label).lower()] for label in labels], device=pids.device, dtype=torch.long, ) heatmaps = ipts[0] if heatmaps.ndim == 4: heatmaps = heatmaps.unsqueeze(1) else: heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w") outs = self.Backbone(heatmaps) outs = self.TP(outs, seqL, options={"dim": 2})[0] feat = self.HPP(outs) embed_1 = self.FCs(feat) key_features = canonicalize_pav(key_features) embed_1 = self.PGA(embed_1, key_features) embed_2, logits = self.BNNecks(embed_1) del embed_2 return { "training_feat": { "triplet": {"embeddings": embed_1, "labels": pids}, "softmax": {"logits": logits, "labels": label_ids}, }, "visual_summary": { "image/sils": rearrange(heatmaps, "n c s h w -> (n s) c h w"), }, "inference_feat": { "embeddings": logits, }, } class PAVGuidedAttention(nn.Module): channel_att: nn.Sequential spatial_att: nn.Sequential def __init__( self, in_channels: int = 256, parts_num: int = 16, num_pairs: int = 8, num_metrics: int = 3, ) -> None: super().__init__() pav_dim = num_pairs * num_metrics self.channel_att = nn.Sequential( nn.Linear(pav_dim, in_channels), nn.Sigmoid(), ) self.spatial_att = nn.Sequential( nn.Conv1d(pav_dim, parts_num, kernel_size=1), nn.Sigmoid(), ) def forward( self, embeddings: Float[torch.Tensor, "batch channels parts"], pav: Float[torch.Tensor, "batch pairs metrics"], ) -> Float[torch.Tensor, "batch channels parts"]: pav_flat = pav.flatten(1) channel_att = self.channel_att(pav_flat).unsqueeze(-1) spatial_att = self.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) return embeddings * channel_att * spatial_att LABEL_MAP: dict[str, int] = { "negative": 0, "neutral": 1, "positive": 2, } def canonicalize_pav( pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"], ) -> Float[torch.Tensor, "batch pairs metrics"]: if pav.ndim == 4: if pav.shape[1] != 1: raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}") return pav.squeeze(1) if pav.ndim != 3: raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}") return pav