195 lines
5.1 KiB
Python
195 lines
5.1 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
import numpy as np
|
|
import rclpy
|
|
from cv_bridge import CvBridge
|
|
from rclpy.qos import QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy
|
|
from sensor_msgs.msg import Image
|
|
from std_msgs.msg import String
|
|
|
|
filepath = os.path.dirname(os.path.realpath(__file__)) + "/"
|
|
sys.path.append(filepath + "../../../scripts/")
|
|
import test_triangulate
|
|
import utils_2d_pose
|
|
|
|
# ==================================================================================================
|
|
|
|
bridge = CvBridge()
|
|
node = None
|
|
publisher_pose = None
|
|
|
|
cam_id = "camera01"
|
|
img_input_topic = "/" + cam_id + "/pylon_ros2_camera_node/image_raw"
|
|
pose_out_topic = "/poses/" + cam_id
|
|
|
|
last_input_image = None
|
|
last_input_time = 0.0
|
|
kpt_model = None
|
|
joint_names_2d = test_triangulate.joint_names_2d
|
|
|
|
lock = threading.Lock()
|
|
stop_flag = False
|
|
|
|
# Model config
|
|
min_bbox_score = 0.4
|
|
min_bbox_area = 0.1 * 0.1
|
|
batch_poses = False
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
def callback_images(image_data):
|
|
global last_input_image, last_input_time, lock
|
|
|
|
# Convert into cv images from image string
|
|
if image_data.encoding == "bayer_rggb8":
|
|
bayer_image = bridge.imgmsg_to_cv2(image_data, "bayer_rggb8")
|
|
elif image_data.encoding == "mono8":
|
|
bayer_image = bridge.imgmsg_to_cv2(image_data, "mono8")
|
|
elif image_data.encoding == "rgb8":
|
|
color_image = bridge.imgmsg_to_cv2(image_data, "rgb8")
|
|
bayer_image = test_triangulate.rgb2bayer(color_image)
|
|
else:
|
|
raise ValueError("Unknown image encoding:", image_data.encoding)
|
|
|
|
time_stamp = image_data.header.stamp.sec + image_data.header.stamp.nanosec / 1.0e9
|
|
|
|
with lock:
|
|
last_input_image = bayer_image
|
|
last_input_time = time_stamp
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
def callback_model():
|
|
global last_input_image, last_input_time, kpt_model, lock
|
|
|
|
ptime = time.time()
|
|
if last_input_time == 0.0:
|
|
time.sleep(0.0001)
|
|
return
|
|
|
|
# Collect inputs
|
|
images_2d = []
|
|
timestamps = []
|
|
with lock:
|
|
img = last_input_image
|
|
ts = last_input_time
|
|
images_2d.append(img)
|
|
timestamps.append(ts)
|
|
last_input_image = None
|
|
last_input_time = 0.0
|
|
|
|
# Predict 2D poses
|
|
images_2d = [test_triangulate.bayer2rgb(img) for img in images_2d]
|
|
poses_2d = utils_2d_pose.get_2d_pose(kpt_model, images_2d)
|
|
poses_2d = test_triangulate.update_keypoints(poses_2d, joint_names_2d)
|
|
poses_2d = poses_2d[0]
|
|
|
|
# Drop persons with no joints
|
|
poses_2d = np.asarray(poses_2d)
|
|
mask = np.sum(poses_2d[..., 2], axis=1) > 0
|
|
poses_2d = poses_2d[mask]
|
|
|
|
# Round poses
|
|
poses2D = [np.array(p).round(3).tolist() for p in poses_2d]
|
|
|
|
# Publish poses
|
|
ts_pose = time.time()
|
|
poses = {
|
|
"bodies2D": poses2D,
|
|
"joints": joint_names_2d,
|
|
"num_persons": len(poses2D),
|
|
"timestamps": {
|
|
"image": timestamps[0],
|
|
"pose": ts_pose,
|
|
"z-images-pose": ts_pose - timestamps[0],
|
|
},
|
|
}
|
|
publish(poses)
|
|
|
|
msg = "Detected persons: {} - Prediction time: {:.4f}s"
|
|
print(msg.format(poses["num_persons"], time.time() - ptime))
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
def callback_wrapper():
|
|
global stop_flag
|
|
while not stop_flag:
|
|
callback_model()
|
|
time.sleep(0.001)
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
def publish(data):
|
|
# Publish json data
|
|
msg = String()
|
|
msg.data = json.dumps(data)
|
|
publisher_pose.publish(msg)
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
def main():
|
|
global node, publisher_pose, kpt_model, stop_flag
|
|
|
|
# Start node
|
|
rclpy.init(args=sys.argv)
|
|
node = rclpy.create_node("rpt2D_wrapper")
|
|
|
|
# Quality of service settings
|
|
qos_profile = QoSProfile(
|
|
reliability=QoSReliabilityPolicy.RELIABLE,
|
|
history=QoSHistoryPolicy.KEEP_LAST,
|
|
depth=1,
|
|
)
|
|
|
|
# Create subscribers
|
|
_ = node.create_subscription(
|
|
Image,
|
|
img_input_topic,
|
|
callback_images,
|
|
qos_profile,
|
|
)
|
|
|
|
# Create publishers
|
|
publisher_pose = node.create_publisher(String, pose_out_topic, qos_profile)
|
|
|
|
# Load 2D pose model
|
|
whole_body = test_triangulate.whole_body
|
|
if any((whole_body[k] for k in whole_body)):
|
|
kpt_model = utils_2d_pose.load_wb_model()
|
|
else:
|
|
kpt_model = utils_2d_pose.load_model(min_bbox_score, min_bbox_area, batch_poses)
|
|
|
|
node.get_logger().info("Finished initialization of pose estimator")
|
|
|
|
# Start prediction thread
|
|
p1 = threading.Thread(target=callback_wrapper)
|
|
p1.start()
|
|
|
|
# Run ros update thread
|
|
rclpy.spin(node)
|
|
|
|
stop_flag = True
|
|
p1.join()
|
|
node.destroy_node()
|
|
rclpy.shutdown()
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
if __name__ == "__main__":
|
|
main()
|