当前位置:网站首页>CNN中的混淆矩阵 | PyTorch系列(二十三)
CNN中的混淆矩阵 | PyTorch系列(二十三)
2022-07-28 01:44:00 【51CTO】
文 |AI_study

原标题:CNN Confusion Matrix With PyTorch - Neural Network Programming
在这节课中,我们将建立一些函数,让我们能够得到训练集中每个样本的预测张量。然后,我们会看到如何使用这个预测张量,以及每个样本的标签,来创建一个混淆矩阵。这个混淆矩阵将允许我们查看我们的网络中哪些类别相互混淆。
- 准备数据
- 建立模型
- 训练模型
- 分析模型的结果
- 构建、绘制和解释一个混淆矩阵
有关所有代码设置细节,请参阅本课程的前一节。
混淆矩阵要求
要为整个数据集创建一个混淆矩阵,我们需要一个与训练集长度相同的一维预测张量。
这个预测张量将包含我们训练集中每个样本的10个预测(每个服装类别一个)。在我们得到这个张量之后,我们可以使用标签张量来生成一个混淆矩阵。
一个混淆矩阵将告诉我们模型在哪里被混淆了。更具体地说,混淆矩阵将显示模型正确预测的类别和模型不正确预测的类别。对于不正确的预测,我们将能够看到模型预测的类别,这将告诉我们哪些类别使模型混乱。
获取整个训练集的预测
为了得到所有训练集样本的预测,我们需要通过网络传递所有样本。为此,可以创建一个batch_size=1的DataLoader。这将一次性向网络传递一批数据,并为所有训练集样本提供所需的预测张量。
然而,根据计算资源和训练集的大小,如果我们在不同的数据集上训练,我们需要一种方法来预测更小的批量并收集结果。为了收集结果,我们将使用torch.cat()函数将输出张量连接在一起,以获得单个预测张量。我们来建立一个函数。
建立一个函数来获得所有样本的预测
我们将创建一个名为get_all_preds()的函数,并传递一个模型和一个数据加载器。该模型将用于获取预测,而数据加载器将用于提供来自训练集的批次。
所有函数需要做的就是遍历数据加载器,将批处理传递给模型,并将每个批处理的结果连接到一个预测张量,该张量将返回给调用者。
此函数的植入会创建一个空张量,all_preds来保存输出预测。然后,迭代来自数据加载器的批处理,并将输出预测与all_preds张量连接在一起。最后,所有预测all_preds将返回给调用方。
请注意,在顶部,我们已使用@ torch.no_grad() PyTorch装饰对函数进行了注释。这是因为我们希望该函数执行忽略梯度跟踪。
这是因为梯度跟踪占用内存,并且在推理(在不训练的情况下获得预测)期间,无需跟踪计算图。装饰器是在执行特定功能时局部关闭梯度跟踪功能的一种方法。
本地禁用PyTorch梯度跟踪
我们现在准备调用以获取训练集的预测。我们需要做的就是创建一个具有合理批处理大小的数据加载器,并将模型和数据加载器传递给get_all_preds() 函数。
在上一节中,我们了解了在不需要时如何使用PyTorch的梯度跟踪功能,并在开始训练过程时将其重新打开。
每当我们要使用Backward()函数计算梯度时,我们特别需要梯度计算功能。否则,将其关闭是一个好主意,因为将其关闭会减少计算的内存消耗,例如 当我们使用网络进行预测(推理)时。
这两个选项均有效。让我们保留所有这些并获得我们的预测。
使用预测张量
现在,有了预测张量,我们可以将其传递给我们在上一节中创建的get_num_correct()函数以及训练集标签,以获取正确预测的总数。
我们可以看到正确预测的总数,并通过除以训练集中的样本数来打印准确性。
建立混淆矩阵
我们构建混淆矩阵的任务是将预测值的数量与真实值(目标)进行比较。
这将创建一个充当热图的矩阵,告诉我们预测值相对于真实值的下降位置。
为此,我们需要具有目标张量和train_preds张量中的预测标签。
现在,如果我们逐元素比较两个张量,我们可以看到预测的标签是否与目标匹配。此外,如果我们要计算预测标签与目标标签的数量,则两个张量内的值将作为矩阵的坐标。让我们沿着第二维堆叠这两个张量,以便我们可以有60,000个有序对。
现在,我们可以遍历这些对,并计算矩阵中每个位置的出现次数。让我们创建矩阵。由于我们有十个预测类别,因此将有一个十乘十的矩阵。检查此处以了解stack()函数。
https://deeplizard.com/learn/video/kF2AlpykJGY
现在,我们将遍历预测目标对,并在每次发生特定位置时向矩阵内的值添加一个。
这为我们提供了以下混淆矩阵张量。
请注意,下面的示例将具有不同的值,因为这两个示例是在不同的时间创建的。
绘制混淆矩阵
为了将实际的混淆矩阵生成为numpy.ndarray,我们使用sklearn.metrics库中的confusion_matrix()函数。让我们将其与其他需要的导入一起导入。
对于最后一次导入,请注意plotcm是一个文件plotcm.py,位于当前目录中的资源文件夹中。在plotcm.py文件中,有一个称为plot_confusion_matrix()的函数,我们将调用该函数。您将需要在系统上实现此功能。我们将在稍后讨论如何执行此操作。首先,让我们生成混淆矩阵。
我们可以像这样生成混淆矩阵:
PyTorch张量是类似于数组的Python对象,因此我们可以将它们直接传递给confusion_matrix()函数。我们相对于train_preds张量的第一维传递训练集标签张量(targets)和argmax,这为我们提供了混淆矩阵数据结构。
要实际绘制混淆矩阵,我们需要一些自定义代码,这些代码已放入名为plotcm的本地文件中。该函数称为plot_confusion_matrix()。plotcm.py文件需要包含以下内容,并且位于当前目录的resources文件夹中。
请注意,您也可以只将此代码复制到笔记本中,或避免导入的任何内容。
plotcm.py:
来源-scikit-learn.org
对于导入,我们这样做:
我们已经准备好绘制混淆矩阵,但是首先我们需要创建一个预测类名称列表,以传递给plot_confusion_matrix()函数。下表给出了我们的预测类及其相应的索引:

这使我们可以调用以绘制矩阵:

解释混淆矩阵
混淆矩阵具有三个轴:
- 预测标签(类)
- 真实标签
- 热图值(彩色)
预测标签和真实标签向我们显示了我们正在处理的预测类。矩阵对角线表示矩阵中预测和真值相同的位置,因此我们希望此处的热图更暗。
任何不在对角线上的值都是不正确的预测,因为预测和真实标签不匹配。要读取该图,我们可以使用以下步骤:
- 在水平轴上选择一个预测标签。
- 检查此标签的对角线位置以查看正确的总数。
- 检查其他非对角线位置以查看网络混乱之处。
例如,网络正在将T恤/上衣与衬衫混淆,但并未将T恤/上衣与以下物质混淆:
- Ankle boot
- Sneaker
- Sandal
如果我们考虑一下,这很有意义。随着我们模型的学习,我们将看到对角线之外的数字越来越小。
在本系列的这一点上,我们已经完成了许多在PyTorch中构建和训练CNN的工作。恭喜!
文章中内容都是经过仔细研究的,本人水平有限,翻译无法做到完美,但是真的是费了很大功夫,希望小伙伴能动动你性感的小手,分享朋友圈,支持一下我 ^_^
英文原文链接是:
https://deeplizard.com/learn/video/0LhiS6yu2qQ>




边栏推荐
- How do you use the jar package sent by others (how to use the jar package sent by others)
- 【TA-霜狼_may-《百人计划》】图形3.5 Early-z 和 Z-prepass
- [self growth website collection]
- How is insert locked in MySQL? (glory Collection Edition)
- POC simulation attack weapon - Introduction to nucleus (I)
- Pytorch optimizer settings
- [solution] solve the problem of SSH connection being inactive for a long time and being stuck and disconnected
- JVM tuning -xms -xmx -xmn -xss
- [data processing] boxplot drawing
- Lombok prompts variable log error when using JUnit test in idea
猜你喜欢

Which users are suitable for applying for rapidssl certificate

Deep Residual Learning for Image Recognition浅读与实现

数据中台夯实数据基础

【图像隐藏】基于DCT、DWT、LHA、LSB的数字图像信息隐藏系统含各类攻击和性能参数附matlab代码

Some shortest path problems solved by hierarchical graph

程序里随处可见的interface,真的有用吗?真的用对了吗?

Canvas from getting started to persuading friends to give up (graphic version)

第三章 队列

怎么简单实现菜单拖拽排序的功能

【软件测试】—— 自动化测试之unittest框架
随机推荐
Interviewer: what is the factory method mode?
Center-based 3D Object Detection and Tracking(基于中心的3D目标检测和跟踪 / CenterPoint)论文笔记
[hcip] BGP Foundation
Flutter神操作学习之(满级攻略)
Share an esp32 relay
[hcip] routing strategy, strategic routing
How do you use the jar package sent by others (how to use the jar package sent by others)
Explore flex basis
第二季度邮件安全报告:邮件攻击暴增4倍,利用知名品牌获取信任
【ELM分类】基于核极限学习机和极限学习机实现UCI数据集分类附matlab代码
[TA frost wolf \u may - hundred people plan] Figure 3.7 TP (d) r architecture of mobile terminal
PS simple to use
Is it safe to buy funds on Alipay? I want to make a fixed investment in the fund
软件产品第三方测试费用为什么没有统一的报价?
JVM tuning -xms -xmx -xmn -xss
selenium+pytest+allure综合练习
0动态规划中等 LeetCode873. 最长的斐波那契子序列的长度
【软件测试】—— 自动化测试之unittest框架
MySQL blocking monitoring script
【 图像去雾】基于暗通道和非均值滤波实现图像去雾附matlab代码