Files
OpenGait/opengait/modeling/models/drf.py
T

210 lines
7.2 KiB
Python

from __future__ import annotations
from collections import OrderedDict
from collections.abc import Mapping
from typing import Any
import torch
import torch.nn as nn
from jaxtyping import Float, Int
from einops import rearrange
from ..base_model_body import BaseModelBody
from ..modules import (
HorizontalPoolingPyramid,
PackSequenceWrapper,
SeparateBNNecks,
SeparateFCs,
SetBlockWrapper,
)
class DRF(BaseModelBody):
"""Dual Representation Framework from arXiv:2509.00872v1."""
LEGACY_STATE_PREFIXES: dict[str, str] = {
"attention_layer.": "PGA.",
}
CANONICAL_LABEL_ORDER: tuple[str, str, str] = ("negative", "neutral", "positive")
def build_network(self, model_cfg: dict[str, Any]) -> None:
self.Backbone = self.get_backbone(model_cfg["backbone_cfg"])
self.Backbone = SetBlockWrapper(self.Backbone)
self.FCs = SeparateFCs(**model_cfg["SeparateFCs"])
self.BNNecks = SeparateBNNecks(**model_cfg["SeparateBNNecks"])
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg["bin_num"])
self.PGA = PAVGuidedAttention(
in_channels=model_cfg["SeparateFCs"]["out_channels"],
parts_num=model_cfg["SeparateFCs"]["parts_num"],
num_pairs=model_cfg.get("num_pairs", 8),
num_metrics=model_cfg.get("num_metrics", 3),
)
self.label_order = resolve_label_order(model_cfg.get("label_order"))
self.label_map = {label: idx for idx, label in enumerate(self.label_order)}
self.canonical_inference_logits = bool(model_cfg.get("canonical_inference_logits", True))
self.logit_canonical_indices = torch.tensor(
[self.label_map[label] for label in self.CANONICAL_LABEL_ORDER],
dtype=torch.long,
)
def forward(
self,
inputs: tuple[
list[torch.Tensor],
Int[torch.Tensor, "batch"],
list[str],
list[str],
Int[torch.Tensor, "1 batch"] | None,
Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
],
) -> dict[str, dict[str, Any]]:
ipts, pids, labels, _, seqL, key_features = inputs
label_ids = torch.as_tensor(
[self.label_map[str(label).lower()] for label in labels],
device=pids.device,
dtype=torch.long,
)
heatmaps = ipts[0]
if heatmaps.ndim == 4:
heatmaps = heatmaps.unsqueeze(1)
else:
heatmaps = rearrange(heatmaps, "n s c h w -> n c s h w")
outs = self.Backbone(heatmaps)
outs = self.TP(outs, seqL, options={"dim": 2})[0]
feat = self.HPP(outs)
embed_1 = self.FCs(feat)
key_features = canonicalize_pav(key_features)
embed_1 = self.PGA(embed_1, key_features)
embed_2, logits = self.BNNecks(embed_1)
del embed_2
inference_logits = self.canonicalize_logits(logits)
return {
"training_feat": {
"triplet": {"embeddings": embed_1, "labels": pids},
"softmax": {"logits": logits, "labels": label_ids},
},
"visual_summary": {
"image/sils": rearrange(heatmaps, "n c s h w -> (n s) c h w"),
},
"inference_feat": {
"embeddings": inference_logits,
},
}
def canonicalize_logits(
self,
logits: Float[torch.Tensor, "batch classes parts"],
) -> Float[torch.Tensor, "batch classes parts"]:
if not self.canonical_inference_logits or tuple(self.label_order) == self.CANONICAL_LABEL_ORDER:
return logits
indices = self.logit_canonical_indices.to(device=logits.device)
return logits.index_select(dim=1, index=indices)
@classmethod
def remap_legacy_state_dict(
cls,
state_dict: Mapping[str, torch.Tensor],
) -> OrderedDict[str, torch.Tensor]:
"""Map older author checkpoint names onto the current DRF module tree."""
remapped_state = OrderedDict[str, torch.Tensor]()
for key, value in state_dict.items():
remapped_key = key
for old_prefix, new_prefix in cls.LEGACY_STATE_PREFIXES.items():
if remapped_key.startswith(old_prefix):
remapped_key = new_prefix + remapped_key[len(old_prefix) :]
break
remapped_state[remapped_key] = value
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
setattr(remapped_state, "_metadata", metadata)
return remapped_state
def load_state_dict(
self,
state_dict: Mapping[str, torch.Tensor],
strict: bool = True,
assign: bool = False,
) -> Any:
return super().load_state_dict(
self.remap_legacy_state_dict(state_dict),
strict=strict,
assign=assign,
)
class PAVGuidedAttention(nn.Module):
channel_att: nn.Sequential
spatial_att: nn.Sequential
def __init__(
self,
in_channels: int = 256,
parts_num: int = 16,
num_pairs: int = 8,
num_metrics: int = 3,
) -> None:
super().__init__()
pav_dim = num_pairs * num_metrics
self.channel_att = nn.Sequential(
nn.Linear(pav_dim, in_channels),
nn.Sigmoid(),
)
self.spatial_att = nn.Sequential(
nn.Conv1d(pav_dim, parts_num, kernel_size=1),
nn.Sigmoid(),
)
def forward(
self,
embeddings: Float[torch.Tensor, "batch channels parts"],
pav: Float[torch.Tensor, "batch pairs metrics"],
) -> Float[torch.Tensor, "batch channels parts"]:
pav_flat = pav.flatten(1)
channel_att = self.channel_att(pav_flat).unsqueeze(-1)
spatial_att = self.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
return embeddings * channel_att * spatial_att
def resolve_label_order(label_order_cfg: Any) -> tuple[str, str, str]:
if label_order_cfg is None:
return DRF.CANONICAL_LABEL_ORDER
if not isinstance(label_order_cfg, list | tuple):
raise TypeError(
"DRF model_cfg.label_order must be a list/tuple of "
"['negative', 'neutral', 'positive'] in the desired logit order."
)
normalized_order = tuple(str(label).lower() for label in label_order_cfg)
expected = set(DRF.CANONICAL_LABEL_ORDER)
if len(normalized_order) != 3 or set(normalized_order) != expected:
raise ValueError(
"DRF model_cfg.label_order must contain exactly "
"negative/neutral/positive once each; "
f"got {label_order_cfg!r}"
)
return normalized_order
def canonicalize_pav(
pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
) -> Float[torch.Tensor, "batch pairs metrics"]:
if pav.ndim == 4:
if pav.shape[1] != 1:
raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}")
return pav.squeeze(1)
if pav.ndim != 3:
raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}")
return pav