如何在Python中使用字符串轴而不是整数绘制混淆矩阵?
阅读更多:Python 教程
什么是混淆矩阵?
在机器学习中,我们通常会使用训练集和测试集作为模型评估的基础。混淆矩阵(Confusion Matrix)是一种常用的评估方法,用于评估分类模型的准确性。它将预测结果和真实结果以二维表格的形式呈现,可以更直观地了解模型的精准率、召回率等评估指标。
混淆矩阵通常长这样:
预测正例 | 预测反例 | |
---|---|---|
真实正例 | TP | FN |
真实反例 | FP | TN |
模型的分类结果包括四种情况:
- True Positive(真正例):实际为正例,模型预测为正例
- False Positive(假正例):实际为反例,模型预测为正例
- False Negative(假反例):实际为正例,模型预测为反例
- True Negative(真反例):实际为反例,模型预测为反例
Python绘制混淆矩阵
我们通常使用Python的Matplotlib库来绘制混淆矩阵。在Matplotlib中,绘制混淆矩阵需要用到imshow函数。因此,我们首先需要导入该函数。
import matplotlib.pyplot as plt
接下来,我们需要将混淆矩阵中的各项指标(即TP, FP, FN, TN)转化为Python中的二维数组形式。默认情况下,imshow函数将使用整数轴绘制混淆矩阵,这在可视化时缺乏直观性。因此,我们需要将整数轴转换为字符串轴。 转换之后,我们就可以使用下面这个例子程序绘制一张混淆矩阵网络图:
import numpy as np
import itertools
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
该函数使用矩阵方式绘制混淆矩阵
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# Test
cm = np.array([[50, 10], [5, 35]])
classes = ['Cat', 'Dog']
plot_confusion_matrix(cm, classes, normalize=True)
plt.show()
输出结果应该类似于下面这样:
Normalized confusion matrix
[[0.83333333 0.16666667]
[0.125 0.875 ]]
从整数轴到字符串轴的转换
我们将混淆矩阵中的类别名称从‘0’和‘1’(或其他整数)转换为字符串,则需要将类别名称与整数对应起来。我们可以利用字典来实现这个过程。例如,假设我们有4个类别,我们可以用下面的代码将它们编码为整数:
class_to_idx = {'class1': 0, 'class2': 1, 'class3': 2, 'class4': 3}
接下来,我们需要将整数轴转换为字符串轴。为此,我们可以使用pyplot中的xticks和yticks函数。
fig, ax = plt.subplots()
ax.imshow(cm, cmap='Blues')
# ax.grid(False) # 可以用此语句取消网格
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)
plt.show()
在这里,我们输入的是混淆矩阵矩阵cm和类别名称列表classes。我们首先获取子图和子图轴对象,并使用imshow函数将cm绘制为矩阵。我们随后将刻度位置设置为类别名称,并设置其显示标签。
两种方法的对比
使用pyplot绘图,在xticks和yticks中仍然需要将类别名称映射到整数,因此,对于一些需要在整数轴和字符串轴之间切换的场景,使用matplotlib的面向对象API会更加方便。从图像的效果上来看,使用面向对象API有更多的绘图自由度,包括更好的字体选项,更少的重叠,更好的轴标签位置等。
结论
本文介绍了如何在Python中使用字符串轴绘制混淆矩阵。我们首先介绍了混淆矩阵的概念,然后提供了两种方法:使用pyplot绘图和使用面向对象绘图。并且给出了在从整数轴到字符串轴的转换时所需注意的事项。