混淆矩阵是评判模型结果的一种指标,属于模型评估的一部分,一般用作评判分类器的优劣
- 以二分类为例子
- 一级指标FP FN
- 二级指标 四个
- 三级指标
举例子(三分类来分析加深上面的公式的理解)
- 所有的类别即使Accuracy准确率
- 单一类别的Precision精确率
- 单一类别的Recall召回率
- 单一类别的Specificity特异度
同样上面的单一类别可以计算各自的三级指标
在自己写论文中,分类模型评价的一节里,你可以采用类似下面的图来证明你的模型怎么样。
代码实战
import os
import json
import torch
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model_v2 import MobileNetV2
class ConfusionMatrix(object):
def __init__(self,num_classes:int,labels:list):
self.matrix=np.zeros((num_classes,num_classes))
self.num_classes=num_classes
self.labels=labels
def update(self,preds,labels):
for p,t in zip(preds,labels):
self.matrix[p,t]+=1
def summary(self):
sum_TP =0;
for i in range(self.num_classes):
sum_TP += self.matrix[i,i]
acc = sum_TP/np.sum(self.matrix)
print("the model accuracy is ",acc)
table = PrettyTable()
table.field_names=["","Precision","Recall","Specificity"]
for i in range(self.num_classes):
TP = self.matrix[i,i]
FP = np.sum(self.matrix[i, :])-TP
FN = np.sum(self.matrix[:, i])-TP
TN = np.sum(self.matrix)-TP-FP-FN
Precision = round(TP/(TP+FP), 3)
Recall = round(TP/(TP+FN), 3)
Specificity = round(TN/(TN+FP), 3)
table.add_row([self.labels[i], Precision, Recall, Specificity])
print(table)
def plot(self):
matrix=self.matrix
print(matrix)
plt.imshow(matrix,cmap=plt.cm.Blues)
plt.xticks(range(self.num_classes),self.labels,rotation=45)
plt.yticks(range(self.num_classes),self.labels)
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
thresh=matrix.max()/2
for x in range(self.num_classes):
for y in range(self.num_classes):
info=int(matrix[y,x])
plt.text(x,y,info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path=data_root+"/data_set/flower_data/"
validate_dataset = datasets.ImageFolder(root=image_path+"val",transform=data_transform)
train_num = len(validate_dataset)
batch_size = 16
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=2)
net = MobileNetV2(num_classes=5)
model_weight_path = "./Mobilenetv2.pth"
net.load_state_dict(torch.load(model_weight_path,map_location=device))
net.to(device)
try:
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
json_file=open('./class_indices.json','r')
class_indict=json.load(json_file)
except Exception as e:
print(e)
exit(-1)
labels=[label for _,label in class_indict.items()]
confusion=ConfusionMatrix(num_classes=5,labels=labels)
net.eval()
with torch.no_grad():
for val_data in validate_loader:
val_images,val_labels=val_data
outputs=net(val_images.to(device))
outputs=torch.softmax(outputs,dim=1)
outputs=torch.argmax(outputs,dim=1)
confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
confusion.plot()
confusion.summary()
- 结果
参考链接:https://blog.csdn.net/qq_37541097?spm=1001.2014.3001.5509 推荐博文:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839
|