from __future__ import annotations from typing import Any import torch import torch.nn as nn from jaxtyping import Float, Int from einops import rearrange from ..base_model import BaseModel from ..modules import ( HorizontalPoolingPyramid, PackSequenceWrapper, SeparateBNNecks, SeparateFCs, SetBlockWrapper, ) class DRF(BaseModel): """Dual Representation Framework from arXiv:2509.00872v1.""" def build_network(self, model_cfg: dict[str, Any]) -> None: self.Backbone = self.get_backbone(model_cfg["backbone_cfg"]) self.Backbone = SetBlockWrapper(self.Backbone) self.FCs = SeparateFCs(**model_cfg["SeparateFCs"]) self.BNNecks = SeparateBNNecks(**model_cfg["SeparateBNNecks"]) self.TP = PackSequenceWrapper(torch.max) self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg["bin_num"]) self.PGA = PAVGuidedAttention( in_channels=model_cfg["SeparateFCs"]["out_channels"], parts_num=model_cfg["SeparateFCs"]["parts_num"], num_pairs=model_cfg.get("num_pairs", 8), num_metrics=model_cfg.get("num_metrics", 3), ) def forward( self, inputs: tuple[ list[torch.Tensor], Int[torch.Tensor, "batch"], list[str], list[str], Int[torch.Tensor, "1 batch"] | None, ], ) -> dict[str, dict[str, Any]]: ipts, pids, labels, _, seqL = inputs label_ids = torch.as_tensor( [LABEL_MAP[str(label).lower()] for label in labels], device=pids.device, dtype=torch.long, ) heatmaps = ipts[0] if heatmaps.ndim == 4: heatmaps = heatmaps.unsqueeze(1) else: heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w") pav_seq = ipts[1] pav = aggregate_sequence_features(pav_seq, seqL) outs = self.Backbone(heatmaps) outs = self.TP(outs, seqL, options={"dim": 2})[0] feat = self.HPP(outs) embed_1 = self.FCs(feat) embed_1 = self.PGA(embed_1, pav) embed_2, logits = self.BNNecks(embed_1) del embed_2 return { "training_feat": { "triplet": {"embeddings": embed_1, "labels": pids}, "softmax": {"logits": logits, "labels": label_ids}, }, "visual_summary": { "image/sils": rearrange(heatmaps, "n c s h w -> (n s) c h w"), }, "inference_feat": { "embeddings": logits, }, } class PAVGuidedAttention(nn.Module): channel_att: nn.Sequential spatial_att: nn.Sequential def __init__( self, in_channels: int = 256, parts_num: int = 16, num_pairs: int = 8, num_metrics: int = 3, ) -> None: super().__init__() pav_dim = num_pairs * num_metrics self.channel_att = nn.Sequential( nn.Linear(pav_dim, in_channels), nn.Sigmoid(), ) self.spatial_att = nn.Sequential( nn.Conv1d(pav_dim, parts_num, kernel_size=1), nn.Sigmoid(), ) def forward( self, embeddings: Float[torch.Tensor, "batch channels parts"], pav: Float[torch.Tensor, "batch pairs metrics"], ) -> Float[torch.Tensor, "batch channels parts"]: pav_flat = pav.flatten(1) channel_att = self.channel_att(pav_flat).unsqueeze(-1) spatial_att = self.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) return embeddings * channel_att * spatial_att def aggregate_sequence_features( sequence_features: Float[torch.Tensor, "batch seq pairs metrics"], seqL: Int[torch.Tensor, "1 batch"] | None, ) -> Float[torch.Tensor, "batch pairs metrics"]: if seqL is None: return sequence_features.mean(dim=1) lengths = seqL[0].tolist() flattened = sequence_features.squeeze(0) aggregated = [] start = 0 for length in lengths: end = start + int(length) aggregated.append(flattened[start:end].mean(dim=0)) start = end return torch.stack(aggregated, dim=0) LABEL_MAP: dict[str, int] = { "negative": 0, "neutral": 1, "positive": 2, }