Moved detector pre/post-processing into onnx graph.

This commit is contained in:
Daniel
2024-12-03 11:52:55 +01:00
parent 36781e616b
commit 742d2386c7
5 changed files with 213 additions and 104 deletions

View File

@ -1,6 +1,6 @@
import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper
from onnx import TensorProto, compose, helper, numpy_helper
# ==================================================================================================
@ -97,6 +97,37 @@ def add_steps_to_onnx(model_path):
for i, j in enumerate([0, 3, 1, 2]):
input_shape[j].dim_value = dims[i]
if "det" in model_path:
# Add preprocess model to main network
pp1_model = onnx.load(base_path + "det_preprocess.onnx")
model = compose.add_prefix(model, prefix="main_")
pp1_model = compose.add_prefix(pp1_model, prefix="preprocess_")
model = compose.merge_models(
pp1_model,
model,
io_map=[(pp1_model.graph.output[0].name, model.graph.input[0].name)],
)
# Add postprocess model
pp2_model = onnx.load(base_path + "det_postprocess.onnx")
pp2_model = compose.add_prefix(pp2_model, prefix="postprocess_")
model = compose.merge_models(
model,
pp2_model,
io_map=[
(model.graph.output[0].name, pp2_model.graph.input[1].name),
],
)
# Update nodes from postprocess model to use the input of the main network
pp2_input_image_name = pp2_model.graph.input[0].name
main_input_name = model.graph.input[0].name
for node in model.graph.node:
for idx, name in enumerate(node.input):
if name == pp2_input_image_name:
node.input[idx] = main_input_name
model.graph.input.pop(1)
# Set input type to int8
model.graph.input[0].type.tensor_type.elem_type = TensorProto.UINT8

View File

@ -3,3 +3,7 @@ _base_ = ["../_base_/base_static.py", "../../_base_/backends/onnxruntime.py"]
onnx_config = dict(
input_shape=[320, 320],
)
codebase_config = dict(
post_processing=dict(score_threshold=0.3, iou_threshold=0.3),
)

View File

@ -3,3 +3,7 @@ _base_ = ["../_base_/base_static.py", "../../_base_/backends/onnxruntime-fp16.py
onnx_config = dict(
input_shape=[320, 320],
)
codebase_config = dict(
post_processing=dict(score_threshold=0.3, iou_threshold=0.3),
)

View File

@ -0,0 +1,161 @@
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
# ==================================================================================================
base_path = "/RapidPoseTriangulation/extras/mmdeploy/exports/"
det_target_size = (320, 320)
# ==================================================================================================
class Letterbox(nn.Module):
def __init__(self, target_size, fill_value=128):
"""Resize and pad image while keeping aspect ratio"""
super(Letterbox, self).__init__()
self.target_size = target_size
self.fill_value = fill_value
def calc_params(self, img):
ih, iw = img.shape[1:3]
th, tw = self.target_size
scale = torch.min(tw / iw, th / ih)
nw = torch.round(iw * scale)
nh = torch.round(ih * scale)
pad_w = tw - nw
pad_h = th - nh
pad_left = pad_w // 2
pad_top = pad_h // 2
pad_right = pad_w - pad_left
pad_bottom = pad_h - pad_top
paddings = (pad_left, pad_right, pad_top, pad_bottom)
return paddings, scale, (nw, nh)
def forward(self, img):
paddings, _, (nw, nh) = self.calc_params(img)
# Resize the image
img = img.to(torch.float32)
img = F.interpolate(
img.permute(0, 3, 1, 2), size=(nh, nw), mode="bilinear", align_corners=False
)
img = img.permute(0, 2, 3, 1)
img = img.round()
# Pad the image
img = F.pad(
img.permute(0, 3, 1, 2),
pad=paddings,
mode="constant",
value=self.fill_value,
)
img = img.permute(0, 2, 3, 1)
canvas = img
return canvas
# ==================================================================================================
class DetPreprocess(nn.Module):
def __init__(self, target_size, fill_value=114):
super(DetPreprocess, self).__init__()
self.letterbox = Letterbox(target_size, fill_value)
def forward(self, img):
# img: torch.Tensor of shape [batch, H, W, C], dtype=torch.uint8
img = self.letterbox(img)
return img
# ==================================================================================================
class DetPostprocess(nn.Module):
def __init__(self, target_size):
super(DetPostprocess, self).__init__()
self.letterbox = Letterbox(target_size)
def forward(self, img, boxes):
paddings, scale, _ = self.letterbox.calc_params(img)
boxes = boxes.float()
boxes[:, :, 0] -= paddings[0]
boxes[:, :, 2] -= paddings[0]
boxes[:, :, 1] -= paddings[2]
boxes[:, :, 3] -= paddings[2]
boxes[:, :, 0:4] /= scale
ih, iw = img.shape[1:3]
boxes = torch.max(boxes, torch.tensor(0))
b0 = boxes[:, :, 0]
b1 = boxes[:, :, 1]
b2 = boxes[:, :, 2]
b3 = boxes[:, :, 3]
b0 = torch.min(b0, iw - 1)
b1 = torch.min(b1, ih - 1)
b2 = torch.min(b2, iw - 1)
b3 = torch.min(b3, ih - 1)
boxes = torch.stack((b0, b1, b2, b3, boxes[:, :, 4]), dim=2)
return boxes
# ==================================================================================================
def main():
img_path = "/RapidPoseTriangulation/scripts/../data/h1/54138969-img_003201.jpg"
image = cv2.imread(img_path, 3)
# Initialize the DetPreprocess module
preprocess_model = DetPreprocess(target_size=det_target_size)
det_dummy_input_a0 = torch.from_numpy(image).unsqueeze(0)
# Export to ONNX
torch.onnx.export(
preprocess_model,
det_dummy_input_a0,
base_path + "det_preprocess.onnx",
opset_version=11,
input_names=["input_image"],
output_names=["preprocessed_image"],
dynamic_axes={
"input_image": {0: "batch_size", 1: "height", 2: "width"},
"preprocessed_image": {0: "batch_size"},
},
)
# Initialize the DetPostprocess module
postprocess_model = DetPostprocess(target_size=det_target_size)
det_dummy_input_b0 = torch.from_numpy(image).unsqueeze(0)
det_dummy_input_b1 = torch.rand(1, 10, 5)
# Export to ONNX
torch.onnx.export(
postprocess_model,
(det_dummy_input_b0, det_dummy_input_b1),
base_path + "det_postprocess.onnx",
opset_version=11,
input_names=["input_image", "boxes"],
output_names=["output_boxes"],
dynamic_axes={
"input_image": {0: "batch_size", 1: "height", 2: "width"},
"boxes": {0: "batch_size", 1: "num_boxes"},
"output_boxes": {0: "batch_size", 1: "num_boxes"},
},
)
# ==================================================================================================
if __name__ == "__main__":
main()