From 41de2d47ccec05a4b113f884319f1d5aa9cbf073 Mon Sep 17 00:00:00 2001 From: darkliang <11710911@mail.sustech.edu.cn> Date: Thu, 13 Oct 2022 14:59:18 +0800 Subject: [PATCH] Fix gaitedge, fixes #91 --- opengait/modeling/models/gaitedge.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/opengait/modeling/models/gaitedge.py b/opengait/modeling/models/gaitedge.py index 459521d..9d34dd7 100644 --- a/opengait/modeling/models/gaitedge.py +++ b/opengait/modeling/models/gaitedge.py @@ -90,11 +90,10 @@ class GaitEdge(GaitGL): self._load_ckpt(save_name) def preprocess(self, sils): - - dilated_mask = (morph.dilation(sils, self.kernel).detach() + dilated_mask = (morph.dilation(sils, self.kernel.to(sils.device)).detach() ) > 0.5 # Dilation - eroded_mask = (morph.erosion(sils, self.kernel).detach() - ) > 0.5 # Dilation + eroded_mask = (morph.erosion(sils, self.kernel.to(sils.device)).detach() + ) > 0.5 # Erosion edge_mask = dilated_mask ^ eroded_mask return edge_mask, eroded_mask