83 lines
3.4 KiB
Python
83 lines
3.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
from .lidargaitv2_utils import PointNetSetAbstraction, PPPooling, PPPooling_UDP,NetVLAD
|
|
|
|
from ..base_model import BaseModel
|
|
from ..modules import SeparateFCs, SeparateBNNecks
|
|
|
|
|
|
class LidarGaitPlusPlus(BaseModel):
|
|
def build_network(self, model_cfg):
|
|
C = model_cfg['channel']
|
|
C_out = model_cfg['SeparateFCs']['in_channels']
|
|
scale_aware = model_cfg['scale_aware']
|
|
normalize_dp = model_cfg['normalize_dp']
|
|
sampling = model_cfg['sampling']
|
|
|
|
npoints = model_cfg.get('npoints', [512, 256, 128])
|
|
nsample = model_cfg.get('nsample', 32)
|
|
in_channel = 4 if scale_aware else 3
|
|
|
|
self.sa1 = PointNetSetAbstraction(npoint=npoints[0], radius=0.1, nsample=nsample, in_channel=in_channel, mlp=[2*C, 2*C, 4*C], group_all=False, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
|
|
self.sa2 = PointNetSetAbstraction(npoint=npoints[1], radius=0.2, nsample=nsample, in_channel=4*C + in_channel, mlp=[4*C, 4*C, 8*C], group_all=False, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
|
|
self.sa3 = PointNetSetAbstraction(npoint=npoints[2], radius=0.4, nsample=nsample, in_channel=8*C + in_channel, mlp=[8*C, 8*C, 16*C], group_all=False, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
|
|
self.sa4 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=16*C + in_channel, mlp=[16*C, 16*C, C_out], group_all=True, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
|
|
|
|
if model_cfg['pool'] == 'VLAD':
|
|
self.pool = NetVLAD(num_clusters=16, dim=C_out, alpha=1.0)
|
|
elif model_cfg['pool'] == 'GMaxP':
|
|
self.pool = PPPooling_UDP([1])
|
|
elif model_cfg['pool'] == 'PPP_UDP':
|
|
self.pool = PPPooling_UDP(model_cfg['scale'])
|
|
elif model_cfg['pool'] == 'PPP_UAP':
|
|
self.pool = PPPooling(scale_aware=False, bin_num=model_cfg['scale'])
|
|
elif model_cfg['pool'] == 'PPP_HAP':
|
|
self.pool = PPPooling(scale_aware=True, bin_num=model_cfg['scale'])
|
|
|
|
|
|
|
|
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
|
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
|
|
|
|
|
def forward(self, inputs):
|
|
ipts, labs, _, views, seqL = inputs
|
|
|
|
xyz = ipts[0]
|
|
B, T, N, C = xyz.shape
|
|
xyz = rearrange(xyz, 'B T N C -> (B T) C N')
|
|
|
|
|
|
l1_xyz, l1_points = self.sa1(xyz, None)
|
|
l1_points = torch.max(l1_points, dim=-2)[0]
|
|
|
|
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
|
|
l2_points = torch.max(l2_points, dim=-2)[0]
|
|
|
|
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
|
|
l3_points = torch.max(l3_points, dim=-2)[0]
|
|
|
|
l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
|
|
|
|
x = self.pool(l4_points, l3_xyz)
|
|
|
|
|
|
x = rearrange(x, '(B T) feat p -> B T feat p', B=B)
|
|
feat = x.max(1)[0]# x.mean(1) # x.max(1)[0]
|
|
embed = self.FCs(feat) # [n, c, p]
|
|
embed_2, logits = self.BNNecks(embed) # [n, c, p]
|
|
retval = {
|
|
'training_feat': {
|
|
'triplet': {'embeddings': embed, 'labels': labs},
|
|
'softmax': {'logits': logits, 'labels': labs}
|
|
},
|
|
'visual_summary': {
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': embed,
|
|
}
|
|
}
|
|
return retval |