14fa5212d4
* commit for fix dimension * fix dimension for all method * restore config * clean up baseline config * add contiguous * rm comment
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import torch
|
|
|
|
from ..base_model import BaseModel
|
|
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
|
|
|
|
|
class Baseline(BaseModel):
|
|
|
|
def build_network(self, model_cfg):
|
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
|
self.Backbone = SetBlockWrapper(self.Backbone)
|
|
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
|
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
|
self.TP = PackSequenceWrapper(torch.max)
|
|
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
|
|
|
def forward(self, inputs):
|
|
ipts, labs, _, _, seqL = inputs
|
|
|
|
sils = ipts[0]
|
|
if len(sils.size()) == 4:
|
|
sils = sils.unsqueeze(1)
|
|
|
|
del ipts
|
|
outs = self.Backbone(sils) # [n, c, s, h, w]
|
|
|
|
# Temporal Pooling, TP
|
|
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
|
# Horizontal Pooling Matching, HPM
|
|
feat = self.HPP(outs) # [n, c, p]
|
|
|
|
embed_1 = self.FCs(feat) # [n, c, p]
|
|
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
|
embed = embed_1
|
|
|
|
n, _, s, h, w = sils.size()
|
|
retval = {
|
|
'training_feat': {
|
|
'triplet': {'embeddings': embed_1, 'labels': labs},
|
|
'softmax': {'logits': logits, 'labels': labs}
|
|
},
|
|
'visual_summary': {
|
|
'image/sils': sils.view(n*s, 1, h, w)
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': embed
|
|
}
|
|
}
|
|
return retval
|