SkeletonGait

This commit is contained in:
Jingzhe Ma
2024-03-07 21:10:22 +08:00
parent 662ffb25b9
commit 29578ccfd7
12 changed files with 59 additions and 29 deletions
+8 -8
View File
@@ -6,8 +6,8 @@ from glob import glob
def get_args():
parser = argparse.ArgumentParser(description='Symlink silouette data and pose data into the same folder for SkeletonGait++ training.')
parser.add_argument('--heatmapt_data_path', type=str, required=True, help="path of heatmap data, must be the absolute path.")
parser.add_argument('--silouette_data_path', type=str, required=True, help="path of silouette data, must be the absolute path.")
parser.add_argument('--heatmap_data_path', type=str, required=True, help="path of heatmap data, must be the absolute path.")
parser.add_argument('--silhouette_data_path', type=str, required=True, help="path of silouette data, must be the absolute path.")
parser.add_argument('--dataset_pkl_ext_name', type=str, default='.pkl', help="The extent name for .pkl files of silouettes data.")
parser.add_argument('--output_path', type=str, required=True, help="path of output data")
opt = parser.parse_args()
@@ -15,24 +15,24 @@ def get_args():
def main():
opt = get_args()
heatmap_data_path = opt.heatmapt_data_path
silouette_data_path = opt.silouette_data_path
heatmap_data_path = opt.heatmap_data_path
silhouette_data_path = opt.silhouette_data_path
if not os.path.exists(heatmap_data_path):
print(f"heatmap data path {heatmap_data_path} does not exist.")
sys.exit(1)
if not os.path.exists(silouette_data_path):
print(f"silouette data path {silouette_data_path} does not exist.")
if not os.path.exists(silhouette_data_path):
print(f"silouette data path {silhouette_data_path} does not exist.")
sys.exit(1)
all_heatmap_files = sorted(glob(os.path.join(heatmap_data_path, "*/*/*/*.pkl")))
all_silouette_files = sorted(glob(os.path.join(silouette_data_path, f"*/*/*/*{opt.dataset_pkl_ext_name}")))
all_silouette_files = sorted(glob(os.path.join(silhouette_data_path, f"*/*/*/*{opt.dataset_pkl_ext_name}")))
# print(len(all_heatmap_files), len(all_silouette_files))
# assert len(all_heatmap_files) == len(all_silouette_files), "The number of heatmap files and silouette files are not equal."
if len(all_heatmap_files) >= len(all_silouette_files):
for heatmap_file in tqdm(all_heatmap_files):
tmp_list = heatmap_file.split('/')
sil_folder = os.path.join(silouette_data_path, *tmp_list[-4:-1])
sil_folder = os.path.join(silhouette_data_path, *tmp_list[-4:-1])
if not os.path.exists(sil_folder):
print(f"silouette folder {sil_folder} does not exist.")
continue
+2 -2
View File
@@ -647,8 +647,8 @@ class TransferDataset(Dataset):
os.makedirs(save_path_img, exist_ok=True)
# save_heatemapimg_index = random.choice(list(range(heatmap_img.shape[0])))
for save_heatemapimg_index in range(heatmap_img.shape[0]):
cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0])
cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1])
cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0])
cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1])
pickle.dump(heatmap_img, open(os.path.join(save_path_pkl, tmp_split[-1]), 'wb'))
return None