from xml.parsers.expat import model from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Poly3DCollection # plt.switch_backend('agg') def display_model( model_info, model_faces=None, with_joints=False, kintree_table=None, ax=None, batch_idx=0, show=True, savepath=None, only_joint=False): """ Displays mesh batch_idx in batch of model_info, model_info as returned by generate_random_model """ if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') verts = model_info['verts'][batch_idx] joints = model_info['joints'][batch_idx] if model_faces is None: ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) elif not only_joint: mesh = Poly3DCollection(verts[model_faces], alpha=0.2) face_color = (141 / 255, 184 / 255, 226 / 255) edge_color = (50 / 255, 50 / 255, 50 / 255) mesh.set_edgecolor(edge_color) mesh.set_facecolor(face_color) ax.add_collection3d(mesh) if with_joints: draw_skeleton(joints, kintree_table=kintree_table, ax=ax) ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') ax.set_xlim(-0.7, 0.7) ax.set_ylim(-0.7, 0.7) ax.set_zlim(-0.7, 0.7) ax.view_init(azim=-90, elev=100) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) if savepath: # print('Saving figure at {}.'.format(savepath)) plt.savefig(savepath, bbox_inches='tight', pad_inches=0) if show: plt.show() plt.close() return ax def draw_skeleton(joints3D, kintree_table, ax=None, with_numbers=True): if ax is None: fig = plt.figure(frameon=False) ax = fig.add_subplot(111, projection='3d') else: ax = ax colors = [] left_right_mid = ['r', 'g', 'b'] kintree_colors = [2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 0, 1, 0, 1] for c in kintree_colors: colors += left_right_mid[c] # For each 24 joint for i in range(1, kintree_table.shape[1]): j1 = kintree_table[0][i] j2 = kintree_table[1][i] ax.plot([joints3D[j1, 0], joints3D[j2, 0]], [joints3D[j1, 1], joints3D[j2, 1]], [joints3D[j1, 2], joints3D[j2, 2]], color=colors[i], linestyle='-', linewidth=2, marker='o', markersize=5) if with_numbers: ax.text(joints3D[j2, 0], joints3D[j2, 1], joints3D[j2, 2], j2) return ax