Files
RapidPoseTriangulation/rpt/rgbd_merger.cpp
T
crosstyan ed721729fd feat(rgbd): add RGB-D reconstruction pipeline
Add end-to-end RGB-D reconstruction support across the C++ core and Python API.

- add a native merge_rgbd_views path, view-aware 3D pose containers, and nanobind bindings

- expose Python helpers to sample aligned depth, apply per-joint offsets, lift UVD poses to world space, and run reconstruct_rgbd

- add RGB-D regression tests for merging, manual pipeline parity, symmetric depth sampling windows, and out-of-bounds joints

- bump the project version from 0.1.0 to 0.2.0 for the new feature surface
2026-03-26 13:04:57 +08:00

1328 lines
47 KiB
C++

#include <algorithm>
#include <array>
#include <cmath>
#include <limits>
#include <numeric>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "interface.hpp"
namespace
{
using Pose3D = std::vector<std::array<float, 4>>;
using PoseList3D = std::vector<Pose3D>;
struct Track
{
Pose3D skeleton;
std::vector<Pose3D> last_detections;
};
std::vector<int> solve_assignment_hungarian(const std::vector<std::vector<double>> &costs)
{
const size_t rows = costs.size();
const size_t cols = rows == 0 ? 0 : costs[0].size();
const size_t size = std::max(rows, cols);
if (size == 0)
{
return {};
}
const double inf = std::numeric_limits<double>::infinity();
std::vector<std::vector<double>> matrix(size + 1, std::vector<double>(size + 1, 0.0));
for (size_t row = 0; row < rows; ++row)
{
for (size_t col = 0; col < cols; ++col)
{
matrix[row + 1][col + 1] = costs[row][col];
}
}
std::vector<double> u(size + 1, 0.0);
std::vector<double> v(size + 1, 0.0);
std::vector<size_t> p(size + 1, 0);
std::vector<size_t> way(size + 1, 0);
for (size_t row = 1; row <= rows; ++row)
{
p[0] = row;
size_t col0 = 0;
std::vector<double> minv(size + 1, inf);
std::vector<bool> used(size + 1, false);
do
{
used[col0] = true;
const size_t row0 = p[col0];
double delta = inf;
size_t col1 = 0;
for (size_t col = 1; col <= size; ++col)
{
if (used[col])
{
continue;
}
const double current = matrix[row0][col] - u[row0] - v[col];
if (current < minv[col])
{
minv[col] = current;
way[col] = col0;
}
if (minv[col] < delta)
{
delta = minv[col];
col1 = col;
}
}
for (size_t col = 0; col <= size; ++col)
{
if (used[col])
{
u[p[col]] += delta;
v[col] -= delta;
}
else
{
minv[col] -= delta;
}
}
col0 = col1;
} while (p[col0] != 0);
do
{
const size_t col1 = way[col0];
p[col0] = p[col1];
col0 = col1;
} while (col0 != 0);
}
std::vector<int> assignment(rows, -1);
for (size_t col = 1; col <= size; ++col)
{
const size_t row = p[col];
if (row == 0 || row > rows || col > cols)
{
continue;
}
assignment[row - 1] = static_cast<int>(col - 1);
}
return assignment;
}
bool all_visible_joints_close(
const Pose3D &skel1,
const Pose3D &skel2,
float max_distance,
float vis_threshold)
{
const float max_dist_sq = max_distance * max_distance;
bool any_visible = false;
for (size_t joint = 0; joint < skel1.size(); ++joint)
{
const bool visible1 = skel1[joint][3] > vis_threshold;
const bool visible2 = skel2[joint][3] > vis_threshold;
if (!visible1 || !visible2)
{
continue;
}
any_visible = true;
const float dx = skel1[joint][0] - skel2[joint][0];
const float dy = skel1[joint][1] - skel2[joint][1];
const float dz = skel1[joint][2] - skel2[joint][2];
const float distance_sq = dx * dx + dy * dy + dz * dz;
if (distance_sq > max_dist_sq)
{
return false;
}
}
return any_visible;
}
void add_extra_joints(PoseList3D &poses, const std::vector<std::string> &joint_names)
{
const auto it_head = std::find(joint_names.begin(), joint_names.end(), "head");
const auto it_ear_left = std::find(joint_names.begin(), joint_names.end(), "ear_left");
const auto it_ear_right = std::find(joint_names.begin(), joint_names.end(), "ear_right");
const auto it_nose = std::find(joint_names.begin(), joint_names.end(), "nose");
if (it_head == joint_names.end() || it_ear_left == joint_names.end() ||
it_ear_right == joint_names.end() || it_nose == joint_names.end())
{
return;
}
const int idx_head = std::distance(joint_names.begin(), it_head);
const int idx_ear_left = std::distance(joint_names.begin(), it_ear_left);
const int idx_ear_right = std::distance(joint_names.begin(), it_ear_right);
const int idx_nose = std::distance(joint_names.begin(), it_nose);
for (auto &pose : poses)
{
auto &joint_head = pose[static_cast<size_t>(idx_head)];
const auto &joint_ear_left = pose[static_cast<size_t>(idx_ear_left)];
const auto &joint_ear_right = pose[static_cast<size_t>(idx_ear_right)];
if (joint_ear_left[3] > 0.1f && joint_ear_right[3] > 0.1f)
{
joint_head[0] = (joint_ear_left[0] + joint_ear_right[0]) * 0.5f;
joint_head[1] = (joint_ear_left[1] + joint_ear_right[1]) * 0.5f;
joint_head[2] = (joint_ear_left[2] + joint_ear_right[2]) * 0.5f;
joint_head[3] = std::min(joint_ear_left[3], joint_ear_right[3]);
continue;
}
const auto &joint_nose = pose[static_cast<size_t>(idx_nose)];
if (joint_nose[3] > 0.1f)
{
joint_head[0] = joint_nose[0];
joint_head[1] = joint_nose[1];
joint_head[2] = joint_nose[2];
joint_head[3] = joint_nose[3];
}
}
}
void filter_poses(PoseList3D &poses, const std::array<std::array<float, 3>, 2> &roomparams)
{
constexpr float min_score = 0.1f;
std::vector<size_t> drop_indices;
drop_indices.reserve(poses.size());
for (size_t pose_index = 0; pose_index < poses.size(); ++pose_index)
{
auto &pose = poses[pose_index];
std::vector<size_t> valid_joint_idx;
valid_joint_idx.reserve(pose.size());
for (size_t joint = 0; joint < pose.size(); ++joint)
{
if (pose[joint][3] > min_score)
{
valid_joint_idx.push_back(joint);
}
}
if (valid_joint_idx.size() < 5)
{
drop_indices.push_back(pose_index);
continue;
}
std::array<float, 3> mean = {0.0f, 0.0f, 0.0f};
std::array<float, 3> mins = {
std::numeric_limits<float>::max(),
std::numeric_limits<float>::max(),
std::numeric_limits<float>::max(),
};
std::array<float, 3> maxs = {
std::numeric_limits<float>::lowest(),
std::numeric_limits<float>::lowest(),
std::numeric_limits<float>::lowest(),
};
for (const size_t joint : valid_joint_idx)
{
for (size_t coord = 0; coord < 3; ++coord)
{
mins[coord] = std::min(mins[coord], pose[joint][coord]);
maxs[coord] = std::max(maxs[coord], pose[joint][coord]);
mean[coord] += pose[joint][coord];
}
}
for (size_t coord = 0; coord < 3; ++coord)
{
mean[coord] /= static_cast<float>(valid_joint_idx.size());
}
constexpr float max_size = 2.3f;
constexpr float min_size = 0.3f;
std::array<float, 3> diff = {
maxs[0] - mins[0],
maxs[1] - mins[1],
maxs[2] - mins[2],
};
if (diff[0] > max_size || diff[1] > max_size || diff[2] > max_size)
{
drop_indices.push_back(pose_index);
continue;
}
if (diff[0] < min_size && diff[1] < min_size && diff[2] < min_size)
{
drop_indices.push_back(pose_index);
continue;
}
constexpr float wdist = 0.1f;
const auto &room_size = roomparams[0];
const auto &room_center = roomparams[1];
const std::array<float, 3> room_half_size = {
room_size[0] / 2.0f,
room_size[1] / 2.0f,
room_size[2] / 2.0f,
};
bool outside = false;
for (size_t coord = 0; coord < 3; ++coord)
{
if (mean[coord] > room_half_size[coord] + room_center[coord] ||
mean[coord] < -room_half_size[coord] + room_center[coord])
{
outside = true;
break;
}
}
for (size_t coord = 0; coord < 3; ++coord)
{
if (maxs[coord] > room_half_size[coord] + room_center[coord] + wdist ||
mins[coord] < -room_half_size[coord] + room_center[coord] - wdist)
{
outside = true;
break;
}
}
if (outside)
{
drop_indices.push_back(pose_index);
}
}
for (const size_t pose_index : drop_indices)
{
for (auto &joint : poses[pose_index])
{
joint[3] = 0.001f;
}
}
}
void add_missing_joints(PoseList3D &poses, const std::vector<std::string> &joint_names, float vis_threshold)
{
std::unordered_map<std::string, size_t> joint_name_to_idx;
joint_name_to_idx.reserve(joint_names.size());
for (size_t idx = 0; idx < joint_names.size(); ++idx)
{
joint_name_to_idx[joint_names[idx]] = idx;
}
std::unordered_map<std::string, std::vector<std::string>> adjacents = {
{"hip_right", {"hip_middle", "hip_left"}},
{"hip_left", {"hip_middle", "hip_right"}},
{"knee_right", {"hip_right", "knee_left"}},
{"knee_left", {"hip_left", "knee_right"}},
{"ankle_right", {"knee_right", "ankle_left"}},
{"ankle_left", {"knee_left", "ankle_right"}},
{"shoulder_right", {"shoulder_middle", "shoulder_left"}},
{"shoulder_left", {"shoulder_middle", "shoulder_right"}},
{"elbow_right", {"shoulder_right", "hip_right"}},
{"elbow_left", {"shoulder_left", "hip_left"}},
{"wrist_right", {"elbow_right"}},
{"wrist_left", {"elbow_left"}},
{"nose", {"shoulder_middle", "shoulder_right", "shoulder_left"}},
{"head", {"shoulder_middle", "shoulder_right", "shoulder_left"}},
{"foot_*_left_*", {"ankle_left"}},
{"foot_*_right_*", {"ankle_right"}},
{"face_*", {"nose"}},
{"hand_*_left_*", {"wrist_left"}},
{"hand_*_right_*", {"wrist_right"}},
};
for (auto &pose : poses)
{
std::vector<size_t> valid_joint_idx;
valid_joint_idx.reserve(pose.size());
for (size_t joint = 0; joint < pose.size(); ++joint)
{
if (pose[joint][3] > vis_threshold)
{
valid_joint_idx.push_back(joint);
}
}
if (valid_joint_idx.empty())
{
continue;
}
std::array<float, 3> body_center = {0.0f, 0.0f, 0.0f};
for (const size_t joint : valid_joint_idx)
{
body_center[0] += pose[joint][0];
body_center[1] += pose[joint][1];
body_center[2] += pose[joint][2];
}
for (size_t coord = 0; coord < 3; ++coord)
{
body_center[coord] /= static_cast<float>(valid_joint_idx.size());
}
for (size_t joint = 0; joint < joint_names.size(); ++joint)
{
std::string adjacent_name;
const std::string &joint_name = joint_names[joint];
if (joint_name.starts_with("foot_") && joint_name.find("_left") != std::string::npos)
{
adjacent_name = "foot_*_left_*";
}
else if (
joint_name.starts_with("foot_") &&
joint_name.find("_right") != std::string::npos)
{
adjacent_name = "foot_*_right_*";
}
else if (joint_name.starts_with("face_"))
{
adjacent_name = "face_*";
}
else if (joint_name.starts_with("hand_") && joint_name.find("_left") != std::string::npos)
{
adjacent_name = "hand_*_left_*";
}
else if (
joint_name.starts_with("hand_") &&
joint_name.find("_right") != std::string::npos)
{
adjacent_name = "hand_*_right_*";
}
else if (adjacents.contains(joint_name))
{
adjacent_name = joint_name;
}
if (adjacent_name.empty())
{
continue;
}
auto &joint_entry = pose[joint];
if (joint_entry[3] != 0.0f)
{
continue;
}
std::array<float, 3> adjacent_position = body_center;
auto adjacent_iter = adjacents.find(adjacent_name);
if (adjacent_iter != adjacents.end())
{
adjacent_position = {0.0f, 0.0f, 0.0f};
size_t adjacent_count = 0;
for (const std::string &adjacent_joint_name : adjacent_iter->second)
{
const auto mapped = joint_name_to_idx.find(adjacent_joint_name);
if (mapped == joint_name_to_idx.end())
{
continue;
}
const auto &adjacent_joint = pose[mapped->second];
if (adjacent_joint[3] <= vis_threshold)
{
continue;
}
adjacent_position[0] += adjacent_joint[0];
adjacent_position[1] += adjacent_joint[1];
adjacent_position[2] += adjacent_joint[2];
++adjacent_count;
}
if (adjacent_count > 0)
{
for (size_t coord = 0; coord < 3; ++coord)
{
adjacent_position[coord] /= static_cast<float>(adjacent_count);
}
}
else
{
adjacent_position = body_center;
}
}
joint_entry[0] = adjacent_position[0];
joint_entry[1] = adjacent_position[1];
joint_entry[2] = adjacent_position[2];
joint_entry[3] = 0.1f;
}
}
}
void replace_far_joints(
PoseList3D &poses,
const std::vector<std::string> &joint_names,
float min_match_score)
{
for (auto &pose : poses)
{
std::array<float, 4> center_head = {0.0f, 0.0f, 0.0f, 0.0f};
std::array<float, 4> center_foot_left = {0.0f, 0.0f, 0.0f, 0.0f};
std::array<float, 4> center_foot_right = {0.0f, 0.0f, 0.0f, 0.0f};
std::array<float, 4> center_hand_left = {0.0f, 0.0f, 0.0f, 0.0f};
std::array<float, 4> center_hand_right = {0.0f, 0.0f, 0.0f, 0.0f};
for (size_t joint = 0; joint < pose.size(); ++joint)
{
const std::string &joint_name = joint_names[joint];
const float offset = (1.0f - min_match_score) * 2.0f;
const float min_score = min_match_score - offset;
if (pose[joint][3] <= min_score)
{
continue;
}
if (
joint_name.starts_with("face_") || joint_name == "nose" || joint_name == "eye_left" ||
joint_name == "eye_right" || joint_name == "ear_left" ||
joint_name == "ear_right")
{
center_head[0] += pose[joint][0];
center_head[1] += pose[joint][1];
center_head[2] += pose[joint][2];
center_head[3] += 1.0f;
}
else if (joint_name.starts_with("foot_") || joint_name.starts_with("ankle_"))
{
if (joint_name.find("_left") != std::string::npos)
{
center_foot_left[0] += pose[joint][0];
center_foot_left[1] += pose[joint][1];
center_foot_left[2] += pose[joint][2];
center_foot_left[3] += 1.0f;
}
else if (joint_name.find("_right") != std::string::npos)
{
center_foot_right[0] += pose[joint][0];
center_foot_right[1] += pose[joint][1];
center_foot_right[2] += pose[joint][2];
center_foot_right[3] += 1.0f;
}
}
else if (joint_name.starts_with("hand_") || joint_name.starts_with("wrist_"))
{
if (joint_name.find("_left") != std::string::npos)
{
center_hand_left[0] += pose[joint][0];
center_hand_left[1] += pose[joint][1];
center_hand_left[2] += pose[joint][2];
center_hand_left[3] += 1.0f;
}
else if (joint_name.find("_right") != std::string::npos)
{
center_hand_right[0] += pose[joint][0];
center_hand_right[1] += pose[joint][1];
center_hand_right[2] += pose[joint][2];
center_hand_right[3] += 1.0f;
}
}
}
for (size_t coord = 0; coord < 3; ++coord)
{
if (center_head[3] > 0.0f)
{
center_head[coord] /= center_head[3];
}
if (center_foot_left[3] > 0.0f)
{
center_foot_left[coord] /= center_foot_left[3];
}
if (center_foot_right[3] > 0.0f)
{
center_foot_right[coord] /= center_foot_right[3];
}
if (center_hand_left[3] > 0.0f)
{
center_hand_left[coord] /= center_hand_left[3];
}
if (center_hand_right[3] > 0.0f)
{
center_hand_right[coord] /= center_hand_right[3];
}
}
constexpr float max_dist_head_sq = 0.20f * 0.20f;
constexpr float max_dist_foot_sq = 0.25f * 0.25f;
constexpr float max_dist_hand_sq = 0.20f * 0.20f;
for (size_t joint = 0; joint < pose.size(); ++joint)
{
const std::string &joint_name = joint_names[joint];
std::array<float, 4> center = {0.0f, 0.0f, 0.0f, 0.0f};
float max_dist_sq = 0.0f;
if (
joint_name.starts_with("face_") || joint_name == "nose" || joint_name == "eye_left" ||
joint_name == "eye_right" || joint_name == "ear_left" ||
joint_name == "ear_right")
{
center = center_head;
max_dist_sq = max_dist_head_sq;
}
else if (joint_name.starts_with("foot_") || joint_name.starts_with("ankle_"))
{
center = joint_name.find("_left") != std::string::npos ? center_foot_left : center_foot_right;
max_dist_sq = max_dist_foot_sq;
}
else if (joint_name.starts_with("hand_") || joint_name.starts_with("wrist_"))
{
center = joint_name.find("_left") != std::string::npos ? center_hand_left : center_hand_right;
max_dist_sq = max_dist_hand_sq;
}
else
{
continue;
}
if (center[3] <= 0.0f)
{
continue;
}
const float dx = pose[joint][0] - center[0];
const float dy = pose[joint][1] - center[1];
const float dz = pose[joint][2] - center[2];
const float dist_sq = dx * dx + dy * dy + dz * dz;
if ((pose[joint][3] > 0.0f && dist_sq > max_dist_sq) || pose[joint][3] == 0.0f)
{
pose[joint][0] = center[0];
pose[joint][1] = center[1];
pose[joint][2] = center[2];
pose[joint][3] = 0.1f;
}
}
}
}
class RgbdViewMerger
{
public:
RgbdViewMerger(const std::vector<std::string> &joint_names_in, size_t num_views_in, float max_distance_in)
: joint_names(joint_names_in),
num_views(static_cast<float>(num_views_in)),
max_distance(max_distance_in)
{
neighbor_joints = {
{"chest", {
"shoulder_left",
"shoulder_right",
"shoulder_middle",
"hip_left",
"hip_right",
"hip_middle",
}},
{"head", {
"nose",
"eye_left",
"eye_right",
"ear_left",
"ear_right",
"shoulder_left",
"shoulder_right",
"shoulder_middle",
"neck",
}},
{"nose", {
"eye_left",
"eye_right",
"ear_left",
"ear_right",
"shoulder_left",
"shoulder_right",
"shoulder_middle",
"neck",
}},
{"eye_left", {
"nose",
"eye_right",
"ear_left",
"ear_right",
"shoulder_left",
"shoulder_middle",
}},
{"eye_right", {
"nose",
"eye_left",
"ear_left",
"ear_right",
"shoulder_right",
"shoulder_middle",
}},
{"ear_left", {
"nose",
"ear_right",
"eye_left",
"eye_right",
"shoulder_left",
"shoulder_middle",
}},
{"ear_right", {
"nose",
"ear_left",
"eye_left",
"eye_right",
"shoulder_right",
"shoulder_middle",
}},
{"shoulder_left", {
"nose",
"eye_left",
"ear_left",
"elbow_left",
"hip_left",
"shoulder_middle",
"neck",
}},
{"shoulder_right", {
"nose",
"eye_right",
"ear_right",
"elbow_right",
"hip_right",
"shoulder_middle",
"neck",
}},
{"shoulder_middle", {
"nose",
"eye_left",
"eye_right",
"ear_left",
"ear_right",
"elbow_left",
"elbow_right",
"hip_middle",
"shoulder_left",
"shoulder_right",
"neck",
}},
{"neck", {
"nose",
"eye_left",
"eye_right",
"shoulder_left",
"shoulder_right",
"shoulder_middle",
}},
{"elbow_left", {
"shoulder_left",
"wrist_left",
"shoulder_middle",
"neck",
}},
{"elbow_right", {
"shoulder_right",
"wrist_right",
"shoulder_middle",
"neck",
}},
{"wrist_left", {
"elbow_left",
}},
{"wrist_right", {
"elbow_right",
}},
{"hip_left", {
"hip_right",
"knee_left",
"shoulder_left",
"hip_middle",
}},
{"hip_right", {
"hip_left",
"knee_right",
"shoulder_right",
"hip_middle",
}},
{"hip_middle", {
"hip_left",
"hip_right",
"knee_right",
"knee_left",
"shoulder_middle",
"neck",
}},
{"knee_left", {
"hip_left",
"hip_middle",
"hip_right",
"ankle_left",
}},
{"knee_right", {
"hip_right",
"hip_middle",
"hip_left",
"ankle_right",
}},
{"ankle_left", {
"knee_left",
}},
{"ankle_right", {
"knee_right",
}},
};
for (const auto &entry : neighbor_joints)
{
std::vector<size_t> ids;
ids.reserve(entry.second.size());
for (const auto &joint_name : entry.second)
{
const auto it = std::find(joint_names.begin(), joint_names.end(), joint_name);
if (it != joint_names.end())
{
ids.push_back(static_cast<size_t>(std::distance(joint_names.begin(), it)));
}
}
neighbor_joints_ids[entry.first] = std::move(ids);
}
}
void consume_view(const PoseList3D &view_poses)
{
if (view_poses.empty())
{
return;
}
const PoseList3D filtered_detections = drop_outlier_joints(view_poses);
auto [matches, new_detection_indices, outdated_track_indices] = data_association(filtered_detections);
(void)outdated_track_indices;
for (const auto &[track_index, detection_index] : matches)
{
auto &track = tracks[track_index];
while (track.last_detections.size() >= static_cast<size_t>(num_views))
{
track.last_detections.erase(track.last_detections.begin());
}
track.last_detections.push_back(filtered_detections[detection_index]);
}
for (auto &track : tracks)
{
update_track_position(track);
}
for (const size_t detection_index : new_detection_indices)
{
Track track;
track.skeleton = filtered_detections[detection_index];
track.last_detections.push_back(filtered_detections[detection_index]);
tracks.push_back(std::move(track));
}
merge_tracks(merge_distance);
for (auto &track : tracks)
{
update_track_position(track);
}
merge_tracks(merge_distance);
for (auto &track : tracks)
{
update_track_position(track);
}
}
PoseList3D finalize() const
{
PoseList3D poses;
poses.reserve(tracks.size());
for (const auto &track : tracks)
{
size_t visible_joints = 0;
for (const auto &joint : track.skeleton)
{
if (joint[3] > vis_threshold)
{
++visible_joints;
}
}
if (visible_joints >= min_num_kpts)
{
poses.push_back(track.skeleton);
}
}
return poses;
}
private:
std::array<float, 4> neighbor_center(const Pose3D &skeleton, const std::string &joint_name) const
{
std::array<float, 4> mean = {0.0f, 0.0f, 0.0f, 0.0f};
const auto neighbor_iter = neighbor_joints_ids.find(joint_name);
if (neighbor_iter == neighbor_joints_ids.end())
{
return mean;
}
const auto &neighbor_ids = neighbor_iter->second;
size_t count = 0;
for (const size_t idx : neighbor_ids)
{
const auto &joint = skeleton[idx];
if (joint[3] > vis_threshold)
{
mean[0] += joint[0];
mean[1] += joint[1];
mean[2] += joint[2];
mean[3] += joint[3];
++count;
}
}
if (count > 0)
{
const float inv = 1.0f / static_cast<float>(count);
mean[0] *= inv;
mean[1] *= inv;
mean[2] *= inv;
mean[3] *= inv;
}
return mean;
}
std::array<float, 3> centroid(const Pose3D &skeleton) const
{
const auto it = std::find(joint_names.begin(), joint_names.end(), "chest");
if (it != joint_names.end())
{
const size_t chest_index = static_cast<size_t>(std::distance(joint_names.begin(), it));
if (skeleton[chest_index][3] > 0.0f)
{
return {skeleton[chest_index][0], skeleton[chest_index][1], skeleton[chest_index][2]};
}
}
const auto center = neighbor_center(skeleton, "chest");
return {center[0], center[1], center[2]};
}
PoseList3D drop_outlier_joints(const PoseList3D &detections) const
{
PoseList3D filtered;
filtered.reserve(detections.size());
for (Pose3D pose : detections)
{
std::array<float, 4> hip_center = neighbor_center(pose, "hip_middle");
if (hip_center[3] > 0.0f)
{
const float max_dist_sq = max_distance * 3.5f * max_distance * 3.5f;
for (auto &joint : pose)
{
const float dx = joint[0] - hip_center[0];
const float dy = joint[1] - hip_center[1];
const float dz = joint[2] - hip_center[2];
const float distance_sq = dx * dx + dy * dy + dz * dz;
if (distance_sq > max_dist_sq)
{
joint[3] = 0.0f;
}
}
}
hip_center = neighbor_center(pose, "hip_middle");
if (hip_center[3] > 0.0f)
{
const float max_dist_sq = max_distance * 2.5f * max_distance * 2.5f;
for (auto &joint : pose)
{
const float dx = joint[0] - hip_center[0];
const float dy = joint[1] - hip_center[1];
const float dz = joint[2] - hip_center[2];
const float distance_sq = dx * dx + dy * dy + dz * dz;
if (distance_sq > max_dist_sq)
{
joint[3] = 0.0f;
}
}
}
const float max_dist_sq = max_distance * max_distance;
for (size_t joint_index = 0; joint_index < pose.size(); ++joint_index)
{
const auto center = neighbor_center(pose, joint_names[joint_index]);
if (center[3] <= 0.0f)
{
continue;
}
const float dx = pose[joint_index][0] - center[0];
const float dy = pose[joint_index][1] - center[1];
const float dz = pose[joint_index][2] - center[2];
const float distance_sq = dx * dx + dy * dy + dz * dz;
if (distance_sq > max_dist_sq)
{
pose[joint_index][3] = 0.0f;
}
}
float sum_scores = 0.0f;
for (const auto &joint : pose)
{
sum_scores += joint[3];
}
if (sum_scores > 0.0f)
{
filtered.push_back(std::move(pose));
}
}
return filtered;
}
std::tuple<std::vector<std::pair<size_t, size_t>>, std::vector<size_t>, std::vector<size_t>>
data_association(const PoseList3D &new_poses) const
{
std::vector<std::pair<size_t, size_t>> matches;
std::vector<size_t> new_detection_indices;
std::vector<size_t> outdated_track_indices;
if (tracks.empty())
{
new_detection_indices.resize(new_poses.size());
std::iota(new_detection_indices.begin(), new_detection_indices.end(), 0);
return {matches, new_detection_indices, outdated_track_indices};
}
std::vector<std::array<float, 3>> track_centroids(tracks.size());
for (size_t track_index = 0; track_index < tracks.size(); ++track_index)
{
track_centroids[track_index] = centroid(tracks[track_index].skeleton);
}
std::vector<std::array<float, 3>> detection_centroids(new_poses.size());
for (size_t detection_index = 0; detection_index < new_poses.size(); ++detection_index)
{
detection_centroids[detection_index] = centroid(new_poses[detection_index]);
}
std::vector<std::vector<double>> cost_matrix(
tracks.size(),
std::vector<double>(new_poses.size(), 0.0));
for (size_t track_index = 0; track_index < tracks.size(); ++track_index)
{
for (size_t detection_index = 0; detection_index < new_poses.size(); ++detection_index)
{
const float dx =
track_centroids[track_index][0] - detection_centroids[detection_index][0];
const float dy =
track_centroids[track_index][1] - detection_centroids[detection_index][1];
const float dz =
track_centroids[track_index][2] - detection_centroids[detection_index][2];
cost_matrix[track_index][detection_index] = std::sqrt(dx * dx + dy * dy + dz * dz);
}
}
const std::vector<int> assignment = solve_assignment_hungarian(cost_matrix);
for (size_t track_index = 0; track_index < assignment.size(); ++track_index)
{
const int detection_index = assignment[track_index];
if (detection_index < 0)
{
continue;
}
const double cost = cost_matrix[track_index][static_cast<size_t>(detection_index)];
if (cost < max_distance)
{
matches.emplace_back(track_index, static_cast<size_t>(detection_index));
}
}
std::unordered_set<size_t> matched_tracks;
std::unordered_set<size_t> matched_detections;
for (const auto &[track_index, detection_index] : matches)
{
matched_tracks.insert(track_index);
matched_detections.insert(detection_index);
}
for (size_t detection_index = 0; detection_index < new_poses.size(); ++detection_index)
{
if (!matched_detections.contains(detection_index))
{
new_detection_indices.push_back(detection_index);
}
}
for (size_t track_index = 0; track_index < tracks.size(); ++track_index)
{
if (!matched_tracks.contains(track_index))
{
outdated_track_indices.push_back(track_index);
}
}
return {matches, new_detection_indices, outdated_track_indices};
}
void update_track_position(Track &track) const
{
const Pose3D &current_pose = track.skeleton;
const size_t num_joints = current_pose.size();
const size_t num_detections = track.last_detections.size();
std::vector<std::array<float, 3>> neighbor_centers(num_joints, {0.0f, 0.0f, 0.0f});
for (size_t joint_index = 0; joint_index < num_joints; ++joint_index)
{
const auto neighbor_iter = neighbor_joints_ids.find(joint_names[joint_index]);
if (neighbor_iter == neighbor_joints_ids.end())
{
continue;
}
const auto &neighbor_ids = neighbor_iter->second;
size_t count = 0;
for (const size_t neighbor_index : neighbor_ids)
{
const auto &neighbor_joint = current_pose[neighbor_index];
if (neighbor_joint[3] > vis_threshold)
{
neighbor_centers[joint_index][0] += neighbor_joint[0];
neighbor_centers[joint_index][1] += neighbor_joint[1];
neighbor_centers[joint_index][2] += neighbor_joint[2];
++count;
}
}
if (count > 0)
{
const float inv = 1.0f / static_cast<float>(count);
neighbor_centers[joint_index][0] *= inv;
neighbor_centers[joint_index][1] *= inv;
neighbor_centers[joint_index][2] *= inv;
}
}
Pose3D new_pose(num_joints, {0.0f, 0.0f, 0.0f, 0.0f});
const float max_dist_sq = max_distance * max_distance;
for (size_t joint_index = 0; joint_index < num_joints; ++joint_index)
{
std::array<float, 4> mean = {0.0f, 0.0f, 0.0f, 0.0f};
size_t count = 0;
for (size_t detection_index = 0; detection_index < num_detections; ++detection_index)
{
const auto &joint = track.last_detections[detection_index][joint_index];
if (joint[3] <= 0.0f)
{
continue;
}
const float dx = joint[0] - neighbor_centers[joint_index][0];
const float dy = joint[1] - neighbor_centers[joint_index][1];
const float dz = joint[2] - neighbor_centers[joint_index][2];
const float distance_sq = dx * dx + dy * dy + dz * dz;
if (distance_sq <= max_dist_sq)
{
mean[0] += joint[0];
mean[1] += joint[1];
mean[2] += joint[2];
mean[3] += joint[3];
++count;
}
}
if (count > 0)
{
const float inv = 1.0f / static_cast<float>(count);
new_pose[joint_index][0] = mean[0] * inv;
new_pose[joint_index][1] = mean[1] * inv;
new_pose[joint_index][2] = mean[2] * inv;
new_pose[joint_index][3] = mean[3] * inv;
}
else
{
const float factor = (num_views - 1.0f) / num_views;
new_pose[joint_index] = current_pose[joint_index];
new_pose[joint_index][3] *= factor;
}
}
constexpr int topk = 3;
for (size_t joint_index = 0; joint_index < num_joints; ++joint_index)
{
std::vector<std::pair<float, size_t>> valid_detections;
valid_detections.reserve(num_detections);
for (size_t detection_index = 0; detection_index < num_detections; ++detection_index)
{
const auto &joint = track.last_detections[detection_index][joint_index];
if (joint[3] <= 0.0f)
{
continue;
}
const float dx = joint[0] - new_pose[joint_index][0];
const float dy = joint[1] - new_pose[joint_index][1];
const float dz = joint[2] - new_pose[joint_index][2];
const float distance_sq = dx * dx + dy * dy + dz * dz;
if (distance_sq <= max_dist_sq)
{
valid_detections.emplace_back(distance_sq, detection_index);
}
}
const size_t num_to_drop = std::max(
static_cast<int>(valid_detections.size()) - topk,
0);
if (num_to_drop == 0)
{
continue;
}
std::sort(valid_detections.begin(), valid_detections.end());
valid_detections.resize(valid_detections.size() - num_to_drop);
std::array<float, 4> mean = {0.0f, 0.0f, 0.0f, 0.0f};
for (const auto &[distance_sq, detection_index] : valid_detections)
{
(void)distance_sq;
const auto &joint = track.last_detections[detection_index][joint_index];
mean[0] += joint[0];
mean[1] += joint[1];
mean[2] += joint[2];
mean[3] += joint[3];
}
const float inv = 1.0f / static_cast<float>(valid_detections.size());
new_pose[joint_index][0] = mean[0] * inv;
new_pose[joint_index][1] = mean[1] * inv;
new_pose[joint_index][2] = mean[2] * inv;
new_pose[joint_index][3] = mean[3] * inv;
}
track.skeleton = std::move(new_pose);
}
void merge_tracks(float distance_threshold)
{
const size_t count = tracks.size();
std::unordered_set<size_t> merged_indices;
std::vector<Track> merged_tracks;
std::vector<Pose3D> skeletons(count);
for (size_t idx = 0; idx < count; ++idx)
{
skeletons[idx] = tracks[idx].skeleton;
}
for (size_t first = 0; first < count; ++first)
{
if (merged_indices.contains(first))
{
continue;
}
merged_tracks.push_back(tracks[first]);
const auto &current_skeleton = skeletons[first];
for (size_t second = first + 1; second < count; ++second)
{
if (merged_indices.contains(second))
{
continue;
}
const bool close = all_visible_joints_close(
current_skeleton,
skeletons[second],
distance_threshold,
vis_threshold);
if (!close)
{
continue;
}
auto &merged_track = merged_tracks.back();
const size_t num_last = merged_track.last_detections.size();
const size_t num_current = tracks[second].last_detections.size();
for (size_t joint = 0; joint < current_skeleton.size(); ++joint)
{
const auto &joint1 = current_skeleton[joint];
const auto &joint2 = skeletons[second][joint];
const float total_score = joint1[3] + joint2[3];
if (total_score <= 0.0f)
{
continue;
}
const float inv_score = 1.0f / total_score;
const float inv_count = 1.0f / static_cast<float>(num_last + num_current);
float w1 = (joint1[3] * inv_score) * (static_cast<float>(num_last) * inv_count);
float w2 = (joint2[3] * inv_score) * (static_cast<float>(num_current) * inv_count);
const float inv_weight = 1.0f / (w1 + w2);
w1 *= inv_weight;
w2 *= inv_weight;
merged_track.skeleton[joint][0] = joint1[0] * w1 + joint2[0] * w2;
merged_track.skeleton[joint][1] = joint1[1] * w1 + joint2[1] * w2;
merged_track.skeleton[joint][2] = joint1[2] * w1 + joint2[2] * w2;
merged_track.skeleton[joint][3] = joint1[3] * w1 + joint2[3] * w2;
}
merged_track.last_detections.insert(
merged_track.last_detections.end(),
tracks[second].last_detections.begin(),
tracks[second].last_detections.end());
merged_indices.insert(second);
}
}
tracks = std::move(merged_tracks);
}
std::vector<std::string> joint_names;
float num_views = 1.0f;
float max_distance = 0.5f;
const size_t min_num_kpts = 7;
const float merge_distance = 0.3f;
const float vis_threshold = 0.1f;
std::unordered_map<std::string, std::vector<std::string>> neighbor_joints;
std::unordered_map<std::string, std::vector<size_t>> neighbor_joints_ids;
std::vector<Track> tracks;
};
PoseList3D view_to_nested(const PoseBatch3DByViewView &poses_3d, size_t view_index)
{
PoseList3D view_poses;
view_poses.reserve(poses_3d.person_counts[view_index]);
for (size_t person = 0; person < poses_3d.person_counts[view_index]; ++person)
{
Pose3D pose(poses_3d.num_joints, {0.0f, 0.0f, 0.0f, 0.0f});
for (size_t joint = 0; joint < poses_3d.num_joints; ++joint)
{
for (size_t coord = 0; coord < 4; ++coord)
{
pose[joint][coord] = poses_3d.at(view_index, person, joint, coord);
}
}
view_poses.push_back(std::move(pose));
}
return view_poses;
}
} // namespace
PoseBatch3D merge_rgbd_views(
const PoseBatch3DByViewView &poses_3d,
const TriangulationConfig &config,
float max_distance)
{
if (poses_3d.person_counts == nullptr)
{
throw std::invalid_argument("person_counts must not be null.");
}
if (poses_3d.num_views == 0)
{
throw std::invalid_argument("No 3D views provided.");
}
if (poses_3d.num_views != config.cameras.size())
{
throw std::invalid_argument("Number of cameras and 3D views must be the same.");
}
if (poses_3d.num_joints != config.joint_names.size())
{
throw std::invalid_argument("Number of joint names and 3D poses must be the same.");
}
if (poses_3d.max_persons > 0 && poses_3d.num_joints > 0 && poses_3d.data == nullptr)
{
throw std::invalid_argument("poses_3d data must not be null.");
}
for (size_t view = 0; view < poses_3d.num_views; ++view)
{
if (poses_3d.person_counts[view] > poses_3d.max_persons)
{
throw std::invalid_argument(
"person_counts entries must not exceed the padded person dimension.");
}
}
RgbdViewMerger merger(config.joint_names, poses_3d.num_views, max_distance);
for (size_t view = 0; view < poses_3d.num_views; ++view)
{
merger.consume_view(view_to_nested(poses_3d, view));
}
PoseList3D merged = merger.finalize();
add_extra_joints(merged, config.joint_names);
filter_poses(merged, config.roomparams);
add_missing_joints(merged, config.joint_names, 0.1f);
replace_far_joints(merged, config.joint_names, config.options.min_match_score);
return PoseBatch3D::from_nested(merged);
}