From dd150d65a0c8e33c02d9ef05e733faf6243d7888 Mon Sep 17 00:00:00 2001 From: Zzier Date: Wed, 27 Aug 2025 20:54:12 +0800 Subject: [PATCH] Optimize parameter naming, fix label index error --- opengait/modeling/models/sconet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/opengait/modeling/models/sconet.py b/opengait/modeling/models/sconet.py index a977e81..c61c147 100644 --- a/opengait/modeling/models/sconet.py +++ b/opengait/modeling/models/sconet.py @@ -16,10 +16,10 @@ class ScoNet(BaseModel): self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']) def forward(self, inputs): - ipts, labs, class_id, _, seqL = inputs + ipts, pids, labels, _, seqL = inputs - class_id_int = np.array([1 if status == 'positive' else 2 if status == 'neutral' else 0 for status in class_id]) - class_id = torch.tensor(class_id_int).cuda() + # Label mapping: negative->0, neutral->1, positive->2 + label_ids = np.array([{'negative': 0, 'neutral': 1, 'positive': 2}[status] for status in labels]) sils = ipts[0] if len(sils.size()) == 4: @@ -40,8 +40,8 @@ class ScoNet(BaseModel): embed = embed_1 retval = { 'training_feat': { - 'triplet': {'embeddings': embed, 'labels': labs}, - 'softmax': {'logits': logits, 'labels': class_id}, + 'triplet': {'embeddings': embed, 'labels': pids}, + 'softmax': {'logits': logits, 'labels': label_ids}, }, 'visual_summary': { 'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')