Small improvements and fixes.

This commit is contained in:
Daniel
2025-01-17 16:24:37 +01:00
parent 1bd58deede
commit 8a249a2f16
3 changed files with 40 additions and 11 deletions

View File

@ -54,6 +54,7 @@ def add_steps_to_onnx(model_path):
inputs=[input_name],
outputs=[casted_output],
to=cast_type,
name="Cast_Input",
)
# Node to transpose
@ -118,6 +119,32 @@ def add_steps_to_onnx(model_path):
# Set input image type to int8
model.graph.input[0].type.tensor_type.elem_type = TensorProto.UINT8
# Cast all outputs to fp32 to avoid half precision issues in cpp code
for output in graph.output:
orig_output_name = output.name
internal_output_name = orig_output_name + "_internal"
# Rename the output tensor
for node in model.graph.node:
for idx, name in enumerate(node.output):
if name == orig_output_name:
node.output[idx] = internal_output_name
# Insert a Cast node that casts the internal output to fp32
cast_fp32_name = orig_output_name
cast_node_output = helper.make_node(
"Cast",
inputs=[internal_output_name],
outputs=[cast_fp32_name],
to=1,
name="Cast_Output_" + orig_output_name,
)
# Append the cast node to the graph
graph.node.append(cast_node_output)
# Update the output's data type info
output.type.tensor_type.elem_type = TensorProto.FLOAT
path = re.sub(r"(x)(\d+)x(\d+)x(\d+)", r"\1\3x\4x\2", model_path)
path = path.replace(".onnx", "_extra-steps.onnx")
onnx.save(model, path)