Files
2025-06-11 14:43:19 +08:00

83 lines
3.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .lidargaitv2_utils import PointNetSetAbstraction, PPPooling, PPPooling_UDP,NetVLAD
from ..base_model import BaseModel
from ..modules import SeparateFCs, SeparateBNNecks
class LidarGaitPlusPlus(BaseModel):
def build_network(self, model_cfg):
C = model_cfg['channel']
C_out = model_cfg['SeparateFCs']['in_channels']
scale_aware = model_cfg['scale_aware']
normalize_dp = model_cfg['normalize_dp']
sampling = model_cfg['sampling']
npoints = model_cfg.get('npoints', [512, 256, 128])
nsample = model_cfg.get('nsample', 32)
in_channel = 4 if scale_aware else 3
self.sa1 = PointNetSetAbstraction(npoint=npoints[0], radius=0.1, nsample=nsample, in_channel=in_channel, mlp=[2*C, 2*C, 4*C], group_all=False, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
self.sa2 = PointNetSetAbstraction(npoint=npoints[1], radius=0.2, nsample=nsample, in_channel=4*C + in_channel, mlp=[4*C, 4*C, 8*C], group_all=False, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
self.sa3 = PointNetSetAbstraction(npoint=npoints[2], radius=0.4, nsample=nsample, in_channel=8*C + in_channel, mlp=[8*C, 8*C, 16*C], group_all=False, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
self.sa4 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=16*C + in_channel, mlp=[16*C, 16*C, C_out], group_all=True, sampling=sampling, scale_aware=scale_aware, normalize_dp=normalize_dp)
if model_cfg['pool'] == 'VLAD':
self.pool = NetVLAD(num_clusters=16, dim=C_out, alpha=1.0)
elif model_cfg['pool'] == 'GMaxP':
self.pool = PPPooling_UDP([1])
elif model_cfg['pool'] == 'PPP_UDP':
self.pool = PPPooling_UDP(model_cfg['scale'])
elif model_cfg['pool'] == 'PPP_UAP':
self.pool = PPPooling(scale_aware=False, bin_num=model_cfg['scale'])
elif model_cfg['pool'] == 'PPP_HAP':
self.pool = PPPooling(scale_aware=True, bin_num=model_cfg['scale'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
def forward(self, inputs):
ipts, labs, _, views, seqL = inputs
xyz = ipts[0]
B, T, N, C = xyz.shape
xyz = rearrange(xyz, 'B T N C -> (B T) C N')
l1_xyz, l1_points = self.sa1(xyz, None)
l1_points = torch.max(l1_points, dim=-2)[0]
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l2_points = torch.max(l2_points, dim=-2)[0]
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
l3_points = torch.max(l3_points, dim=-2)[0]
l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
x = self.pool(l4_points, l3_xyz)
x = rearrange(x, '(B T) feat p -> B T feat p', B=B)
feat = x.max(1)[0]# x.mean(1) # x.max(1)[0]
embed = self.FCs(feat) # [n, c, p]
embed_2, logits = self.BNNecks(embed) # [n, c, p]
retval = {
'training_feat': {
'triplet': {'embeddings': embed, 'labels': labs},
'softmax': {'logits': logits, 'labels': labs}
},
'visual_summary': {
},
'inference_feat': {
'embeddings': embed,
}
}
return retval