Matplotlib 如何绘制Scikit-learn分类报告
在机器学习领域,使用Scikit-learn进行分类任务是常见的选择。当我们完成了一个分类任务后,我们需要评估模型的性能。这个过程可以使用Scikit-learn的分类报告来完成。分类报告是一个性能指标汇总的表格,提供精确度、召回率和F1-score等指标,非常适合评估分类模型的性能。在这篇文章中,我们将学习如何使用Matplotlib将Scikit-learn的分类报告可视化,以便更直观地分析和理解性能指标。
阅读更多:Matplotlib 教程
Scikit-learn分类报告简介
Scikit-learn分类报告用于评估分类模型的性能。它提供了多个指标来评估模型的分类结果。这些指标包括准确率(precission)、召回率(recall)、f1分数,以及每个类别的支持数(support)。
分类报告是一个基于表格的布局,它总结了每个类别的分类指标。其中包括了以下信息:
- 准确率(precision): 用于预测为正的样本中的真正正样本的比例
- 召回率(recall): 正确预测为正的样本数占所有真正正样本的比例
- f1分数(f1-score): 准确率和召回率的加权平均数,其值介于0和1之间
- 支持(support): 给定类别的真实样本数
在Scikit-learn中,我们可以使用classification_report
函数生成分类报告。以下是一个例子:
from sklearn.metrics import classification_report
y_pred = [0, 1, 2, 3]
y_true = [0, 1, 2, 2]
target_names = ['class 0', 'class 1', 'class 2', 'class 3']
print(classification_report(y_true, y_pred, target_names=target_names))
输出结果为:
precision recall f1-score support
class 0 1.00 1.00 1.00 1
class 1 1.00 1.00 1.00 1
class 2 0.67 0.67 0.67 3
class 3 0.00 0.00 0.00 0
accuracy 0.75 5
macro avg 0.67 0.67 0.67 5
weighted avg 0.80 0.75 0.77 5
Matplotlib绘制分类报告
Scikit-learn生成的分类报告功能很强大,但它的输出结果是以纯文本形式呈现的,这可能让我们的分析和理解变得棘手。我们可以使用Matplotlib来绘制Scikit-learn生成的分类报告以可视化展示。这可以让我们更好地理解不同类别之间的性能差异。
以下是Matplotlib的一些绘图功能:
- 使用直方图(histogram)来可视化数据的分布
- 使用饼图(pie chart)来可视化给定类别的分类结果
- 使用热力图(heatmap)来绘制全局性能指标
我们将逐步介绍如何使用Matplotlib来绘制Scikit-learn分类报告。
直方图
我们可以使用Matplotlib的hist
函数来绘制直方图。这个函数可以绘制一组数据的分布情况。我们可以将每个类别的分类指标作为直方图的输入数据,并通过改变颜色和位置来区分不同的类别。
import matplotlib.pyplot as plt
# 生成分类报告
report = classification_report(y_true,y_pred, target_names=target_names, output_dict=True)
# 将分类指标转换为数据框
import pandas as pd
data = pd.DataFrame(report).transpose()
# 绘制直方图
fig, ax = plt.subplots(figsize=(10,6))
for index, row in data.iloc[:-3].iterrows():
ax.bar(index + 0.2, row['precision'], color='r', width=0.2, label='Precision')
ax.bar(index + 0.4, row['recall'], color='g', width=0.2, label='Recall')
ax.bar(index + 0.6, row['f1-score'], color='b', width=0.2, label='F1-score')
ax.set_xticks(range(len(data.iloc[:-3])))
ax.set_xticklabels(data.iloc[:-3].index)
ax.legend(loc='upper right')
plt.show()
这将绘制一个每个类别的精确度、召回率和f1分数的直方图。
饼图
我们可以使用Matplotlib的pie
函数来绘制饼图。这个函数可以用于可视化给定类别的分类结果。我们可以将每个类别的support
作为饼图的输入数据,并使用饼图的颜色和标签来区分不同的类别。
# 绘制饼图
fig, ax = plt.subplots(figsize=(8,8))
colors = ['r', 'g', 'b', 'c']
for i, (index, row) in enumerate(data.iloc[:-3].iterrows()):
sizes = row['support']
ax.pie(sizes, labels=target_names[i:i+1],
colors=colors[i%len(colors)],
startangle=90,
counterclock=False,
autopct='%1.1f%%')
# 添加图例
ax.legend(target_names, loc='upper right')
plt.show()
这将绘制每个类别的分类结果的饼图。
热力图
最后,我们可以使用Matplotlib的heatmap
函数来绘制热力图。这个函数可以用于绘制全局性能指标。我们可以将每个类别的生成的性能指标作为矩阵输入数据,并使用热力图的颜色来显示模型的性能。
# 绘制热力图
fig, ax = plt.subplots(figsize=(8,6))
im = ax.imshow(data.iloc[:-3, :-1], cmap='PuBu')
# 添加坐标轴标签
ax.set_xticks(range(len(data.iloc[:-3, :-1].columns)))
ax.set_xticklabels(data.iloc[:-3, :-1].columns)
ax.set_yticks(range(len(data.iloc[:-3, :-1].index)))
ax.set_yticklabels(data.iloc[:-3, :-1].index)
# 添加注释
for i in range(len(data.iloc[:-3, :-1].index)):
for j in range(len(data.iloc[:-3, :-1].columns)):
text = ax.text(j, i, data.iloc[i, j],
ha='center', va='center', color='black')
# 添加图例
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel('Performance', rotation=-90, va="bottom")
plt.show()
这将绘制一个热力图,其中列是性能指标,行是每个类别。
总结
通过本文,我们学习了如何使用Matplotlib来绘制Scikit-learn的分类报告,以更好地理解分类模型的性能。我们使用了直方图、饼图和热力图等不同的图表类型来可视化分类指标。这些图表能够帮助我们更直观地分析和理解分类模型的性能,以便更好地调整模型和改进性能。