Matplotlib 如何绘制Scikit-learn分类报告

Matplotlib 如何绘制Scikit-learn分类报告

在机器学习领域,使用Scikit-learn进行分类任务是常见的选择。当我们完成了一个分类任务后,我们需要评估模型的性能。这个过程可以使用Scikit-learn的分类报告来完成。分类报告是一个性能指标汇总的表格,提供精确度、召回率和F1-score等指标,非常适合评估分类模型的性能。在这篇文章中,我们将学习如何使用Matplotlib将Scikit-learn的分类报告可视化,以便更直观地分析和理解性能指标。

阅读更多:Matplotlib 教程

Scikit-learn分类报告简介

Scikit-learn分类报告用于评估分类模型的性能。它提供了多个指标来评估模型的分类结果。这些指标包括准确率(precission)、召回率(recall)、f1分数,以及每个类别的支持数(support)。

分类报告是一个基于表格的布局,它总结了每个类别的分类指标。其中包括了以下信息:

  • 准确率(precision): 用于预测为正的样本中的真正正样本的比例
  • 召回率(recall): 正确预测为正的样本数占所有真正正样本的比例
  • f1分数(f1-score): 准确率和召回率的加权平均数,其值介于0和1之间
  • 支持(support): 给定类别的真实样本数

在Scikit-learn中,我们可以使用classification_report函数生成分类报告。以下是一个例子:

from sklearn.metrics import classification_report
y_pred = [0, 1, 2, 3]
y_true = [0, 1, 2, 2]
target_names = ['class 0', 'class 1', 'class 2', 'class 3']
print(classification_report(y_true, y_pred, target_names=target_names))

输出结果为:

              precision    recall  f1-score   support

     class 0       1.00      1.00      1.00         1
     class 1       1.00      1.00      1.00         1
     class 2       0.67      0.67      0.67         3
     class 3       0.00      0.00      0.00         0

    accuracy                           0.75         5
   macro avg       0.67      0.67      0.67         5
weighted avg       0.80      0.75      0.77         5

Matplotlib绘制分类报告

Scikit-learn生成的分类报告功能很强大,但它的输出结果是以纯文本形式呈现的,这可能让我们的分析和理解变得棘手。我们可以使用Matplotlib来绘制Scikit-learn生成的分类报告以可视化展示。这可以让我们更好地理解不同类别之间的性能差异。

以下是Matplotlib的一些绘图功能:

  • 使用直方图(histogram)来可视化数据的分布
  • 使用饼图(pie chart)来可视化给定类别的分类结果
  • 使用热力图(heatmap)来绘制全局性能指标

我们将逐步介绍如何使用Matplotlib来绘制Scikit-learn分类报告。

直方图

我们可以使用Matplotlib的hist函数来绘制直方图。这个函数可以绘制一组数据的分布情况。我们可以将每个类别的分类指标作为直方图的输入数据,并通过改变颜色和位置来区分不同的类别。

import matplotlib.pyplot as plt

# 生成分类报告
report = classification_report(y_true,y_pred, target_names=target_names, output_dict=True)

# 将分类指标转换为数据框
import pandas as pd
data = pd.DataFrame(report).transpose()

# 绘制直方图
fig, ax = plt.subplots(figsize=(10,6))

for index, row in data.iloc[:-3].iterrows():
    ax.bar(index + 0.2, row['precision'], color='r', width=0.2, label='Precision')
    ax.bar(index + 0.4, row['recall'], color='g', width=0.2, label='Recall')
    ax.bar(index + 0.6, row['f1-score'], color='b', width=0.2, label='F1-score')

ax.set_xticks(range(len(data.iloc[:-3])))
ax.set_xticklabels(data.iloc[:-3].index)
ax.legend(loc='upper right')
plt.show()

这将绘制一个每个类别的精确度、召回率和f1分数的直方图。

饼图

我们可以使用Matplotlib的pie函数来绘制饼图。这个函数可以用于可视化给定类别的分类结果。我们可以将每个类别的support作为饼图的输入数据,并使用饼图的颜色和标签来区分不同的类别。

# 绘制饼图
fig, ax = plt.subplots(figsize=(8,8))
colors = ['r', 'g', 'b', 'c']

for i, (index, row) in enumerate(data.iloc[:-3].iterrows()):
    sizes = row['support']
    ax.pie(sizes, labels=target_names[i:i+1], 
            colors=colors[i%len(colors)], 
            startangle=90, 
            counterclock=False, 
            autopct='%1.1f%%')

# 添加图例
ax.legend(target_names, loc='upper right')
plt.show()

这将绘制每个类别的分类结果的饼图。

热力图

最后,我们可以使用Matplotlib的heatmap函数来绘制热力图。这个函数可以用于绘制全局性能指标。我们可以将每个类别的生成的性能指标作为矩阵输入数据,并使用热力图的颜色来显示模型的性能。

# 绘制热力图
fig, ax = plt.subplots(figsize=(8,6))
im = ax.imshow(data.iloc[:-3, :-1], cmap='PuBu')

# 添加坐标轴标签
ax.set_xticks(range(len(data.iloc[:-3, :-1].columns)))
ax.set_xticklabels(data.iloc[:-3, :-1].columns)
ax.set_yticks(range(len(data.iloc[:-3, :-1].index)))
ax.set_yticklabels(data.iloc[:-3, :-1].index)

# 添加注释
for i in range(len(data.iloc[:-3, :-1].index)):
    for j in range(len(data.iloc[:-3, :-1].columns)):
        text = ax.text(j, i, data.iloc[i, j],
                       ha='center', va='center', color='black')

# 添加图例
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel('Performance', rotation=-90, va="bottom")
plt.show()

这将绘制一个热力图,其中列是性能指标,行是每个类别。

总结

通过本文,我们学习了如何使用Matplotlib来绘制Scikit-learn的分类报告,以更好地理解分类模型的性能。我们使用了直方图、饼图和热力图等不同的图表类型来可视化分类指标。这些图表能够帮助我们更直观地分析和理解分类模型的性能,以便更好地调整模型和改进性能。

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程