Export pose postprocessing steps as well.
This commit is contained in:
@ -50,14 +50,8 @@ class SimCC(BaseModel):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def postprocess(self, tensor: List[np.ndarray]):
|
def postprocess(self, tensor: List[np.ndarray]):
|
||||||
simcc_x, simcc_y = tensor
|
kpts = tensor[0][0]
|
||||||
simcc_x = np.squeeze(simcc_x, axis=0)
|
scores = np.expand_dims(tensor[1][0], axis=-1)
|
||||||
simcc_y = np.squeeze(simcc_y, axis=0)
|
keypoints = np.concatenate([kpts, scores], axis=-1)
|
||||||
keypoints = simcc_decoder(simcc_x,
|
|
||||||
simcc_y,
|
|
||||||
self.input_shape[2:],
|
|
||||||
self.dx,
|
|
||||||
self.dy,
|
|
||||||
self.scale)
|
|
||||||
|
|
||||||
return keypoints
|
return keypoints
|
||||||
|
|||||||
@ -2,7 +2,7 @@ _base_ = ["./pose-detection_static.py", "../_base_/backends/onnxruntime.py"]
|
|||||||
|
|
||||||
onnx_config = dict(
|
onnx_config = dict(
|
||||||
input_shape=[288, 384],
|
input_shape=[288, 384],
|
||||||
output_names=["simcc_x", "simcc_y"],
|
output_names=["kpts", "scores"],
|
||||||
)
|
)
|
||||||
|
|
||||||
codebase_config = dict(export_postprocess=False) # do not export get_simcc_maximum
|
codebase_config = dict(export_postprocess=True) # export get_simcc_maximum
|
||||||
|
|||||||
@ -2,7 +2,7 @@ _base_ = ["./pose-detection_static.py", "../_base_/backends/onnxruntime-fp16.py"
|
|||||||
|
|
||||||
onnx_config = dict(
|
onnx_config = dict(
|
||||||
input_shape=[288, 384],
|
input_shape=[288, 384],
|
||||||
output_names=["simcc_x", "simcc_y"],
|
output_names=["kpts", "scores"],
|
||||||
)
|
)
|
||||||
|
|
||||||
codebase_config = dict(export_postprocess=False) # do not export get_simcc_maximum
|
codebase_config = dict(export_postprocess=True) # export get_simcc_maximum
|
||||||
|
|||||||
Reference in New Issue
Block a user