Fix gaitedge, fixes #91
This commit is contained in:
@@ -90,11 +90,10 @@ class GaitEdge(GaitGL):
|
||||
self._load_ckpt(save_name)
|
||||
|
||||
def preprocess(self, sils):
|
||||
|
||||
dilated_mask = (morph.dilation(sils, self.kernel).detach()
|
||||
dilated_mask = (morph.dilation(sils, self.kernel.to(sils.device)).detach()
|
||||
) > 0.5 # Dilation
|
||||
eroded_mask = (morph.erosion(sils, self.kernel).detach()
|
||||
) > 0.5 # Dilation
|
||||
eroded_mask = (morph.erosion(sils, self.kernel.to(sils.device)).detach()
|
||||
) > 0.5 # Erosion
|
||||
edge_mask = dilated_mask ^ eroded_mask
|
||||
return edge_mask, eroded_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user