OpenGait release(pre-beta version).

This commit is contained in:
梁峻豪
2021-10-18 14:30:21 +08:00
commit 57ee4a448e
41 changed files with 3772 additions and 0 deletions
+201
View File
@@ -0,0 +1,201 @@
# modified from https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
import os
import cv2
import numpy as np
from warnings import warn
from time import sleep
import argparse
import pickle
from multiprocessing import Pool
from multiprocessing import TimeoutError as MP_TimeoutError
START = "START"
FINISH = "FINISH"
WARNING = "WARNING"
FAIL = "FAIL"
def boolean_string(s):
if s.upper() not in {'FALSE', 'TRUE'}:
raise ValueError('Not a valid boolean string')
return s.upper() == 'TRUE'
parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--input_path', default='', type=str,
help='Root path of raw dataset.')
parser.add_argument('--output_path', default='', type=str,
help='Root path for output.')
parser.add_argument('--log_file', default='./pretreatment.log', type=str,
help='Log file path. Default: ./pretreatment.log')
parser.add_argument('--log', default=False, type=boolean_string,
help='If set as True, all logs will be saved. '
'Otherwise, only warnings and errors will be saved.'
'Default: False')
parser.add_argument('--worker_num', default=1, type=int,
help='How many subprocesses to use for data pretreatment. '
'Default: 1')
parser.add_argument('--img_size', default=64, type=int,
help='image size')
opt = parser.parse_args()
INPUT_PATH = opt.input_path
OUTPUT_PATH = opt.output_path
IF_LOG = opt.log
LOG_PATH = opt.log_file
WORKERS = opt.worker_num
T_H = opt.img_size
T_W = opt.img_size
def log2str(pid, comment, logs):
str_log = ''
if type(logs) is str:
logs = [logs]
for log in logs:
str_log += "# JOB %d : --%s-- %s\n" % (
pid, comment, log)
return str_log
def log_print(pid, comment, logs):
str_log = log2str(pid, comment, logs)
if comment in [WARNING, FAIL]:
with open(LOG_PATH, 'a') as log_f:
log_f.write(str_log)
if comment in [START, FINISH]:
if pid % 500 != 0:
return
print(str_log, end='')
def cut_img(img, seq_info, frame_name, pid):
# A silhouette contains too little white pixels
# might be not valid for identification.
if img.sum() <= 10000:
message = 'seq:%s, frame:%s, no data, %d.' % (
'-'.join(seq_info), frame_name, img.sum())
warn(message)
log_print(pid, WARNING, message)
return None
# Get the top and bottom point
y = img.sum(axis=1)
y_top = (y != 0).argmax(axis=0)
y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
img = img[y_top:y_btm + 1, :]
# As the height of a person is larger than the width,
# use the height to calculate resize ratio.
_r = img.shape[1] / img.shape[0]
_t_w = int(T_H * _r)
img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
# Get the median of x axis and regard it as the x center of the person.
sum_point = img.sum()
sum_column = img.sum(axis=0).cumsum()
x_center = -1
for i in range(sum_column.size):
if sum_column[i] > sum_point / 2:
x_center = i
break
if x_center < 0:
message = 'seq:%s, frame:%s, no center.' % (
'-'.join(seq_info), frame_name)
warn(message)
log_print(pid, WARNING, message)
return None
h_T_W = int(T_W / 2)
left = x_center - h_T_W
right = x_center + h_T_W
if left <= 0 or right >= img.shape[1]:
left += h_T_W
right += h_T_W
_ = np.zeros((img.shape[0], h_T_W))
img = np.concatenate([_, img, _], axis=1)
img = img[:, left:right]
return img.astype('uint8')
def cut_pickle(seq_info, pid):
seq_name = '-'.join(seq_info)
log_print(pid, START, seq_name)
seq_path = os.path.join(INPUT_PATH, *seq_info)
out_dir = os.path.join(OUTPUT_PATH, *seq_info)
frame_list = os.listdir(seq_path)
frame_list.sort()
count_frame = 0
all_imgs = []
view = seq_info[-1]
for _frame_name in frame_list:
frame_path = os.path.join(seq_path, _frame_name)
img = cv2.imread(frame_path)[:, :, 0]
img = cut_img(img, seq_info, _frame_name, pid)
if img is not None:
# Save the cut img
all_imgs.append(img)
count_frame += 1
all_imgs = np.asarray(all_imgs)
if count_frame > 0:
os.makedirs(out_dir)
all_imgs_pkl = os.path.join(out_dir, '{}.pkl'.format(view))
pickle.dump(all_imgs, open(all_imgs_pkl, 'wb'))
# Warn if the sequence contains less than 5 frames
if count_frame < 5:
message = 'seq:%s, less than 5 valid data.' % (
'-'.join(seq_info))
warn(message)
log_print(pid, WARNING, message)
log_print(pid, FINISH,
'Contain %d valid frames. Saved to %s.'
% (count_frame, out_dir))
pool = Pool(WORKERS)
results = list()
pid = 0
print('Pretreatment Start.\n'
'Input path: %s\n'
'Output path: %s\n'
'Log file: %s\n'
'Worker num: %d' % (
INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS))
id_list = os.listdir(INPUT_PATH)
id_list.sort()
# Walk the input path
for _id in id_list:
seq_type = os.listdir(os.path.join(INPUT_PATH, _id))
seq_type.sort()
for _seq_type in seq_type:
view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type))
view.sort()
for _view in view:
seq_info = [_id, _seq_type, _view]
out_dir = os.path.join(OUTPUT_PATH, *seq_info)
# os.makedirs(out_dir)
results.append(
pool.apply_async(
cut_pickle,
args=(seq_info, pid)))
sleep(0.02)
pid += 1
pool.close()
unfinish = 1
while unfinish > 0:
unfinish = 0
for i, res in enumerate(results):
try:
res.get(timeout=0.1)
except Exception as e:
if type(e) == MP_TimeoutError:
unfinish += 1
continue
else:
print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n',
i, type(e))
raise e
pool.join()