当前位置:网站首页>Complete model verification (test, demo) routine
Complete model verification (test, demo) routine
2022-07-08 01:01:00 【booze-J】
article
Network model training and preservation reference utilize GPU Training network models , And the loaded network model is also trained by the code in this article .
Validation model sample code :
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
# Get the picture storage path
image_path = "./images/img.png"
# Read the picture
image = Image.open(image_path)
# The type of picture read is <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=296x183 at 0x1FBD755E340>
print("image:\n",image)
image = image.convert("RGB")
''' The next step is to image Do channel conversion , because png The format is four channel , except RGB Outside of three channels , There is also a transparency channel . So we call image = image.convert("RGB"), Keep its color channel Of course , If the picture is originally three color channels , After this operation , unchanged . With this step, you can adapt png,jpg Pictures in various formats 、 '''
# [Resize](https://pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html?highlight=resize#torchvision.transforms.Resize)
# Why Resize This step ? Because the required input of our network model is 32*32 Size picture
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
# take image Convert to the appropriate type
image = transform(image)
print("image:\n",image.shape) # torch.Size([3, 32, 32])
# Building neural networks ( Open a separate file to store the network model )
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model(x)
return x
# Load network model
model = torch.load("./model/obj_0.pth")
# Transform the picture into four-dimensional (3,32,32) -> (1,3,32,32)
image = torch.reshape(image,(1,3,32,32))
# Define the equipment used for the test Different training methods of models (GPU Training 、CPU Training ) The data type when testing is also different
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Look at the model
print(model)
model.eval() # When it comes to testing , You may often forget this step
with torch.no_grad(): # # You may often forget this step
# Model if used cuda Trained , You also need to use cuda Type of data to test
output = model(image.to(device))
''' tensor([[-1.1961, 0.1016, 0.6076, 0.5585, 0.4856, 0.4466, 0.4176, 0.3158, -1.4603, -0.5929]], device='cuda:0') You can see output contain 10 Data , Each data represents a probability that the test image belongs to this class '''
print(output)
# Print out the category predicted by the test image using the network model The discovery is different from the actual category The reason is that the training times of the network model are less More accurate network models can be obtained by increasing training batches and adjusting learning rates
print(output.argmax(1).item())
Some points needing attention are described in detail in the comments in the code .
Be careful
- png The format is four channel , except RGB Outside of three channels , There is also a transparency channel . So we call
image = image.convert("RGB")
, Keep its color channel of course , If the picture is originally three color channels , After this operation , unchanged . - Before the picture to be predicted is introduced into the network model for prediction , Pretreatment is needed first , Whether the image size meets the input requirements , Whether the image format and dimension meet the requirements, etc .
- Model if used cuda Trained , You also need to use cuda Type of data to test
- Print out the category predicted by the test image using the network model , The discovery is different from the actual category , The reason is that the training times of the network model are less or the learning rate is inappropriate , More accurate network models can be obtained by increasing training batches and adjusting learning rates .
边栏推荐
- 130. Zones environnantes
- Summary of weidongshan phase II course content
- Codeforces Round #804 (Div. 2)(A~D)
- Summary of the third course of weidongshan
- How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
- Su embedded training - Day5
- 5G NR 系统消息
- Course of causality, taught by Jonas Peters, University of Copenhagen
- [necessary for R & D personnel] how to make your own dataset and display it.
- Invalid V-for traversal element style
猜你喜欢
Codeforces Round #804 (Div. 2)(A~D)
130. 被圍繞的區域
A network composed of three convolution layers completes the image classification task of cifar10 data set
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
My best game based on wechat applet development
NVIDIA Jetson test installation yolox process record
14.绘制网络模型结构
Langchao Yunxi distributed database tracing (II) -- source code analysis
What does interface testing test?
跨模态语义关联对齐检索-图像文本匹配(Image-Text Matching)
随机推荐
C# ?,?.,?? .....
跨模态语义关联对齐检索-图像文本匹配(Image-Text Matching)
串口接收一包数据
牛客基础语法必刷100题之基本类型
Four stages of sand table deduction in attack and defense drill
51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
大二级分类产品页权重低,不收录怎么办?
Service Mesh的基本模式
Kubernetes static pod (static POD)
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
6.Dropout应用
The weight of the product page of the second level classification is low. What if it is not included?
AI遮天传 ML-初识决策树
1.线性回归
New library launched | cnopendata China Time-honored enterprise directory
取消select的默认样式的向下箭头和设置select默认字样
ABAP ALV LVC template
接口测试进阶接口脚本使用—apipost(预/后执行脚本)
ABAP ALV LVC模板
完整的模型训练套路