Merge close 2D poses.

This commit is contained in:
Daniel
2025-01-27 14:20:52 +01:00
parent d886d1db7c
commit f9127e9a40

View File

@ -917,6 +917,10 @@ namespace utils_2d_pose
std::unique_ptr<RTMDet> det_model; std::unique_ptr<RTMDet> det_model;
std::unique_ptr<RTMPose> pose_model; std::unique_ptr<RTMPose> pose_model;
bool batch_poses; bool batch_poses;
void merge_close_poses(
std::vector<std::vector<std::array<float, 3>>> &poses,
std::array<size_t, 2> image_size);
}; };
// ============================================================================================= // =============================================================================================
@ -925,12 +929,12 @@ namespace utils_2d_pose
{ {
auto bboxes = det_model->call(image); auto bboxes = det_model->call(image);
std::vector<std::vector<std::array<float, 3>>> keypoints; std::vector<std::vector<std::array<float, 3>>> poses;
if (this->batch_poses) if (this->batch_poses)
{ {
if (!bboxes.empty()) if (!bboxes.empty())
{ {
keypoints = std::move(pose_model->call(image, bboxes)); poses = std::move(pose_model->call(image, bboxes));
} }
} }
else else
@ -938,10 +942,98 @@ namespace utils_2d_pose
for (size_t i = 0; i < bboxes.size(); i++) for (size_t i = 0; i < bboxes.size(); i++)
{ {
auto kpts = pose_model->call(image, {bboxes[i]}); auto kpts = pose_model->call(image, {bboxes[i]});
keypoints.push_back(std::move(kpts[0])); poses.push_back(std::move(kpts[0]));
} }
} }
return keypoints;
// Sometimes the detection model predicts multiple boxes with different shapes for the same
// person. They then result in strongly overlapping poses, which are merged here.
merge_close_poses(poses, {(size_t)image.cols, (size_t)image.rows});
return poses;
}
// =============================================================================================
void TopDown::merge_close_poses(
std::vector<std::vector<std::array<float, 3>>> &poses,
std::array<size_t, 2> image_size)
{
// Joint ids in COCO order
const size_t num_overlaps = 5;
const std::map<std::string, size_t> joint_indices = {
{"nose", 0},
{"left_shoulder", 5},
{"right_shoulder", 6},
{"left_hip", 11},
{"right_hip", 12},
{"left_elbow", 7},
{"right_elbow", 8},
};
if (poses.size() < 2)
{
return;
}
// Calculate pixel threshold based on image size
size_t min_dim = std::min(image_size[0], image_size[1]);
float pixel_threshold = 0.01f * min_dim;
// Merge poses if enough joints are close
std::vector<std::vector<std::array<float, 3>>> merged_poses;
merged_poses.reserve(poses.size());
for (auto &opose : poses)
{
bool merged = false;
for (auto &mpose : merged_poses)
{
size_t close_count = 0;
for (auto &kv : joint_indices)
{
size_t joint_id = kv.second;
float x1 = opose[joint_id][0];
float y1 = opose[joint_id][1];
float c1 = opose[joint_id][2];
float x2 = mpose[joint_id][0];
float y2 = mpose[joint_id][1];
float c2 = mpose[joint_id][2];
if (c1 > 0.0f && c2 > 0.0f)
{
float dx = x1 - x2;
float dy = y1 - y2;
float dist_sq = dx * dx + dy * dy;
if (dist_sq <= pixel_threshold * pixel_threshold)
{
++close_count;
}
}
}
if (close_count >= num_overlaps)
{
// Merge new pose with existing one
for (size_t j = 0; j < mpose.size(); ++j)
{
if (opose[j][2] > mpose[j][2])
{
mpose[j] = std::move(opose[j]);
}
}
merged = true;
break;
}
}
if (!merged)
{
// Not mergeable, add as new pose
merged_poses.push_back(std::move(opose));
}
}
// Replace original poses with merged ones
poses = std::move(merged_poses);
} }
// ============================================================================================= // =============================================================================================