diff --git a/opengait/modeling/models/sconet.py b/opengait/modeling/models/sconet.py index c61c147..8438da4 100644 --- a/opengait/modeling/models/sconet.py +++ b/opengait/modeling/models/sconet.py @@ -20,7 +20,8 @@ class ScoNet(BaseModel): # Label mapping: negative->0, neutral->1, positive->2 label_ids = np.array([{'negative': 0, 'neutral': 1, 'positive': 2}[status] for status in labels]) - + label_ids = torch.from_numpy(label_ids).cuda().long() + sils = ipts[0] if len(sils.size()) == 4: sils = sils.unsqueeze(1)