Updated skelda version.

This commit is contained in:
Daniel
2024-10-14 16:10:19 +02:00
parent 87a205c36c
commit 256fd26f3f
4 changed files with 10 additions and 63 deletions

View File

@ -1,46 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
from skelda import utils_view
# ==================================================================================================
def show_poses2d(bodies, images, joint_names, title=""):
num_imgs = len(images)
rowbreak = int(num_imgs / 2.0 + 0.5)
fig, axs = plt.subplots(2, rowbreak, figsize=(30, 20))
fig.suptitle(title, fontsize=20)
if isinstance(bodies, np.ndarray):
bodies = bodies.tolist()
# Draw skeletons into images
for i, image in enumerate(images):
colors = plt.cm.hsv(np.linspace(0, 1, len(bodies[i]), endpoint=False)).tolist()
colors = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
for j, body in enumerate(bodies[i]):
image = utils_view.draw_body_in_image(image, body, joint_names, colors[j])
# Rescale image range for plotting
images = [img / 255.0 for img in images]
if rowbreak == 1:
axs[0].imshow(images[0])
if len(images) == 2:
axs[1].imshow(images[1])
else:
# Optionally delete last empty plot
fig.delaxes(axs[1])
else:
for i in range(rowbreak):
axs[0][i].imshow(images[i])
if i + rowbreak < num_imgs:
axs[1][i].imshow(images[i + rowbreak])
else:
# Optionally delete last empty plot
fig.delaxes(axs[1][i])
return fig

View File

@ -197,10 +197,6 @@ def load_labels(dataset: dict):
if take_interval > 1:
labels = [l for i, l in enumerate(labels) if i % take_interval == 0]
# Filter joints
fj_func = lambda x: utils_pose.filter_joints_3d(x, eval_joints)
labels = list(map(fj_func, labels))
return labels
@ -332,7 +328,7 @@ def main():
print("3D time:", time_3d)
old_index = label["index"]
all_poses.append(np.array(poses3D))
all_poses.append(np.array(poses3D).tolist())
all_ids.append(label["id"])
all_paths.append(label["imgpaths"])
times.append((time_2d, time_3d))

View File

@ -9,9 +9,8 @@ import cv2
import matplotlib
import numpy as np
import draw_utils
import utils_2d_pose
from skelda import utils_pose
from skelda import utils_pose, utils_view
sys.path.append("/SimplePoseTriangulation/swig/")
import spt
@ -324,14 +323,14 @@ def main():
print("2D time:", time.time() - stime)
# print([np.array(p).round(6).tolist() for p in poses_2d])
fig1 = draw_utils.show_poses2d(
poses_2d, np.array(images_2d), joint_names_2d, "2D detections"
fig1 = utils_view.draw_many_images(
sample["imgpaths_color"], [], [], poses_2d, joint_names_2d, "2D detections"
)
fig1.savefig(os.path.join(dirpath, "2d-k.png"), dpi=fig1.dpi)
# draw_utils.utils_view.show_plots()
if len(images_2d) == 1:
draw_utils.utils_view.show_plots()
utils_view.show_plots()
continue
# Get 3D poses
@ -361,15 +360,13 @@ def main():
# print(poses2D)
# print(poses3D.round(3).tolist())
fig2 = draw_utils.utils_view.show_poses3d(
poses3D, joint_names_3d, roomparams, camparams
)
fig3 = draw_utils.show_poses2d(
poses2D, np.array(images_2d), joint_names_3d, "2D reprojections"
fig2 = utils_view.draw_poses3d(poses3D, joint_names_3d, roomparams, camparams)
fig3 = utils_view.draw_many_images(
sample["imgpaths_color"], [], [], poses2D, joint_names_3d, "2D projections"
)
fig2.savefig(os.path.join(dirpath, "3d-p.png"), dpi=fig2.dpi)
fig3.savefig(os.path.join(dirpath, "2d-p.png"), dpi=fig3.dpi)
draw_utils.utils_view.show_plots()
utils_view.show_plots()
# ==================================================================================================