diff --git a/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp b/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp index e902232..b7e7638 100644 --- a/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp +++ b/extras/ros/rpt3d_wrapper_cpp/src/rpt3d_wrapper.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -52,10 +53,10 @@ public: Rpt3DWrapperNode() : Node("rpt3d_wrapper") { - this->all_cams_set = false; - this->cameras.resize(cam_ids.size()); + this->all_cameras.resize(cam_ids.size()); this->all_poses.resize(cam_ids.size()); this->all_timestamps.resize(cam_ids.size()); + this->joint_names = {}; this->all_poses_set.resize(cam_ids.size(), false); // Load 3D model @@ -67,6 +68,11 @@ public: qos_profile.reliable(); qos_profile.keep_last(1); + // Parallel executable callbacks + auto my_callback_group = create_callback_group(rclcpp::CallbackGroupType::Reentrant); + rclcpp::SubscriptionOptions options; + options.callback_group = my_callback_group; + // Setup subscribers for (size_t i = 0; i < cam_ids.size(); i++) { @@ -81,7 +87,8 @@ public: [this, i](const rpt_msgs::msg::Poses::SharedPtr msg) { this->callback_poses(msg, i); - }); + }, + options); sub_pose_list_.push_back(sub_pose); auto sub_cam = this->create_subscription( @@ -89,7 +96,8 @@ public: [this, i](const std_msgs::msg::String::SharedPtr msg) { this->callback_cam_info(msg, i); - }); + }, + options); sub_cam_list_.push_back(sub_cam); } @@ -105,15 +113,17 @@ private: rclcpp::Publisher::SharedPtr pose_pub_; std::unique_ptr tri_model; - std::vector cameras; - std::atomic all_cams_set; + std::vector all_cameras; + std::mutex cams_mutex, pose_mutex, model_mutex; std::vector>>> all_poses; std::vector all_timestamps; + std::vector joint_names; std::vector all_poses_set; void callback_poses(const rpt_msgs::msg::Poses::SharedPtr msg, size_t cam_idx); void callback_cam_info(const std_msgs::msg::String::SharedPtr msg, size_t cam_idx); + void call_model(); }; // ================================================================================================= @@ -131,34 +141,21 @@ void Rpt3DWrapperNode::callback_cam_info(const std_msgs::msg::String::SharedPtr camera.width = cam["width"].get(); camera.height = cam["height"].get(); camera.type = cam["type"].get(); - cameras[cam_idx] = camera; - // Check if all cameras are set - bool all_set = true; - for (size_t i = 0; i < cam_ids.size(); i++) - { - if (cameras[i].name.empty()) - { - all_set = false; - break; - } - } - this->all_cams_set = all_set; + cams_mutex.lock(); + all_cameras[cam_idx] = camera; + cams_mutex.unlock(); } // ================================================================================================= void Rpt3DWrapperNode::callback_poses(const rpt_msgs::msg::Poses::SharedPtr msg, size_t cam_idx) { - if (!this->all_cams_set) - { - RCLCPP_WARN(this->get_logger(), "Skipping frame, still waiting for cameras..."); - return; - } - auto ts_image = std::chrono::high_resolution_clock::now(); - std::vector>> poses; - std::vector &joint_names_2d = msg->joint_names; + if (this->joint_names.empty()) + { + this->joint_names = msg->joint_names; + } // Unflatten poses size_t idx = 0; @@ -179,25 +176,69 @@ void Rpt3DWrapperNode::callback_poses(const rpt_msgs::msg::Poses::SharedPtr msg, poses.push_back(std::move(body)); } + // If no pose was detected, create an empty placeholder + if (poses.size() == 0) + { + std::vector> body(joint_names.size(), {0, 0, 0}); + poses.push_back(std::move(body)); + } + + pose_mutex.lock(); this->all_poses[cam_idx] = std::move(poses); this->all_poses_set[cam_idx] = true; this->all_timestamps[cam_idx] = msg->header.stamp.sec + msg->header.stamp.nanosec / 1.0e9; + pose_mutex.unlock(); + + // Trigger model callback + model_mutex.lock(); + call_model(); + model_mutex.unlock(); +} + +// ================================================================================================= + +void Rpt3DWrapperNode::call_model() +{ + auto ts_msg = std::chrono::high_resolution_clock::now(); + + // Check if all cameras are set + cams_mutex.lock(); + for (size_t i = 0; i < cam_ids.size(); i++) + { + if (all_cameras[i].name.empty()) + { + RCLCPP_WARN(this->get_logger(), "Skipping frame, still waiting for cameras..."); + cams_mutex.unlock(); + return; + } + } + cams_mutex.unlock(); // If not all poses are set, return and wait for the others + pose_mutex.lock(); for (size_t i = 0; i < cam_ids.size(); i++) { if (!this->all_poses_set[i]) { + pose_mutex.unlock(); return; } } + pose_mutex.unlock(); - // Call model + // Call model, and meanwhile lock updates of the inputs + // Since the prediction is very fast, parallel callback threads only need to wait a short time + cams_mutex.lock(); + pose_mutex.lock(); const auto valid_poses = tri_model->triangulate_poses( - all_poses, cameras, roomparams, joint_names_2d); + all_poses, all_cameras, roomparams, joint_names); + + double min_ts = *std::min_element(all_timestamps.begin(), all_timestamps.end()); + this->all_poses_set = std::vector(cam_ids.size(), false); + pose_mutex.unlock(); + cams_mutex.unlock(); // Calculate timings - double min_ts = *std::min_element(all_timestamps.begin(), all_timestamps.end()); auto ts_pose = std::chrono::high_resolution_clock::now(); double ts_pose_sec = std::chrono::duration(ts_pose.time_since_epoch()).count(); double z_trigger_pose3d = ts_pose_sec - min_ts; @@ -209,11 +250,11 @@ void Rpt3DWrapperNode::callback_poses(const rpt_msgs::msg::Poses::SharedPtr msg, // Publish message auto pose_msg = rpt_msgs::msg::Poses(); - pose_msg.header = msg->header; pose_msg.header.stamp.sec = static_cast(min_ts); pose_msg.header.stamp.nanosec = (min_ts - pose_msg.header.stamp.sec) * 1.0e9; - std::vector pshape = {(int)valid_poses.size(), (int)joint_names_2d.size(), 4}; - pose_msg.bodies_shape = bshape; + pose_msg.header.frame_id = "world"; + std::vector pshape = {(int)valid_poses.size(), (int)joint_names.size(), 4}; + pose_msg.bodies_shape = pshape; pose_msg.bodies_flat.reserve(pshape[0] * pshape[1] * pshape[2]); for (int32_t i = 0; i < pshape[0]; i++) { @@ -225,18 +266,16 @@ void Rpt3DWrapperNode::callback_poses(const rpt_msgs::msg::Poses::SharedPtr msg, } } } - pose_msg.joint_names = joint_names_2d; + pose_msg.joint_names = joint_names; pose_msg.extra_data = jdata.dump(); pose_pub_->publish(pose_msg); // Print info double elapsed_time = std::chrono::duration( - std::chrono::high_resolution_clock::now() - ts_image) + std::chrono::high_resolution_clock::now() - ts_msg) .count(); std::cout << "Detected persons: " << valid_poses.size() << " - Prediction time: " << elapsed_time << "s" << std::endl; - - this->all_poses_set = std::vector(cam_ids.size(), false); } // ================================================================================================= @@ -246,7 +285,9 @@ int main(int argc, char **argv) { rclcpp::init(argc, argv); auto node = std::make_shared(); - rclcpp::spin(node); + rclcpp::executors::MultiThreadedExecutor exec; + exec.add_node(node); + exec.spin(); rclcpp::shutdown(); return 0; }