339 lines
10 KiB
Python
339 lines
10 KiB
Python
import cv2
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.ops import roi_align
|
|
|
|
# ==================================================================================================
|
|
|
|
base_path = "/RapidPoseTriangulation/extras/mmdeploy/exports/"
|
|
det_target_size = (320, 320)
|
|
pose_target_size = (384, 288)
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
class Letterbox(nn.Module):
|
|
def __init__(self, target_size, fill_value=128):
|
|
"""Resize and pad image while keeping aspect ratio"""
|
|
super(Letterbox, self).__init__()
|
|
|
|
self.target_size = target_size
|
|
self.fill_value = fill_value
|
|
|
|
def calc_params(self, ishape):
|
|
ih, iw = ishape[1], ishape[2]
|
|
th, tw = self.target_size
|
|
|
|
scale = torch.min(tw / iw, th / ih)
|
|
nw = torch.round(iw * scale)
|
|
nh = torch.round(ih * scale)
|
|
|
|
pad_w = tw - nw
|
|
pad_h = th - nh
|
|
pad_left = pad_w // 2
|
|
pad_top = pad_h // 2
|
|
pad_right = pad_w - pad_left
|
|
pad_bottom = pad_h - pad_top
|
|
paddings = (pad_left, pad_right, pad_top, pad_bottom)
|
|
|
|
return paddings, scale, (nw, nh)
|
|
|
|
def forward(self, img):
|
|
paddings, _, (nw, nh) = self.calc_params(img.shape)
|
|
|
|
# Resize the image
|
|
img = img.to(torch.float32)
|
|
img = img.permute(0, 3, 1, 2)
|
|
img = F.interpolate(
|
|
img,
|
|
size=(nh, nw),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
img = img.permute(0, 2, 3, 1)
|
|
img = img.round()
|
|
|
|
# Pad the image
|
|
img = F.pad(
|
|
img.permute(0, 3, 1, 2),
|
|
pad=paddings,
|
|
mode="constant",
|
|
value=self.fill_value,
|
|
)
|
|
img = img.permute(0, 2, 3, 1)
|
|
|
|
return img
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
class BoxCrop(nn.Module):
|
|
def __init__(self, target_size):
|
|
"""Crop bounding box from image"""
|
|
super(BoxCrop, self).__init__()
|
|
|
|
self.target_size = target_size
|
|
self.padding_scale = 1.25
|
|
|
|
def calc_params(self, bbox):
|
|
start_x, start_y, end_x, end_y = bbox[0, 0], bbox[0, 1], bbox[0, 2], bbox[0, 3]
|
|
target_h, target_w = self.target_size
|
|
|
|
# Calculate original bounding box width, height and center
|
|
bbox_w = end_x - start_x
|
|
bbox_h = end_y - start_y
|
|
center_x = (start_x + end_x) / 2.0
|
|
center_y = (start_y + end_y) / 2.0
|
|
|
|
# Calculate the aspect ratios
|
|
bbox_aspect = bbox_w / bbox_h
|
|
target_aspect = target_w / target_h
|
|
|
|
# Adjust the scaled bounding box to match the target aspect ratio
|
|
if bbox_aspect > target_aspect:
|
|
adjusted_h = bbox_w / target_aspect
|
|
adjusted_w = bbox_w
|
|
else:
|
|
adjusted_w = bbox_h * target_aspect
|
|
adjusted_h = bbox_h
|
|
|
|
# Scale the bounding box by the padding_scale
|
|
scaled_bbox_w = adjusted_w * self.padding_scale
|
|
scaled_bbox_h = adjusted_h * self.padding_scale
|
|
|
|
# Calculate scaled bounding box coordinates
|
|
new_start_x = center_x - scaled_bbox_w / 2.0
|
|
new_start_y = center_y - scaled_bbox_h / 2.0
|
|
new_end_x = center_x + scaled_bbox_w / 2.0
|
|
new_end_y = center_y + scaled_bbox_h / 2.0
|
|
|
|
# Define the new box coordinates
|
|
new_box = torch.stack((new_start_x, new_start_y, new_end_x, new_end_y), dim=0)
|
|
new_box = new_box.unsqueeze(0)
|
|
scale = torch.stack(
|
|
((target_w / scaled_bbox_w), (target_h / scaled_bbox_h)), dim=0
|
|
)
|
|
|
|
return scale, new_box
|
|
|
|
def forward(self, img, bbox):
|
|
_, bbox = self.calc_params(bbox)
|
|
|
|
batch_indices = torch.zeros(bbox.shape[0], 1)
|
|
rois = torch.cat([batch_indices, bbox], dim=1)
|
|
|
|
# Resize and crop
|
|
img = img.to(torch.float32)
|
|
img = img.permute(0, 3, 1, 2)
|
|
img = roi_align(
|
|
img,
|
|
rois,
|
|
output_size=self.target_size,
|
|
spatial_scale=1.0,
|
|
sampling_ratio=0,
|
|
)
|
|
img = img.permute(0, 2, 3, 1)
|
|
img = img.round()
|
|
|
|
return img
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
class DetPreprocess(nn.Module):
|
|
def __init__(self, target_size, fill_value=114):
|
|
super(DetPreprocess, self).__init__()
|
|
self.letterbox = Letterbox(target_size, fill_value)
|
|
|
|
def forward(self, img):
|
|
# img: torch.Tensor of shape [batch, H, W, C], dtype=torch.uint8
|
|
img = self.letterbox(img)
|
|
return img
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
class DetPostprocess(nn.Module):
|
|
def __init__(self, target_size):
|
|
super(DetPostprocess, self).__init__()
|
|
|
|
self.target_size = target_size
|
|
self.letterbox = Letterbox(target_size)
|
|
|
|
def forward(self, img, boxes):
|
|
paddings, scale, _ = self.letterbox.calc_params(img.shape)
|
|
|
|
boxes = boxes.float()
|
|
boxes[:, :, 0] -= paddings[0]
|
|
boxes[:, :, 2] -= paddings[0]
|
|
boxes[:, :, 1] -= paddings[2]
|
|
boxes[:, :, 3] -= paddings[2]
|
|
|
|
zero = torch.tensor(0)
|
|
boxes = torch.max(boxes, zero)
|
|
|
|
th, tw = self.target_size
|
|
pad_w = paddings[0] + paddings[1]
|
|
pad_h = paddings[2] + paddings[3]
|
|
max_w = tw - pad_w - 1
|
|
max_h = th - pad_h - 1
|
|
b0 = boxes[:, :, 0]
|
|
b1 = boxes[:, :, 1]
|
|
b2 = boxes[:, :, 2]
|
|
b3 = boxes[:, :, 3]
|
|
b0 = torch.min(b0, max_w)
|
|
b1 = torch.min(b1, max_h)
|
|
b2 = torch.min(b2, max_w)
|
|
b3 = torch.min(b3, max_h)
|
|
boxes[:, :, 0] = b0
|
|
boxes[:, :, 1] = b1
|
|
boxes[:, :, 2] = b2
|
|
boxes[:, :, 3] = b3
|
|
|
|
boxes[:, :, 0:4] /= scale
|
|
return boxes
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
class PosePreprocess(nn.Module):
|
|
def __init__(self, target_size, fill_value=114):
|
|
super(PosePreprocess, self).__init__()
|
|
self.boxcrop = BoxCrop(target_size)
|
|
|
|
def forward(self, img, bbox):
|
|
# img: torch.Tensor of shape [1, H, W, C], dtype=torch.uint8
|
|
# bbox: torch.Tensor of shape [1, 4], dtype=torch.float32
|
|
img = self.boxcrop(img, bbox)
|
|
return img
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
class PosePostprocess(nn.Module):
|
|
def __init__(self, target_size):
|
|
super(PosePostprocess, self).__init__()
|
|
self.boxcrop = BoxCrop(target_size)
|
|
self.target_size = target_size
|
|
|
|
def forward(self, img, bbox, keypoints):
|
|
scale, bbox = self.boxcrop.calc_params(bbox)
|
|
|
|
kp = keypoints.float()
|
|
kp[:, :, 0:2] /= scale
|
|
kp[:, :, 0] += bbox[0, 0]
|
|
kp[:, :, 1] += bbox[0, 1]
|
|
|
|
zero = torch.tensor(0)
|
|
kp = torch.max(kp, zero)
|
|
|
|
max_w = img.shape[2] - 1
|
|
max_h = img.shape[1] - 1
|
|
k0 = kp[:, :, 0]
|
|
k1 = kp[:, :, 1]
|
|
k0 = torch.min(k0, max_w)
|
|
k1 = torch.min(k1, max_h)
|
|
kp[:, :, 0] = k0
|
|
kp[:, :, 1] = k1
|
|
|
|
return kp
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
|
|
def main():
|
|
|
|
img_path = "/RapidPoseTriangulation/scripts/../data/h1/54138969-img_003201.jpg"
|
|
image = cv2.imread(img_path, 3)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
# Initialize the DetPreprocess module
|
|
preprocess_model = DetPreprocess(target_size=det_target_size)
|
|
det_dummy_input_a0 = torch.from_numpy(image).unsqueeze(0)
|
|
|
|
# Export to ONNX
|
|
torch.onnx.export(
|
|
preprocess_model,
|
|
det_dummy_input_a0,
|
|
base_path + "det_preprocess.onnx",
|
|
opset_version=11,
|
|
input_names=["input_image"],
|
|
output_names=["preprocessed_image"],
|
|
dynamic_axes={
|
|
"input_image": {0: "batch_size", 1: "height", 2: "width"},
|
|
"preprocessed_image": {0: "batch_size"},
|
|
},
|
|
)
|
|
|
|
# Initialize the DetPostprocess module
|
|
postprocess_model = DetPostprocess(target_size=det_target_size)
|
|
det_dummy_input_b0 = torch.from_numpy(image).unsqueeze(0)
|
|
det_dummy_input_b1 = torch.rand(1, 10, 5)
|
|
|
|
# Export to ONNX
|
|
torch.onnx.export(
|
|
postprocess_model,
|
|
(det_dummy_input_b0, det_dummy_input_b1),
|
|
base_path + "det_postprocess.onnx",
|
|
opset_version=11,
|
|
input_names=["input_image", "boxes"],
|
|
output_names=["output_boxes"],
|
|
dynamic_axes={
|
|
"input_image": {0: "batch_size", 1: "height", 2: "width"},
|
|
"boxes": {0: "batch_size", 1: "num_boxes"},
|
|
"output_boxes": {0: "batch_size", 1: "num_boxes"},
|
|
},
|
|
)
|
|
|
|
# Initialize the PosePreprocess module
|
|
preprocess_model = PosePreprocess(target_size=pose_target_size)
|
|
det_dummy_input_c0 = torch.from_numpy(image).unsqueeze(0)
|
|
det_dummy_input_c1 = torch.tensor([[352, 339, 518, 594]]).to(torch.int32)
|
|
|
|
# Export to ONNX
|
|
torch.onnx.export(
|
|
preprocess_model,
|
|
(det_dummy_input_c0, det_dummy_input_c1),
|
|
base_path + "pose_preprocess.onnx",
|
|
opset_version=11,
|
|
input_names=["input_image", "bbox"],
|
|
output_names=["preprocessed_image"],
|
|
dynamic_axes={
|
|
"input_image": {0: "batch_size", 1: "height", 2: "width"},
|
|
"preprocessed_image": {0: "batch_size"},
|
|
},
|
|
)
|
|
|
|
# Initialize the PosePostprocess module
|
|
postprocess_model = PosePostprocess(target_size=pose_target_size)
|
|
det_dummy_input_d0 = torch.from_numpy(image).unsqueeze(0)
|
|
det_dummy_input_d1 = torch.tensor([[352, 339, 518, 594]]).to(torch.int32)
|
|
det_dummy_input_d2 = torch.rand(1, 17, 2)
|
|
|
|
# Export to ONNX
|
|
torch.onnx.export(
|
|
postprocess_model,
|
|
(det_dummy_input_d0, det_dummy_input_d1, det_dummy_input_d2),
|
|
base_path + "pose_postprocess.onnx",
|
|
opset_version=11,
|
|
input_names=["input_image", "bbox", "keypoints"],
|
|
output_names=["output_keypoints"],
|
|
dynamic_axes={
|
|
"input_image": {0: "batch_size", 1: "height", 2: "width"},
|
|
"output_keypoints": {0: "batch_size"},
|
|
},
|
|
)
|
|
|
|
|
|
# ==================================================================================================
|
|
|
|
if __name__ == "__main__":
|
|
main()
|