Files
RapidPoseTriangulation/extras/ros/rpt2d_wrapper_cpp/src/rpt2d_wrapper.cpp
T
2025-02-20 18:51:48 +01:00

238 lines
7.9 KiB
C++

#include <atomic>
#include <chrono>
#include <iostream>
#include <string>
#include <vector>
// ROS2
#include <rclcpp/rclcpp.hpp>
#include <sensor_msgs/msg/image.hpp>
// OpenCV / cv_bridge
#include <cv_bridge/cv_bridge.h>
#include <opencv2/opencv.hpp>
// JSON library
#include "/RapidPoseTriangulation/extras/include/nlohmann/json.hpp"
using json = nlohmann::json;
#include "rpt_msgs/msg/poses.hpp"
#include "/RapidPoseTriangulation/scripts/utils_2d_pose.hpp"
#include "/RapidPoseTriangulation/scripts/utils_pipeline.hpp"
// =================================================================================================
static const std::string img_input_topic = "/{}/pylon_ros2_camera_node/image_raw";
static const std::string pose_out_topic = "/poses/{}";
static const float min_bbox_score = 0.4;
static const float min_bbox_area = 0.1 * 0.1;
static const bool batch_poses = true;
static const std::map<std::string, bool> whole_body = {
{"foots", true},
{"face", true},
{"hands", true},
};
// =================================================================================================
// =================================================================================================
class Rpt2DWrapperNode : public rclcpp::Node
{
public:
Rpt2DWrapperNode(const std::string &cam_id)
: Node("rpt2d_wrapper_" + cam_id)
{
this->is_busy = false;
std::string img_topic = std::string(img_input_topic)
.replace(img_input_topic.find("{}"), 2, cam_id);
std::string pose_topic = std::string(pose_out_topic)
.replace(pose_out_topic.find("{}"), 2, cam_id);
// QoS
rclcpp::QoS qos_profile(1);
qos_profile.reliable();
qos_profile.keep_last(1);
// Setup subscriber
image_sub_ = this->create_subscription<sensor_msgs::msg::Image>(
img_topic, qos_profile,
std::bind(&Rpt2DWrapperNode::callback_images, this, std::placeholders::_1));
// Setup publisher
pose_pub_ = this->create_publisher<rpt_msgs::msg::Poses>(pose_topic, qos_profile);
// Load model
bool use_wb = utils_pipeline::use_whole_body(whole_body);
this->kpt_model = std::make_unique<utils_2d_pose::PosePredictor>(
use_wb, min_bbox_score, min_bbox_area, batch_poses);
RCLCPP_INFO(this->get_logger(), "Finished initialization of pose estimator.");
}
private:
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_;
rclcpp::Publisher<rpt_msgs::msg::Poses>::SharedPtr pose_pub_;
std::atomic<bool> is_busy;
// Pose model pointer
std::unique_ptr<utils_2d_pose::PosePredictor> kpt_model;
const std::vector<std::string> joint_names_2d = utils_pipeline::get_joint_names(whole_body);
void callback_images(const sensor_msgs::msg::Image::SharedPtr msg);
std::vector<std::vector<std::array<float, 3>>> call_model(const cv::Mat &image);
};
// =================================================================================================
void Rpt2DWrapperNode::callback_images(const sensor_msgs::msg::Image::SharedPtr msg)
{
if (this->is_busy)
{
RCLCPP_WARN(this->get_logger(), "Skipping frame, still processing...");
return;
}
this->is_busy = true;
auto ts_image = std::chrono::high_resolution_clock::now();
// Load or convert image to Bayer format
cv::Mat bayer_image;
try
{
if (msg->encoding == "mono8")
{
cv_bridge::CvImageConstPtr cv_ptr = cv_bridge::toCvShare(msg, msg->encoding);
bayer_image = cv_ptr->image;
}
else if (msg->encoding == "bayer_rggb8")
{
cv_bridge::CvImageConstPtr cv_ptr = cv_bridge::toCvShare(msg, msg->encoding);
bayer_image = cv_ptr->image;
}
else if (msg->encoding == "rgb8")
{
cv_bridge::CvImageConstPtr cv_ptr = cv_bridge::toCvShare(msg, "rgb8");
cv::Mat color_image = cv_ptr->image;
bayer_image = utils_pipeline::rgb2bayer(color_image);
}
else
{
throw std::runtime_error("Unknown image encoding: " + msg->encoding);
}
}
catch (const std::exception &e)
{
RCLCPP_ERROR(this->get_logger(), "cv_bridge exception: %s", e.what());
return;
}
// Call model
const auto &valid_poses = this->call_model(bayer_image);
// Calculate timings
double time_stamp = msg->header.stamp.sec + msg->header.stamp.nanosec / 1.0e9;
auto ts_image_sec = std::chrono::duration<double>(ts_image.time_since_epoch()).count();
auto ts_pose = std::chrono::high_resolution_clock::now();
double ts_pose_sec = std::chrono::duration<double>(ts_pose.time_since_epoch()).count();
double z_trigger_image = ts_image_sec - time_stamp;
double z_trigger_pose = ts_pose_sec - time_stamp;
double z_image_pose = ts_pose_sec - ts_image_sec;
json jdata;
jdata["timestamps"] = {
{"trigger", time_stamp},
{"image", ts_image_sec},
{"pose2d", ts_pose_sec},
{"z-trigger-image", z_trigger_image},
{"z-image-pose2d", z_image_pose},
{"z-trigger-pose2d", z_trigger_pose}};
// Publish message
auto pose_msg = rpt_msgs::msg::Poses();
pose_msg.header = msg->header;
std::vector<int32_t> bshape = {(int)valid_poses.size(), (int)joint_names_2d.size(), 3};
pose_msg.bodies_shape = bshape;
pose_msg.bodies_flat.reserve(bshape[0] * bshape[1] * bshape[2]);
for (int32_t i = 0; i < bshape[0]; i++)
{
for (int32_t j = 0; j < bshape[1]; j++)
{
for (int32_t k = 0; k < bshape[2]; k++)
{
pose_msg.bodies_flat.push_back(valid_poses[i][j][k]);
}
}
}
pose_msg.joint_names = joint_names_2d;
pose_msg.extra_data = jdata.dump();
pose_pub_->publish(pose_msg);
// Print info
double elapsed_time = std::chrono::duration<double>(
std::chrono::high_resolution_clock::now() - ts_image)
.count();
std::cout << "Detected persons: " << valid_poses.size()
<< " - Prediction time: " << elapsed_time << "s" << std::endl;
this->is_busy = false;
}
// =================================================================================================
std::vector<std::vector<std::array<float, 3>>> Rpt2DWrapperNode::call_model(const cv::Mat &image)
{
// Create image vector
cv::Mat rgb_image = utils_pipeline::bayer2rgb(image);
std::vector<cv::Mat> images_2d = {rgb_image};
// Predict 2D poses
auto poses_2d_all = kpt_model->predict(images_2d);
auto poses_2d_upd = utils_pipeline::update_keypoints(
poses_2d_all, joint_names_2d, whole_body);
auto &poses_2d = poses_2d_upd[0];
// Drop persons with no joints
std::vector<std::vector<std::array<float, 3>>> valid_poses;
for (auto &person : poses_2d)
{
float sum_conf = 0.0;
for (auto &kp : person)
{
sum_conf += kp[2];
}
if (sum_conf > 0.0)
{
valid_poses.push_back(person);
}
}
// Round poses to 3 decimal places
for (auto &person : valid_poses)
{
for (auto &kp : person)
{
kp[0] = std::round(kp[0] * 1000.0) / 1000.0;
kp[1] = std::round(kp[1] * 1000.0) / 1000.0;
kp[2] = std::round(kp[2] * 1000.0) / 1000.0;
}
}
return valid_poses;
}
// =================================================================================================
// =================================================================================================
int main(int argc, char **argv)
{
rclcpp::init(argc, argv);
const std::string cam_id = std::getenv("CAMID");
auto node = std::make_shared<Rpt2DWrapperNode>(cam_id);
rclcpp::spin(node);
rclcpp::shutdown();
return 0;
}