Merge model outputs directly in graph.

This commit is contained in:
Daniel
2025-01-17 18:33:19 +01:00
parent 97ea039f7d
commit 8f2322694a
3 changed files with 115 additions and 57 deletions

View File

@ -145,6 +145,64 @@ def add_steps_to_onnx(model_path):
# Update the output's data type info
output.type.tensor_type.elem_type = TensorProto.FLOAT
# Merge the two outputs
if "det" in model_path:
r1_output = "dets"
r2_output = "labels"
out_name = "bboxes"
out_dim = 6
if "pose" in model_path:
r1_output = "kpts"
r2_output = "scores"
out_name = "keypoints"
out_dim = 3
if "det" in model_path or "pose" in model_path:
# Node to expand
r2_expanded = r2_output + "_expanded"
unsqueeze_node = helper.make_node(
"Unsqueeze",
inputs=[r2_output],
outputs=[r2_expanded],
axes=[2],
name="Unsqueeze",
)
# Node to concatenate
r12_merged = out_name
concat_node = helper.make_node(
"Concat",
inputs=[r1_output, r2_expanded],
outputs=[r12_merged],
axis=2,
name="Merged",
)
# Define the new concatenated output
merged_output = helper.make_tensor_value_info(
r12_merged,
TensorProto.FLOAT,
[
(
graph.input[0].type.tensor_type.shape.dim[0].dim_value
if graph.input[0].type.tensor_type.shape.dim[0].dim_value > 0
else None
),
(
graph.output[0].type.tensor_type.shape.dim[1].dim_value
if graph.output[0].type.tensor_type.shape.dim[1].dim_value > 0
else None
),
out_dim,
],
)
# Update the graph
graph.node.append(unsqueeze_node)
graph.node.append(concat_node)
graph.output.pop()
graph.output.pop()
graph.output.append(merged_output)
path = re.sub(r"(x)(\d+)x(\d+)x(\d+)", r"\1\3x\4x\2", model_path)
path = path.replace(".onnx", "_extra-steps.onnx")
onnx.save(model, path)