diff --git a/opengait/modeling/models/gaitedge.py b/opengait/modeling/models/gaitedge.py index 459521d..9d34dd7 100644 --- a/opengait/modeling/models/gaitedge.py +++ b/opengait/modeling/models/gaitedge.py @@ -90,11 +90,10 @@ class GaitEdge(GaitGL): self._load_ckpt(save_name) def preprocess(self, sils): - - dilated_mask = (morph.dilation(sils, self.kernel).detach() + dilated_mask = (morph.dilation(sils, self.kernel.to(sils.device)).detach() ) > 0.5 # Dilation - eroded_mask = (morph.erosion(sils, self.kernel).detach() - ) > 0.5 # Dilation + eroded_mask = (morph.erosion(sils, self.kernel.to(sils.device)).detach() + ) > 0.5 # Erosion edge_mask = dilated_mask ^ eroded_mask return edge_mask, eroded_mask