Files
Pose_to_SMPL_an_230402/display_utils.py
2021-08-05 16:22:06 +08:00

76 lines
2.5 KiB
Python

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
# plt.switch_backend('agg')
def display_model(
model_info,
model_faces=None,
with_joints=False,
kintree_table=None,
ax=None,
batch_idx=0,
show=True,
savepath=None,
only_joint=False):
"""
Displays mesh batch_idx in batch of model_info, model_info as returned by
generate_random_model
"""
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint:
mesh = Poly3DCollection(verts[model_faces], alpha=0.2)
face_color = (141 / 255, 184 / 255, 226 / 255)
edge_color = (50 / 255, 50 / 255, 50 / 255)
mesh.set_edgecolor(edge_color)
mesh.set_facecolor(face_color)
ax.add_collection3d(mesh)
if with_joints:
draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(-0.7, 0.7)
ax.set_ylim(-0.7, 0.7)
ax.set_zlim(-0.7, 0.7)
ax.view_init(azim=-90, elev=100)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
if savepath:
# print('Saving figure at {}.'.format(savepath))
plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
if show:
plt.show()
plt.close()
return ax
def draw_skeleton(joints3D, kintree_table, ax=None, with_numbers=True):
if ax is None:
fig = plt.figure(frameon=False)
ax = fig.add_subplot(111, projection='3d')
else:
ax = ax
colors = []
left_right_mid = ['r', 'g', 'b']
kintree_colors = [2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 0, 1, 0, 1]
for c in kintree_colors:
colors += left_right_mid[c]
# For each 24 joint
for i in range(1, kintree_table.shape[1]):
j1 = kintree_table[0][i]
j2 = kintree_table[1][i]
ax.plot([joints3D[j1, 0], joints3D[j2, 0]],
[joints3D[j1, 1], joints3D[j2, 1]],
[joints3D[j1, 2], joints3D[j2, 2]],
color=colors[i], linestyle='-', linewidth=2, marker='o', markersize=5)
if with_numbers:
ax.text(joints3D[j2, 0], joints3D[j2, 1], joints3D[j2, 2], j2)
return ax