Moved pose pre/post-processing into onnx graph.
This commit is contained in:
@ -121,14 +121,55 @@ def add_steps_to_onnx(model_path):
|
||||
|
||||
# Update nodes from postprocess model to use the input of the main network
|
||||
pp2_input_image_name = pp2_model.graph.input[0].name
|
||||
main_input_name = model.graph.input[0].name
|
||||
main_input_image_name = model.graph.input[0].name
|
||||
for node in model.graph.node:
|
||||
for idx, name in enumerate(node.input):
|
||||
if name == pp2_input_image_name:
|
||||
node.input[idx] = main_input_name
|
||||
node.input[idx] = main_input_image_name
|
||||
model.graph.input.pop(1)
|
||||
|
||||
# Set input type to int8
|
||||
if "pose" in model_path:
|
||||
# Add preprocess model to main network
|
||||
pp1_model = onnx.load(base_path + "pose_preprocess.onnx")
|
||||
model = compose.add_prefix(model, prefix="main_")
|
||||
pp1_model = compose.add_prefix(pp1_model, prefix="preprocess_")
|
||||
model = compose.merge_models(
|
||||
pp1_model,
|
||||
model,
|
||||
io_map=[
|
||||
(pp1_model.graph.output[0].name, model.graph.input[0].name),
|
||||
],
|
||||
)
|
||||
|
||||
# Add postprocess model
|
||||
pp2_model = onnx.load(base_path + "pose_postprocess.onnx")
|
||||
pp2_model = compose.add_prefix(pp2_model, prefix="postprocess_")
|
||||
model = compose.merge_models(
|
||||
model,
|
||||
pp2_model,
|
||||
io_map=[
|
||||
(model.graph.output[0].name, pp2_model.graph.input[2].name),
|
||||
],
|
||||
)
|
||||
|
||||
# Update nodes from postprocess model to use the input of the main network
|
||||
pp2_input_image_name = pp2_model.graph.input[0].name
|
||||
pp2_input_bbox_name = pp2_model.graph.input[1].name
|
||||
main_input_image_name = model.graph.input[0].name
|
||||
main_input_bbox_name = model.graph.input[1].name
|
||||
for node in model.graph.node:
|
||||
for idx, name in enumerate(node.input):
|
||||
if name == pp2_input_image_name:
|
||||
node.input[idx] = main_input_image_name
|
||||
if name == pp2_input_bbox_name:
|
||||
node.input[idx] = main_input_bbox_name
|
||||
model.graph.input.pop(2)
|
||||
model.graph.input.pop(2)
|
||||
|
||||
# Set input box type to int32
|
||||
model.graph.input[1].type.tensor_type.elem_type = TensorProto.INT32
|
||||
|
||||
# Set input image type to int8
|
||||
model.graph.input[0].type.tensor_type.elem_type = TensorProto.UINT8
|
||||
|
||||
path = model_path.replace(".onnx", "_extra-steps.onnx")
|
||||
|
||||
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
|
||||
base_path = "/RapidPoseTriangulation/extras/mmdeploy/exports/"
|
||||
det_target_size = (320, 320)
|
||||
pose_target_size = (384, 288)
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
@ -19,10 +20,37 @@ class Letterbox(nn.Module):
|
||||
self.target_size = target_size
|
||||
self.fill_value = fill_value
|
||||
|
||||
def calc_params(self, img):
|
||||
ih, iw = img.shape[1:3]
|
||||
def calc_params_and_crop(self, img, bbox=None):
|
||||
ih0, iw0 = img.shape[1:3]
|
||||
th, tw = self.target_size
|
||||
|
||||
if bbox is not None:
|
||||
bbox = bbox[0].float()
|
||||
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
|
||||
# Slightly increase bbox size
|
||||
factor = 1.25
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
x1 -= w * (factor - 1) / 2
|
||||
x2 += w * (factor - 1) / 2
|
||||
y1 -= h * (factor - 1) / 2
|
||||
y2 += h * (factor - 1) / 2
|
||||
|
||||
zero = torch.tensor(0)
|
||||
x1 = torch.max(x1, zero).to(torch.int64)
|
||||
y1 = torch.max(y1, zero).to(torch.int64)
|
||||
x2 = torch.min(x2, iw0).to(torch.int64)
|
||||
y2 = torch.min(y2, ih0).to(torch.int64)
|
||||
bbox = torch.stack((x1, y1, x2, y2), dim=0).unsqueeze(0)
|
||||
|
||||
img = img.to(torch.float32)
|
||||
img = img[:, y1:y2, x1:x2, :]
|
||||
ih = y2 - y1
|
||||
iw = x2 - x1
|
||||
else:
|
||||
ih, iw = ih0, iw0
|
||||
|
||||
scale = torch.min(tw / iw, th / ih)
|
||||
nw = torch.round(iw * scale)
|
||||
nh = torch.round(ih * scale)
|
||||
@ -35,15 +63,18 @@ class Letterbox(nn.Module):
|
||||
pad_bottom = pad_h - pad_top
|
||||
paddings = (pad_left, pad_right, pad_top, pad_bottom)
|
||||
|
||||
return paddings, scale, (nw, nh)
|
||||
return img, paddings, scale, (nw, nh), bbox
|
||||
|
||||
def forward(self, img):
|
||||
paddings, _, (nw, nh) = self.calc_params(img)
|
||||
def forward(self, img, bbox=None):
|
||||
img, paddings, _, (nw, nh), _ = self.calc_params_and_crop(img, bbox)
|
||||
|
||||
# Resize the image
|
||||
img = img.to(torch.float32)
|
||||
img = F.interpolate(
|
||||
img.permute(0, 3, 1, 2), size=(nh, nw), mode="bilinear", align_corners=False
|
||||
img.permute(0, 3, 1, 2),
|
||||
size=(nh, nw),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
img = img.round()
|
||||
@ -71,7 +102,7 @@ class DetPreprocess(nn.Module):
|
||||
|
||||
def forward(self, img):
|
||||
# img: torch.Tensor of shape [batch, H, W, C], dtype=torch.uint8
|
||||
img = self.letterbox(img)
|
||||
img = self.letterbox(img, None)
|
||||
return img
|
||||
|
||||
|
||||
@ -81,36 +112,97 @@ class DetPreprocess(nn.Module):
|
||||
class DetPostprocess(nn.Module):
|
||||
def __init__(self, target_size):
|
||||
super(DetPostprocess, self).__init__()
|
||||
|
||||
self.target_size = target_size
|
||||
self.letterbox = Letterbox(target_size)
|
||||
|
||||
def forward(self, img, boxes):
|
||||
paddings, scale, _ = self.letterbox.calc_params(img)
|
||||
_, paddings, scale, _, _ = self.letterbox.calc_params_and_crop(img, None)
|
||||
|
||||
boxes = boxes.float()
|
||||
boxes[:, :, 0] -= paddings[0]
|
||||
boxes[:, :, 2] -= paddings[0]
|
||||
boxes[:, :, 1] -= paddings[2]
|
||||
boxes[:, :, 3] -= paddings[2]
|
||||
boxes[:, :, 0:4] /= scale
|
||||
|
||||
ih, iw = img.shape[1:3]
|
||||
boxes = torch.max(boxes, torch.tensor(0))
|
||||
zero = torch.tensor(0)
|
||||
boxes = torch.max(boxes, zero)
|
||||
|
||||
th, tw = self.target_size
|
||||
pad_w = paddings[0] + paddings[1]
|
||||
pad_h = paddings[2] + paddings[3]
|
||||
max_w = tw - pad_w - 1
|
||||
max_h = th - pad_h - 1
|
||||
b0 = boxes[:, :, 0]
|
||||
b1 = boxes[:, :, 1]
|
||||
b2 = boxes[:, :, 2]
|
||||
b3 = boxes[:, :, 3]
|
||||
b0 = torch.min(b0, iw - 1)
|
||||
b1 = torch.min(b1, ih - 1)
|
||||
b2 = torch.min(b2, iw - 1)
|
||||
b3 = torch.min(b3, ih - 1)
|
||||
b0 = torch.min(b0, max_w)
|
||||
b1 = torch.min(b1, max_h)
|
||||
b2 = torch.min(b2, max_w)
|
||||
b3 = torch.min(b3, max_h)
|
||||
boxes = torch.stack((b0, b1, b2, b3, boxes[:, :, 4]), dim=2)
|
||||
|
||||
boxes[:, :, 0:4] /= scale
|
||||
return boxes
|
||||
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
|
||||
class PosePreprocess(nn.Module):
|
||||
def __init__(self, target_size, fill_value=114):
|
||||
super(PosePreprocess, self).__init__()
|
||||
self.letterbox = Letterbox(target_size, fill_value)
|
||||
|
||||
def forward(self, img, bbox):
|
||||
# img: torch.Tensor of shape [1, H, W, C], dtype=torch.uint8
|
||||
# bbox: torch.Tensor of shape [1, 4], dtype=torch.float32
|
||||
img = self.letterbox(img, bbox)
|
||||
return img
|
||||
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
|
||||
class PosePostprocess(nn.Module):
|
||||
def __init__(self, target_size):
|
||||
super(PosePostprocess, self).__init__()
|
||||
|
||||
self.target_size = target_size
|
||||
self.letterbox = Letterbox(target_size)
|
||||
|
||||
def forward(self, img, bbox, keypoints):
|
||||
_, paddings, scale, _, bbox = self.letterbox.calc_params_and_crop(img, bbox)
|
||||
|
||||
kp = keypoints.float()
|
||||
kp[:, :, 0] -= paddings[0]
|
||||
kp[:, :, 1] -= paddings[2]
|
||||
|
||||
zero = torch.tensor(0)
|
||||
kp = torch.max(kp, zero)
|
||||
|
||||
th, tw = self.target_size
|
||||
pad_w = paddings[0] + paddings[1]
|
||||
pad_h = paddings[2] + paddings[3]
|
||||
max_w = tw - pad_w - 1
|
||||
max_h = th - pad_h - 1
|
||||
k0 = kp[:, :, 0]
|
||||
k1 = kp[:, :, 1]
|
||||
k0 = torch.min(k0, max_w)
|
||||
k1 = torch.min(k1, max_h)
|
||||
kp = torch.stack((k0, k1), dim=2)
|
||||
|
||||
kp[:, :, 0:2] /= scale
|
||||
|
||||
kp[:, :, 0] += bbox[0, 0]
|
||||
kp[:, :, 1] += bbox[0, 1]
|
||||
return kp
|
||||
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
img_path = "/RapidPoseTriangulation/scripts/../data/h1/54138969-img_003201.jpg"
|
||||
@ -154,6 +246,45 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize the PosePreprocess module
|
||||
preprocess_model = PosePreprocess(target_size=pose_target_size)
|
||||
det_dummy_input_c0 = torch.from_numpy(image).unsqueeze(0)
|
||||
det_dummy_input_c1 = torch.tensor([[10, 10, 90, 40]])
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
preprocess_model,
|
||||
(det_dummy_input_c0, det_dummy_input_c1),
|
||||
base_path + "pose_preprocess.onnx",
|
||||
opset_version=11,
|
||||
input_names=["input_image", "bbox"],
|
||||
output_names=["preprocessed_image"],
|
||||
dynamic_axes={
|
||||
"input_image": {0: "batch_size", 1: "height", 2: "width"},
|
||||
"preprocessed_image": {0: "batch_size"},
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize the PosePostprocess module
|
||||
postprocess_model = PosePostprocess(target_size=pose_target_size)
|
||||
det_dummy_input_d0 = torch.from_numpy(image).unsqueeze(0)
|
||||
det_dummy_input_d1 = torch.tensor([[10, 10, 90, 40]])
|
||||
det_dummy_input_d2 = torch.rand(1, 17, 3)
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
postprocess_model,
|
||||
(det_dummy_input_d0, det_dummy_input_d1, det_dummy_input_d2),
|
||||
base_path + "pose_postprocess.onnx",
|
||||
opset_version=11,
|
||||
input_names=["input_image", "bbox", "keypoints"],
|
||||
output_names=["output_keypoints"],
|
||||
dynamic_axes={
|
||||
"input_image": {0: "batch_size", 1: "height", 2: "width"},
|
||||
"output_keypoints": {0: "batch_size"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
|
||||
Reference in New Issue
Block a user