Optimize parameter naming, fix label index error
This commit is contained in:
@@ -16,10 +16,10 @@ class ScoNet(BaseModel):
|
|||||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
ipts, labs, class_id, _, seqL = inputs
|
ipts, pids, labels, _, seqL = inputs
|
||||||
|
|
||||||
class_id_int = np.array([1 if status == 'positive' else 2 if status == 'neutral' else 0 for status in class_id])
|
# Label mapping: negative->0, neutral->1, positive->2
|
||||||
class_id = torch.tensor(class_id_int).cuda()
|
label_ids = np.array([{'negative': 0, 'neutral': 1, 'positive': 2}[status] for status in labels])
|
||||||
|
|
||||||
sils = ipts[0]
|
sils = ipts[0]
|
||||||
if len(sils.size()) == 4:
|
if len(sils.size()) == 4:
|
||||||
@@ -40,8 +40,8 @@ class ScoNet(BaseModel):
|
|||||||
embed = embed_1
|
embed = embed_1
|
||||||
retval = {
|
retval = {
|
||||||
'training_feat': {
|
'training_feat': {
|
||||||
'triplet': {'embeddings': embed, 'labels': labs},
|
'triplet': {'embeddings': embed, 'labels': pids},
|
||||||
'softmax': {'logits': logits, 'labels': class_id},
|
'softmax': {'logits': logits, 'labels': label_ids},
|
||||||
},
|
},
|
||||||
'visual_summary': {
|
'visual_summary': {
|
||||||
'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')
|
'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')
|
||||||
|
|||||||
Reference in New Issue
Block a user