Fix gaitedge, fixes #91

This commit is contained in:
darkliang
2022-10-13 14:59:18 +08:00
parent c2a48191e3
commit 41de2d47cc
+3 -4
View File
@@ -90,11 +90,10 @@ class GaitEdge(GaitGL):
self._load_ckpt(save_name)
def preprocess(self, sils):
dilated_mask = (morph.dilation(sils, self.kernel).detach()
dilated_mask = (morph.dilation(sils, self.kernel.to(sils.device)).detach()
) > 0.5 # Dilation
eroded_mask = (morph.erosion(sils, self.kernel).detach()
) > 0.5 # Dilation
eroded_mask = (morph.erosion(sils, self.kernel.to(sils.device)).detach()
) > 0.5 # Erosion
edge_mask = dilated_mask ^ eroded_mask
return edge_mask, eroded_mask