Moved detector pre/post-processing into onnx graph.

This commit is contained in:
Daniel
2024-12-03 11:52:55 +01:00
parent 36781e616b
commit 742d2386c7
5 changed files with 213 additions and 104 deletions
+32 -1
View File
@@ -1,6 +1,6 @@
import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper
from onnx import TensorProto, compose, helper, numpy_helper
# ==================================================================================================
@@ -97,6 +97,37 @@ def add_steps_to_onnx(model_path):
for i, j in enumerate([0, 3, 1, 2]):
input_shape[j].dim_value = dims[i]
if "det" in model_path:
# Add preprocess model to main network
pp1_model = onnx.load(base_path + "det_preprocess.onnx")
model = compose.add_prefix(model, prefix="main_")
pp1_model = compose.add_prefix(pp1_model, prefix="preprocess_")
model = compose.merge_models(
pp1_model,
model,
io_map=[(pp1_model.graph.output[0].name, model.graph.input[0].name)],
)
# Add postprocess model
pp2_model = onnx.load(base_path + "det_postprocess.onnx")
pp2_model = compose.add_prefix(pp2_model, prefix="postprocess_")
model = compose.merge_models(
model,
pp2_model,
io_map=[
(model.graph.output[0].name, pp2_model.graph.input[1].name),
],
)
# Update nodes from postprocess model to use the input of the main network
pp2_input_image_name = pp2_model.graph.input[0].name
main_input_name = model.graph.input[0].name
for node in model.graph.node:
for idx, name in enumerate(node.input):
if name == pp2_input_image_name:
node.input[idx] = main_input_name
model.graph.input.pop(1)
# Set input type to int8
model.graph.input[0].type.tensor_type.elem_type = TensorProto.UINT8