diff --git a/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp b/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp index b6314e4..9a17172 100644 --- a/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp +++ b/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp @@ -16,6 +16,7 @@ using json = nlohmann::json; #include "rpt_msgs/msg/poses.hpp" #include "/RapidPoseTriangulation/rpt/camera.hpp" #include "/RapidPoseTriangulation/rpt/interface.hpp" +#include "/RapidPoseTriangulation/rpt/tracker.hpp" // ================================================================================================= @@ -36,9 +37,14 @@ static const std::string pose_in_topic = "/poses/{}"; static const std::string cam_info_topic = "/{}/calibration"; static const std::string pose_out_topic = "/poses/humans3d"; -static const float min_match_score = 0.92; +static const float min_match_score = 0.94; static const size_t min_group_size = 1; +static const bool use_tracking = true; +static const float max_movement_speed = 2.0 * 1.5; +static const float cam_fps = 50; +static const float max_track_distance = 0.3 + max_movement_speed / cam_fps; + static const std::array, 2> roomparams = {{ {4.0, 5.0, 2.2}, {1.0, 0.0, 1.1}, @@ -59,9 +65,11 @@ public: this->joint_names = {}; this->all_poses_set.resize(cam_ids.size(), false); - // Load 3D model + // Load 3D models tri_model = std::make_unique( min_match_score, min_group_size); + pose_tracker = std::make_unique( + max_movement_speed, max_track_distance); // QoS rclcpp::QoS qos_profile(1); @@ -113,6 +121,7 @@ private: rclcpp::Publisher::SharedPtr pose_pub_; std::unique_ptr tri_model; + std::unique_ptr pose_tracker; std::vector all_cameras; std::mutex cams_mutex, pose_mutex, model_mutex; @@ -230,11 +239,32 @@ void Rpt3DWrapperNode::call_model() // Since the prediction is very fast, parallel callback threads only need to wait a short time cams_mutex.lock(); pose_mutex.lock(); - const auto valid_poses = tri_model->triangulate_poses( + const auto poses_3d = tri_model->triangulate_poses( all_poses, all_cameras, roomparams, joint_names); double min_ts = *std::min_element(all_timestamps.begin(), all_timestamps.end()); this->all_poses_set = std::vector(cam_ids.size(), false); + + std::vector>> valid_poses; + std::vector track_ids; + if (use_tracking) + { + auto pose_tracks = pose_tracker->track_poses(poses_3d, joint_names, min_ts); + std::vector>> poses_3d_refined; + for (size_t j = 0; j < pose_tracks.size(); j++) + { + auto &pose = std::get<1>(pose_tracks[j]); + poses_3d_refined.push_back(pose); + auto &track_id = std::get<0>(pose_tracks[j]); + track_ids.push_back(track_id); + } + valid_poses = std::move(poses_3d_refined); + } + else + { + valid_poses = std::move(poses_3d); + track_ids = {}; + } pose_mutex.unlock(); cams_mutex.unlock(); @@ -267,6 +297,7 @@ void Rpt3DWrapperNode::call_model() } } pose_msg.joint_names = joint_names; + jdata["track_ids"] = track_ids; pose_msg.extra_data = jdata.dump(); pose_pub_->publish(pose_msg);