362 lines
16 KiB
Python
362 lines
16 KiB
Python
import os
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from easypose import model
|
|
from easypose.model import detection
|
|
from easypose.model import pose
|
|
from .download import get_url, get_model_path, download
|
|
from .consts import AvailablePoseModels, AvailableDetModels
|
|
from .common import Person, region_of_interest, restore_keypoints
|
|
|
|
|
|
def get_pose_model(pose_model_path, pose_model_decoder, device, warmup):
|
|
if pose_model_decoder == 'Dark':
|
|
pose_model = pose.Heatmap(pose_model_path, dark=True, device=device, warmup=warmup)
|
|
else:
|
|
pose_model = getattr(pose, pose_model_decoder)(pose_model_path, device=device, warmup=warmup)
|
|
return pose_model
|
|
|
|
|
|
def get_det_model(det_model_path, model_type, conf_thre, iou_thre, device, warmup):
|
|
det_model = getattr(detection, model_type)(det_model_path, conf_thre, iou_thre, device, warmup)
|
|
return det_model
|
|
|
|
|
|
def region_of_interest_warped(
|
|
image: np.ndarray,
|
|
box: np.ndarray,
|
|
target_size=(288, 384),
|
|
padding_scale: float = 1.25,
|
|
):
|
|
start_x, start_y, end_x, end_y = box
|
|
target_w, target_h = target_size
|
|
|
|
# Calculate original bounding box width and height
|
|
bbox_w = end_x - start_x
|
|
bbox_h = end_y - start_y
|
|
|
|
if bbox_w <= 0 or bbox_h <= 0:
|
|
raise ValueError("Invalid bounding box!")
|
|
|
|
# Calculate the aspect ratios
|
|
bbox_aspect = bbox_w / bbox_h
|
|
target_aspect = target_w / target_h
|
|
|
|
# Adjust the scaled bounding box to match the target aspect ratio
|
|
if bbox_aspect > target_aspect:
|
|
adjusted_h = bbox_w / target_aspect
|
|
adjusted_w = bbox_w
|
|
else:
|
|
adjusted_w = bbox_h * target_aspect
|
|
adjusted_h = bbox_h
|
|
|
|
# Scale the bounding box by the padding_scale
|
|
scaled_bbox_w = adjusted_w * padding_scale
|
|
scaled_bbox_h = adjusted_h * padding_scale
|
|
|
|
# Calculate the center of the original box
|
|
center_x = (start_x + end_x) / 2.0
|
|
center_y = (start_y + end_y) / 2.0
|
|
|
|
# Calculate scaled bounding box coordinates
|
|
new_start_x = center_x - scaled_bbox_w / 2.0
|
|
new_start_y = center_y - scaled_bbox_h / 2.0
|
|
new_end_x = center_x + scaled_bbox_w / 2.0
|
|
new_end_y = center_y + scaled_bbox_h / 2.0
|
|
|
|
# Define the new box coordinates
|
|
new_box = np.array(
|
|
[new_start_x, new_start_y, new_end_x, new_end_y], dtype=np.float32
|
|
)
|
|
scale = target_w / scaled_bbox_w
|
|
|
|
# Define source and destination points for affine transformation
|
|
# See: /mmpose/structures/bbox/transforms.py
|
|
src_pts = np.array(
|
|
[
|
|
[center_x, center_y],
|
|
[new_start_x, center_y],
|
|
[new_start_x, center_y + (center_x - new_start_x)],
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
dst_pts = np.array(
|
|
[
|
|
[target_w * 0.5, target_h * 0.5],
|
|
[0, target_h * 0.5],
|
|
[0, target_h * 0.5 + (target_w * 0.5 - 0)],
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
# Compute the affine transformation matrix
|
|
M = cv2.getAffineTransform(src_pts, dst_pts)
|
|
|
|
# Apply affine transformation with border filling
|
|
extracted_region = cv2.warpAffine(
|
|
image,
|
|
M,
|
|
target_size,
|
|
flags=cv2.INTER_LINEAR,
|
|
)
|
|
|
|
return extracted_region, new_box, scale
|
|
|
|
|
|
class TopDown:
|
|
def __init__(self,
|
|
pose_model_name,
|
|
pose_model_decoder,
|
|
det_model_name,
|
|
conf_threshold=0.6,
|
|
iou_threshold=0.6,
|
|
device='CUDA',
|
|
warmup=30):
|
|
if not pose_model_name.endswith('.onnx') and pose_model_name not in AvailablePoseModels.POSE_MODELS:
|
|
raise ValueError(
|
|
'The {} human pose estimation model is not in the model repository.'.format(pose_model_name))
|
|
if not pose_model_name.endswith('.onnx') and pose_model_decoder not in AvailablePoseModels.POSE_MODELS[pose_model_name]:
|
|
raise ValueError(
|
|
'No {} decoding head for the {} model was found in the model repository.'.format(pose_model_decoder,
|
|
pose_model_name))
|
|
if not pose_model_name.endswith('.onnx') and det_model_name not in AvailableDetModels.DET_MODELS:
|
|
raise ValueError(
|
|
'The {} detection model is not in the model repository.'.format(det_model_name))
|
|
|
|
if not pose_model_name.endswith('.onnx'):
|
|
pose_model_dir = get_model_path(AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder],
|
|
detection_model=False)
|
|
pose_model_path = os.path.join(pose_model_dir,
|
|
AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder])
|
|
else:
|
|
pose_model_path = pose_model_name
|
|
|
|
if os.path.exists(pose_model_path):
|
|
try:
|
|
self.pose_model = get_pose_model(pose_model_path, pose_model_decoder, device, warmup)
|
|
except Exception:
|
|
url = get_url(AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder],
|
|
detection_model=False)
|
|
download(url, pose_model_dir)
|
|
self.pose_model = get_pose_model(pose_model_path, pose_model_decoder, device, warmup)
|
|
else:
|
|
url = get_url(AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder],
|
|
detection_model=False)
|
|
download(url, pose_model_dir)
|
|
self.pose_model = get_pose_model(pose_model_path, pose_model_decoder, device, warmup)
|
|
|
|
if not det_model_name.endswith('.onnx'):
|
|
det_model_dir = get_model_path(AvailableDetModels.DET_MODELS[det_model_name]['file_name'],
|
|
detection_model=True)
|
|
det_model_path = os.path.join(det_model_dir,
|
|
AvailableDetModels.DET_MODELS[det_model_name]['file_name'])
|
|
det_model_type = AvailableDetModels.DET_MODELS[det_model_name]['model_type']
|
|
else:
|
|
det_model_path = det_model_name
|
|
if "rtmdet" in det_model_name:
|
|
det_model_type = 'RTMDet'
|
|
|
|
if os.path.exists(det_model_path):
|
|
try:
|
|
self.det_model = get_det_model(det_model_path,
|
|
det_model_type,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
device,
|
|
warmup)
|
|
except Exception:
|
|
url = get_url(AvailableDetModels.DET_MODELS[det_model_name]['file_name'],
|
|
detection_model=True)
|
|
download(url, det_model_dir)
|
|
self.det_model = get_det_model(det_model_path,
|
|
det_model_type,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
device,
|
|
warmup)
|
|
else:
|
|
url = get_url(AvailableDetModels.DET_MODELS[det_model_name]['file_name'],
|
|
detection_model=True)
|
|
download(url, det_model_dir)
|
|
self.det_model = get_det_model(det_model_path,
|
|
det_model_type,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
device,
|
|
warmup)
|
|
|
|
def predict(self, image):
|
|
boxes = self.det_model(image)
|
|
results = []
|
|
for i in range(boxes.shape[0]):
|
|
p = Person()
|
|
p.box = boxes[i]
|
|
region, p.box, _ = region_of_interest_warped(image, p.box)
|
|
kp = self.pose_model(region)
|
|
|
|
# See: /mmpose/models/pose_estimators/topdown.py - add_pred_to_datasample()
|
|
th, tw = region.shape[:2]
|
|
bw, bh = [p.box[2] - p.box[0], p.box[3] - p.box[1]]
|
|
kp[:, :2] = kp[:, :2] / np.array([tw, th]) * np.array([bw, bh])
|
|
kp[:, :2] += np.array([p.box[0] + bw / 2, p.box[1] + bh / 2])
|
|
kp[:, :2] -= 0.5 * np.array([bw, bh])
|
|
|
|
p.keypoints = kp
|
|
results.append(p)
|
|
return results
|
|
|
|
|
|
class Pose:
|
|
def __init__(self,
|
|
pose_model_name,
|
|
pose_model_decoder,
|
|
device='CUDA',
|
|
warmup=30):
|
|
if pose_model_name not in AvailablePoseModels.POSE_MODELS:
|
|
raise ValueError(
|
|
'The {} human pose estimation model is not in the model repository.'.format(pose_model_name))
|
|
if pose_model_decoder not in AvailablePoseModels.POSE_MODELS[pose_model_name]:
|
|
raise ValueError(
|
|
'No {} decoding head for the {} model was found in the model repository.'.format(pose_model_decoder,
|
|
pose_model_name))
|
|
|
|
pose_model_dir = get_model_path(AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder],
|
|
detection_model=False)
|
|
pose_model_path = os.path.join(pose_model_dir,
|
|
AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder])
|
|
|
|
if os.path.exists(pose_model_path):
|
|
try:
|
|
self.pose_model = get_pose_model(pose_model_path, pose_model_decoder, device, warmup)
|
|
except Exception:
|
|
url = get_url(AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder],
|
|
detection_model=False)
|
|
download(url, pose_model_dir)
|
|
self.pose_model = get_pose_model(pose_model_path, pose_model_decoder, device, warmup)
|
|
else:
|
|
url = get_url(AvailablePoseModels.POSE_MODELS[pose_model_name][pose_model_decoder],
|
|
detection_model=False)
|
|
download(url, pose_model_dir)
|
|
self.pose_model = get_pose_model(pose_model_path, pose_model_decoder, device, warmup)
|
|
|
|
def predict(self, image):
|
|
p = Person()
|
|
box = np.array([0, 0, image.shape[3], image.shape[2], 1, 0])
|
|
p.box = box
|
|
p.keypoints = self.pose_model(image)
|
|
return p
|
|
|
|
|
|
class CustomTopDown:
|
|
def __init__(self,
|
|
pose_model,
|
|
det_model,
|
|
pose_decoder=None,
|
|
device='CUDA',
|
|
iou_threshold=0.6,
|
|
conf_threshold=0.6,
|
|
warmup=30):
|
|
if isinstance(pose_model, model.BaseModel):
|
|
self.pose_model = pose_model
|
|
elif isinstance(pose_model, str):
|
|
if pose_model not in AvailablePoseModels.POSE_MODELS:
|
|
raise ValueError(
|
|
'The {} human pose estimation model is not in the model repository.'.format(pose_model))
|
|
if pose_model not in AvailablePoseModels.POSE_MODELS[pose_model]:
|
|
raise ValueError(
|
|
'No {} decoding head for the {} model was found in the model repository.'.format(pose_decoder,
|
|
pose_model))
|
|
|
|
pose_model_dir = get_model_path(AvailablePoseModels.POSE_MODELS[pose_model][pose_decoder],
|
|
detection_model=False)
|
|
pose_model_path = os.path.join(pose_model_dir,
|
|
AvailablePoseModels.POSE_MODELS[pose_model][pose_decoder])
|
|
|
|
if os.path.exists(pose_model_path):
|
|
try:
|
|
self.pose_model = get_pose_model(pose_model_path, pose_decoder, device, warmup)
|
|
except Exception:
|
|
url = get_url(AvailablePoseModels.POSE_MODELS[pose_model][pose_decoder],
|
|
detection_model=False)
|
|
download(url, pose_model_dir)
|
|
self.pose_model = get_pose_model(pose_model_path, pose_decoder, device, warmup)
|
|
else:
|
|
url = get_url(AvailablePoseModels.POSE_MODELS[pose_model][pose_decoder],
|
|
detection_model=False)
|
|
download(url, pose_model_dir)
|
|
self.pose_model = get_pose_model(pose_model_path, pose_decoder, device, warmup)
|
|
else:
|
|
raise TypeError("Invalid type for pose model, Please write a custom model based on 'BaseModel'.")
|
|
|
|
if isinstance(det_model, model.BaseModel):
|
|
self.det_model = det_model
|
|
elif isinstance(det_model, str):
|
|
if det_model not in AvailableDetModels.DET_MODELS:
|
|
raise ValueError(
|
|
'The {} detection model is not in the model repository.'.format(det_model))
|
|
|
|
det_model_dir = get_model_path(AvailableDetModels.DET_MODELS[det_model]['file_name'],
|
|
detection_model=True)
|
|
det_model_path = os.path.join(det_model_dir,
|
|
AvailableDetModels.DET_MODELS[det_model]['file_name'])
|
|
det_model_type = AvailableDetModels.DET_MODELS[det_model]['model_type']
|
|
if os.path.exists(det_model_path):
|
|
try:
|
|
self.det_model = get_det_model(det_model_path,
|
|
det_model_type,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
device,
|
|
warmup)
|
|
except Exception:
|
|
url = get_url(AvailableDetModels.DET_MODELS[det_model]['file_name'],
|
|
detection_model=True)
|
|
download(url, det_model_dir)
|
|
self.det_model = get_det_model(det_model_path,
|
|
det_model_type,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
device,
|
|
warmup)
|
|
else:
|
|
url = get_url(AvailableDetModels.DET_MODELS[det_model]['file_name'],
|
|
detection_model=True)
|
|
download(url, det_model_dir)
|
|
self.det_model = get_det_model(det_model_path,
|
|
det_model_type,
|
|
conf_threshold,
|
|
iou_threshold,
|
|
device,
|
|
warmup)
|
|
else:
|
|
raise TypeError("Invalid type for detection model, Please write a custom model based on 'BaseModel'.")
|
|
|
|
def predict(self, image):
|
|
boxes = self.det_model(image)
|
|
results = []
|
|
for i in range(boxes.shape[0]):
|
|
p = Person()
|
|
p.box = boxes[i]
|
|
region = region_of_interest(image, p.box)
|
|
kp = self.pose_model(region)
|
|
p.keypoints = restore_keypoints(p.box, kp)
|
|
results.append(p)
|
|
return results
|
|
|
|
|
|
class CustomSinglePose:
|
|
def __init__(self, pose_model):
|
|
if isinstance(pose_model, model.BaseModel):
|
|
self.pose_model = pose_model
|
|
else:
|
|
raise TypeError("Invalid type for pose model, Please write a custom model based on 'BaseModel'.")
|
|
|
|
def predict(self, image):
|
|
p = Person()
|
|
box = np.array([0, 0, image.shape[3], image.shape[2], 1, 0])
|
|
p.box = box
|
|
p.keypoints = self.pose_model(image)
|
|
return p
|