当前位置:网站首页>完整的模型验证(测试,demo)套路
完整的模型验证(测试,demo)套路
2022-07-07 23:11:00 【booze-J】
文章
网络模型训练与保存参考利用GPU训练网络模型,并且加载的网络模型也是该篇文章中的代码训练出来的。
验证模型示例代码:
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
# 获取图片存放路径
image_path = "./images/img.png"
# 读取图片
image = Image.open(image_path)
# 读取的图片类型为 <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=296x183 at 0x1FBD755E340>
print("image:\n",image)
image = image.convert("RGB")
'''接下来要对image进行通道转换,因为png格式是四通道的,除RGB三通道外,还有一个透明度通道。所以我们调用image = image.convert("RGB"),保留其颜色通道 当然,如果图片本来就是三个颜色通道,经过此操作,不变。加上这一步之后可以适应png,jpg各种格式的图片、 '''
# [Resize](https://pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html?highlight=resize#torchvision.transforms.Resize)
# 为什么要进行Resize这一步呢?是因为我们这个网络模型的要求输入是32*32大小的图片
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
# 将image转化为合适的类型
image = transform(image)
print("image:\n",image.shape) # torch.Size([3, 32, 32])
# 搭建神经网络(单独开一个文件存放网络模型)
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model(x)
return x
# 加载网络模型
model = torch.load("./model/obj_0.pth")
# 将图片转化为四维的(3,32,32) -> (1,3,32,32)
image = torch.reshape(image,(1,3,32,32))
# 定义测试所用的设备 模型的训练方式的不同(GPU训练、CPU训练) 测试时的数据类型也不一样
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 查看模型
print(model)
model.eval() # 测试的时候,这一步大家也许经常遗忘
with torch.no_grad(): # # 这一步大家也许经常遗忘
# 模型如果是使用cuda训练的,则测试的时候也需要使用cuda类型的数据进行测试
output = model(image.to(device))
''' tensor([[-1.1961, 0.1016, 0.6076, 0.5585, 0.4856, 0.4466, 0.4176, 0.3158, -1.4603, -0.5929]], device='cuda:0') 可以看到output含10个数据,每个数据代表着测试图片属于该类的一个概率 '''
print(output)
# 打印出测试图片使用网络模型预测的类别 发现和实际类别不同哈 原因是因为该网络模型的训练次数较少 增加训练批次和调整学习率可以得到更精准的网络模型
print(output.argmax(1).item())
一些需要注意的点在代码中的注释有详细的描述。
注意
- png格式是四通道的,除RGB三通道外,还有一个透明度通道。所以我们调用
image = image.convert("RGB")
,保留其颜色通道当然,如果图片本来就是三个颜色通道,经过此操作,不变。 - 在将待预测图片传入网络模型中预测之前,需要先进行预处理,图片大小是否符合输入要求,图片格式和维度是否符合要求等等。
- 模型如果是使用cuda训练的,则测试的时候也需要使用cuda类型的数据进行测试
- 打印出测试图片使用网络模型预测的类别,发现和实际类别不同,原因是因为该网络模型的训练次数较少或者学习率不合适,增加训练批次和调整学习率可以得到更精准的网络模型。
边栏推荐
- SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
- Lecture 1: the entry node of the link in the linked list
- STL -- common function replication of string class
- Introduction to ML regression analysis of AI zhetianchuan
- Qt不同类之间建立信号槽,并传递参数
- ReentrantLock 公平锁源码 第0篇
- The method of server defense against DDoS, Hangzhou advanced anti DDoS IP section 103.219.39 x
- Cve-2022-28346: Django SQL injection vulnerability
- 华为交换机S5735S-L24T4S-QA2无法telnet远程访问
- 接口测试进阶接口脚本使用—apipost(预/后执行脚本)
猜你喜欢
51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出
How to learn a new technology (programming language)
v-for遍历元素样式失效
5g NR system messages
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
8.优化器
13.模型的保存和載入
华为交换机S5735S-L24T4S-QA2无法telnet远程访问
随机推荐
Introduction to paddle - using lenet to realize image classification method I in MNIST
ReentrantLock 公平锁源码 第0篇
Su embedded training - Day3
13.模型的保存和载入
Introduction to paddle - using lenet to realize image classification method II in MNIST
基于卷积神经网络的恶意软件检测方法
Jemter distributed
How to learn a new technology (programming language)
v-for遍历元素样式失效
Fofa attack and defense challenge record
An error is reported during the process of setting up ADG. Rman-03009 ora-03113
What does interface testing test?
语义分割模型库segmentation_models_pytorch的详细使用介绍
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
Qt不同类之间建立信号槽,并传递参数
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
The standby database has been delayed. Check that the MRP is wait_ for_ Log, apply after restarting MRP_ Log but wait again later_ for_ log
Interface test advanced interface script use - apipost (pre / post execution script)
C # generics and performance comparison
"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points