ScoNet_V1

This commit is contained in:
Zzier
2024-06-28 17:34:32 +08:00
parent 01daf44061
commit dc2616c0e0
8 changed files with 6227 additions and 1 deletions
+43 -1
View File
@@ -5,7 +5,7 @@ from utils import get_msg_mgr, mkdir
from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank, evaluate_many
from .re_rank import re_ranking
from sklearn.metrics import confusion_matrix, accuracy_score
def de_diag(acc, each_angle=False):
# Exclude identical-view cases
@@ -415,3 +415,45 @@ def evaluate_CCPG(data, dataset, metric='euc'):
msg_mgr.log_info('DN: {}'.format(de_diag(acc[2, :, :, i], True)))
msg_mgr.log_info('BG: {}'.format(de_diag(acc[3, :, :, i], True)))
return result_dict
def evaluate_scoliosis(data, dataset, metric='euc'):
msg_mgr = get_msg_mgr()
feature, label, class_id, view = data['embeddings'], data['labels'], data['types'], data['views']
label = np.array(label)
class_id = np.array(class_id)
# Update class_id with integer labels based on status
class_id_int = np.array([1 if status == 'positive' else 2 if status == 'critical' else 0 for status in class_id])
print('class_id=', class_id_int)
features = np.array(feature)
c_id_int = np.argmax(features.mean(-1), axis=-1)
print('predicted_labels', c_id_int)
# Calculate sensitivity and specificity
cm = confusion_matrix(class_id_int, c_id_int, labels=[0, 1, 2])
FP = cm.sum(axis=0) - np.diag(cm)
FN = cm.sum(axis=1) - np.diag(cm)
TP = np.diag(cm)
TN = cm.sum() - (FP + FN + TP)
# Sensitivity, hit rate, recall, or true positive rate
TPR = TP / (TP + FN)
# Specificity or true negative rate
TNR = TN / (TN + FP)
accuracy = accuracy_score(class_id_int, c_id_int)
result_dict = {}
result_dict["scalar/test_accuracy/"] = accuracy
result_dict["scalar/test_sensitivity/"] = TPR
result_dict["scalar/test_specificity/"] = TNR
# Printing the sensitivity and specificity
for i, cls in enumerate(['Positive']):
print(f"{cls} Sensitivity (Recall): {TPR[i] * 100:.2f}%")
print(f"{cls} Specificity: {TNR[i] * 100:.2f}%")
print(f"Accuracy: {accuracy * 100:.2f}%")
return result_dict