Files
OpenGait/opengait/modeling/models/lidargaitv2_utils.py
T
2025-06-11 14:43:19 +08:00

378 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from einops import rearrange
from torch.autograd import Variable
import numpy as np
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def ball_query(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
xyz = xyz[:,:,:3]
new_xyz = new_xyz[:,:,:3]
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def knn_query(k, xyz, new_xyz):
"""
Input:
k: number of nearest neighbors to query
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: indices of k-nearest neighbors, [B, S, k]
"""
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
xyz = xyz[:,:,:3]
new_xyz = new_xyz[:,:,:3]
dists = square_distance(new_xyz, xyz)
#scaling_factor = torch.Tensor([1, 1, 0.6]).to(new_xyz.device)
#dists = torch.sum(torch.square(new_xyz.unsqueeze(2) - xyz.unsqueeze(1)) / scaling_factor, dim=-1)
group_idx = dists.sort(dim=-1)[1][:, :, :k]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, sampling='ball',scale_aware=False, normalize_dp=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz[:,:,:3], npoint) # [B, npoint, C]
new_xyz = index_points(xyz, fps_idx)
if sampling == 'ball':
idx = ball_query(radius, nsample, xyz, new_xyz)
elif sampling == 'knn':
idx = knn_query(nsample, xyz, new_xyz)
else:
raise ValueError("Unsupported sampling type. Use 'ball' or 'knn'.")
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if normalize_dp: # and sampling!='knn':
grouped_xyz_norm /= radius
grouped_xyz_norm = grouped_xyz_norm if scale_aware else grouped_xyz_norm[:,:,:,:3]
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points, scale_aware=False):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
grouped_xyz = grouped_xyz if scale_aware else grouped_xyz[:,:,:,:3]
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, sampling='ball', scale_aware=False,normalize_dp=False):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
self.scale_aware = scale_aware
self.normalize_dp = normalize_dp
self.sampling = sampling
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points, scale_aware=self.scale_aware)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, sampling=self.sampling, scale_aware=self.scale_aware,normalize_dp=self.normalize_dp)
# new_xyz: sampled points position data, [B, ], C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = new_points
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
class PPPooling_UDP():
"""
Hierarchically Clustered Point Pooling
"""
def __init__(self, bin_num=None):
if bin_num is None:
bin_num = [16, 8, 4, 2, 1]
self.bin_num = bin_num
def __call__(self, x, xyz):
"""
x : [n, c, h, w]
xyz: [n, 3, p]
ret: [n, c, p]
"""
#print(xyz.shape)
#x = rearrange(x, 'b n c -> b c n 1')
n, c = x.size()[:2]
_, idx = xyz[:, 2, :].sort()
x = x.gather(2, idx.unsqueeze(1).unsqueeze(-1).expand_as(x))
features = []
for b in self.bin_num:
z = x.view(n, c, b, -1)
z = z.mean(-1) + z.max(-1)[0]
features.append(z)
return torch.cat(features, -1)
class PPPooling():
def __init__(self, scale_aware=False, bin_num=None):
# 默认设置多个分辨率的分bin数量
self.bin_num = bin_num if bin_num is not None else [16, 8, 4, 2, 1]
self.scale_aware = scale_aware
def __call__(self, point_clouds, points):
# 调整维度:输入 point_clouds: B x C x N x 1 转换为 B x N x C
# points: B x C x N 转换为 B x N x C
point_clouds = rearrange(point_clouds, 'B C N 1 -> B N C')
points = rearrange(points, 'B C N -> B N C')
B, N, C = point_clouds.shape
if self.scale_aware: # PPPooling_HAP
z = points[:, :, 3] # shape: (B, N)
# 固定的 z 范围例如0 到 2
z_min, z_max = 0.0, 2.0
else:
# PPPooling_UAP
# 使用 points 的第 3 个通道作为 z 坐标,归一化到 [0, 1]
z = points[:, :, 2] # shape: (B, N)
z_min = z.min(dim=1, keepdim=True)[0][0].item()
z_max = z.max(dim=1, keepdim=True)[0][0].item()
z_range = z_max - z_min + 1e-6
z = (z - z_min) / z_range # shape: (B, N)
z_min, z_max = 0.0, 1.0
all_pooled = []
for M in self.bin_num:
# 由于 z 已归一化,直接构造均匀分布的 bin 边界
edges = torch.linspace(z_min, z_max, steps=M + 1, device=point_clouds.device)
# 利用 bucketize 将每个点分配到 [0, M-1] 内的 bin不需要额外处理首尾
# 注意:这里使用 edges[1:-1] 作为分界,保证边界值归到正确 bin
bin_idx = torch.bucketize(z.contiguous(), edges[1:-1], right=False) # shape: (B, N)
# 为每个 bin计算 max 和 mean 池化值,利用 scatter_reduce 与 scatter_add 操作:
# 构造初始 tensor形状均为 (B, M, C)
pooled_max = torch.full((B, M, C), float('-inf'), device=point_clouds.device, dtype=point_clouds.dtype)
pooled_sum = torch.zeros((B, M, C), device=point_clouds.device, dtype=point_clouds.dtype)
counts = torch.zeros((B, M, 1), device=point_clouds.device, dtype=point_clouds.dtype)
# 将 bin_idx 扩展到与 point_clouds 对应的维度 (B, N, C)
bin_idx_exp = bin_idx.unsqueeze(-1).expand(-1, -1, C)
# max 池化scatter_reduce 计算每个 bin 内的最大值
pooled_max = pooled_max.scatter_reduce(1, bin_idx_exp, point_clouds, reduce='amax', include_self=True)
# sum 池化scatter_add 计算每个 bin 内的和
pooled_sum = pooled_sum.scatter_add(1, bin_idx_exp, point_clouds)
# 计算每个 bin 的计数
counts = counts.scatter_add(1, bin_idx.unsqueeze(-1), torch.ones((B, N, 1), device=point_clouds.device))
# 计算 mean 池化
pooled_mean = pooled_sum / counts.clamp(min=1)
# 这里采用 max 与 mean 的和作为最终池化结果(也可以用 concat
pooled = pooled_max + pooled_mean
# 将没有点max为 -inf的 bin 置 0
pooled[pooled == float('-inf')] = 0
all_pooled.append(pooled)
# 将各分辨率下的池化结果在 bin 维度上拼接,并调整为 B x C x M_total
output = torch.cat(all_pooled, dim=1)
output = rearrange(output, 'B M C -> B C M')
return output
class NetVLAD(nn.Module):
"""NetVLAD layer implementation"""
def __init__(self, num_clusters=64, dim=128, alpha=100.0,
normalize_input=True):
"""
Args:
num_clusters : int
The number of clusters
dim : int
Dimension of descriptors
alpha : float
Parameter of initialization. Larger value is harder assignment.
normalize_input : bool
If true, descriptor-wise L2 normalization is applied to input.
"""
super(NetVLAD, self).__init__()
self.num_clusters = num_clusters
self.dim = dim
self.alpha = alpha
self.normalize_input = normalize_input
self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True)
self.centroids = nn.Parameter(torch.rand(num_clusters, dim))
self._init_params()
def _init_params(self):
self.conv.weight = nn.Parameter(
(2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
)
self.conv.bias = nn.Parameter(
- self.alpha * self.centroids.norm(dim=1)
)
def forward(self, x, xyz):
N, C = x.shape[:2]
if self.normalize_input:
x = F.normalize(x, p=2, dim=1) # across descriptor dim
# soft-assignment
soft_assign = self.conv(x).view(N, self.num_clusters, -1)
soft_assign = F.softmax(soft_assign, dim=1)
x_flatten = x.view(N, C, -1)
# calculate residuals to each clusters
residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \
self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
residual *= soft_assign.unsqueeze(2)
vlad = residual.sum(dim=-1)
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
vlad = vlad.view(x.size(0), -1) # flatten b num c -> b num c
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
return vlad.unsqueeze(-1)