Some mixed updates.

This commit is contained in:
Daniel
2024-12-18 16:28:13 +01:00
parent 07426fac2f
commit f8984f9408
6 changed files with 11 additions and 16 deletions

View File

@ -8,8 +8,9 @@ import matplotlib
import numpy as np
import tqdm
# import utils_2d_pose
import utils_2d_pose_ort as utils_2d_pose
import test_triangulate
import utils_2d_pose
from skelda import evals
sys.path.append("/RapidPoseTriangulation/swig/")

View File

@ -9,7 +9,8 @@ import cv2
import matplotlib
import numpy as np
import utils_2d_pose
# import utils_2d_pose
import utils_2d_pose_ort as utils_2d_pose
from skelda import utils_pose, utils_view
sys.path.append("/RapidPoseTriangulation/swig/")

View File

@ -97,9 +97,7 @@ class BaseModel(ABC):
ishape = list(self.input_shapes[i])
if "batch_size" in ishape:
max_batch_size = 10
ishape[0] = np.random.choice(
list(range(1, max_batch_size + 1))
)
ishape[0] = np.random.choice(list(range(1, max_batch_size + 1)))
tensor = np.random.random(ishape)
tensor = tensor * 255
else:
@ -498,17 +496,12 @@ def get_2d_pose(model, imgs, num_joints=17):
new_poses = []
for i in range(len(imgs)):
img = imgs[i]
poses = []
dets = model.predict(img)
for pose in dets:
pose = np.asarray(pose)
poses.append(pose)
if len(poses) == 0:
poses.append(np.zeros([num_joints, 3]))
poses = np.array(poses)
if len(dets) == 0:
poses = np.zeros([1, num_joints, 3], dtype=float)
else:
poses = np.asarray(dets, dtype=float)
new_poses.append(poses)
return new_poses