Add tracking to ros wrapper.

This commit is contained in:
Daniel
2025-04-23 14:05:04 +02:00
parent ad4bf3f54e
commit 79788cd7e0

View File

@ -16,6 +16,7 @@ using json = nlohmann::json;
#include "rpt_msgs/msg/poses.hpp"
#include "/RapidPoseTriangulation/rpt/camera.hpp"
#include "/RapidPoseTriangulation/rpt/interface.hpp"
#include "/RapidPoseTriangulation/rpt/tracker.hpp"
// =================================================================================================
@ -36,9 +37,14 @@ static const std::string pose_in_topic = "/poses/{}";
static const std::string cam_info_topic = "/{}/calibration";
static const std::string pose_out_topic = "/poses/humans3d";
static const float min_match_score = 0.92;
static const float min_match_score = 0.94;
static const size_t min_group_size = 1;
static const bool use_tracking = true;
static const float max_movement_speed = 2.0 * 1.5;
static const float cam_fps = 50;
static const float max_track_distance = 0.3 + max_movement_speed / cam_fps;
static const std::array<std::array<float, 3>, 2> roomparams = {{
{4.0, 5.0, 2.2},
{1.0, 0.0, 1.1},
@ -59,9 +65,11 @@ public:
this->joint_names = {};
this->all_poses_set.resize(cam_ids.size(), false);
// Load 3D model
// Load 3D models
tri_model = std::make_unique<Triangulator>(
min_match_score, min_group_size);
pose_tracker = std::make_unique<PoseTracker>(
max_movement_speed, max_track_distance);
// QoS
rclcpp::QoS qos_profile(1);
@ -113,6 +121,7 @@ private:
rclcpp::Publisher<rpt_msgs::msg::Poses>::SharedPtr pose_pub_;
std::unique_ptr<Triangulator> tri_model;
std::unique_ptr<PoseTracker> pose_tracker;
std::vector<Camera> all_cameras;
std::mutex cams_mutex, pose_mutex, model_mutex;
@ -230,11 +239,32 @@ void Rpt3DWrapperNode::call_model()
// Since the prediction is very fast, parallel callback threads only need to wait a short time
cams_mutex.lock();
pose_mutex.lock();
const auto valid_poses = tri_model->triangulate_poses(
const auto poses_3d = tri_model->triangulate_poses(
all_poses, all_cameras, roomparams, joint_names);
double min_ts = *std::min_element(all_timestamps.begin(), all_timestamps.end());
this->all_poses_set = std::vector<bool>(cam_ids.size(), false);
std::vector<std::vector<std::array<float, 4>>> valid_poses;
std::vector<size_t> track_ids;
if (use_tracking)
{
auto pose_tracks = pose_tracker->track_poses(poses_3d, joint_names, min_ts);
std::vector<std::vector<std::array<float, 4>>> poses_3d_refined;
for (size_t j = 0; j < pose_tracks.size(); j++)
{
auto &pose = std::get<1>(pose_tracks[j]);
poses_3d_refined.push_back(pose);
auto &track_id = std::get<0>(pose_tracks[j]);
track_ids.push_back(track_id);
}
valid_poses = std::move(poses_3d_refined);
}
else
{
valid_poses = std::move(poses_3d);
track_ids = {};
}
pose_mutex.unlock();
cams_mutex.unlock();
@ -267,6 +297,7 @@ void Rpt3DWrapperNode::call_model()
}
}
pose_msg.joint_names = joint_names;
jdata["track_ids"] = track_ids;
pose_msg.extra_data = jdata.dump();
pose_pub_->publish(pose_msg);