ScoNet_V1
This commit is contained in:
@@ -5,7 +5,7 @@ from utils import get_msg_mgr, mkdir
|
||||
|
||||
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
|
||||
from .re_rank import re_ranking
|
||||
|
||||
from sklearn.metrics import confusion_matrix, accuracy_score
|
||||
|
||||
def de_diag(acc, each_angle=False):
|
||||
# Exclude identical-view cases
|
||||
@@ -415,3 +415,45 @@ def evaluate_CCPG(data, dataset, metric='euc'):
|
||||
msg_mgr.log_info('DN: {}'.format(de_diag(acc[2, :, :, i], True)))
|
||||
msg_mgr.log_info('BG: {}'.format(de_diag(acc[3, :, :, i], True)))
|
||||
return result_dict
|
||||
|
||||
def evaluate_scoliosis(data, dataset, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
feature, label, class_id, view = data['embeddings'], data['labels'], data['types'], data['views']
|
||||
|
||||
label = np.array(label)
|
||||
class_id = np.array(class_id)
|
||||
|
||||
# Update class_id with integer labels based on status
|
||||
class_id_int = np.array([1 if status == 'positive' else 2 if status == 'critical' else 0 for status in class_id])
|
||||
print('class_id=', class_id_int)
|
||||
|
||||
features = np.array(feature)
|
||||
c_id_int = np.argmax(features.mean(-1), axis=-1)
|
||||
print('predicted_labels', c_id_int)
|
||||
|
||||
# Calculate sensitivity and specificity
|
||||
cm = confusion_matrix(class_id_int, c_id_int, labels=[0, 1, 2])
|
||||
FP = cm.sum(axis=0) - np.diag(cm)
|
||||
FN = cm.sum(axis=1) - np.diag(cm)
|
||||
TP = np.diag(cm)
|
||||
TN = cm.sum() - (FP + FN + TP)
|
||||
|
||||
# Sensitivity, hit rate, recall, or true positive rate
|
||||
TPR = TP / (TP + FN)
|
||||
# Specificity or true negative rate
|
||||
TNR = TN / (TN + FP)
|
||||
accuracy = accuracy_score(class_id_int, c_id_int)
|
||||
|
||||
result_dict = {}
|
||||
result_dict["scalar/test_accuracy/"] = accuracy
|
||||
result_dict["scalar/test_sensitivity/"] = TPR
|
||||
result_dict["scalar/test_specificity/"] = TNR
|
||||
|
||||
# Printing the sensitivity and specificity
|
||||
for i, cls in enumerate(['Positive']):
|
||||
print(f"{cls} Sensitivity (Recall): {TPR[i] * 100:.2f}%")
|
||||
print(f"{cls} Specificity: {TNR[i] * 100:.2f}%")
|
||||
print(f"Accuracy: {accuracy * 100:.2f}%")
|
||||
|
||||
return result_dict
|
||||
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
||||
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
class ScoNet(BaseModel):
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, class_id, _, seqL = inputs
|
||||
|
||||
class_id_int = np.array([1 if status == 'positive' else 2 if status == 'critical' else 0 for status in class_id])
|
||||
class_id = torch.tensor(class_id_int).cuda()
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(1)
|
||||
else:
|
||||
sils = rearrange(sils, 'n s c h w -> n c s h w')
|
||||
|
||||
del ipts
|
||||
outs = self.Backbone(sils) # [n, c, s, h, w]
|
||||
|
||||
# Temporal Pooling, TP
|
||||
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feat = self.HPP(outs) # [n, c, p]
|
||||
|
||||
embed_1 = self.FCs(feat) # [n, c, p]
|
||||
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||
embed = embed_1
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embed, 'labels': labs},
|
||||
'softmax': {'logits': logits, 'labels': class_id},
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': logits
|
||||
}
|
||||
}
|
||||
return retval
|
||||
Reference in New Issue
Block a user