当前位置:网站首页>Complete model verification (test, demo) routine
Complete model verification (test, demo) routine
2022-07-08 01:01:00 【booze-J】
article
Network model training and preservation reference utilize GPU Training network models , And the loaded network model is also trained by the code in this article .
Validation model sample code :
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
# Get the picture storage path
image_path = "./images/img.png"
# Read the picture
image = Image.open(image_path)
# The type of picture read is <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=296x183 at 0x1FBD755E340>
print("image:\n",image)
image = image.convert("RGB")
''' The next step is to image Do channel conversion , because png The format is four channel , except RGB Outside of three channels , There is also a transparency channel . So we call image = image.convert("RGB"), Keep its color channel Of course , If the picture is originally three color channels , After this operation , unchanged . With this step, you can adapt png,jpg Pictures in various formats 、 '''
# [Resize](https://pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html?highlight=resize#torchvision.transforms.Resize)
# Why Resize This step ? Because the required input of our network model is 32*32 Size picture
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
# take image Convert to the appropriate type
image = transform(image)
print("image:\n",image.shape) # torch.Size([3, 32, 32])
# Building neural networks ( Open a separate file to store the network model )
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model(x)
return x
# Load network model
model = torch.load("./model/obj_0.pth")
# Transform the picture into four-dimensional (3,32,32) -> (1,3,32,32)
image = torch.reshape(image,(1,3,32,32))
# Define the equipment used for the test Different training methods of models (GPU Training 、CPU Training ) The data type when testing is also different
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Look at the model
print(model)
model.eval() # When it comes to testing , You may often forget this step
with torch.no_grad(): # # You may often forget this step
# Model if used cuda Trained , You also need to use cuda Type of data to test
output = model(image.to(device))
''' tensor([[-1.1961, 0.1016, 0.6076, 0.5585, 0.4856, 0.4466, 0.4176, 0.3158, -1.4603, -0.5929]], device='cuda:0') You can see output contain 10 Data , Each data represents a probability that the test image belongs to this class '''
print(output)
# Print out the category predicted by the test image using the network model The discovery is different from the actual category The reason is that the training times of the network model are less More accurate network models can be obtained by increasing training batches and adjusting learning rates
print(output.argmax(1).item())
Some points needing attention are described in detail in the comments in the code .
Be careful
- png The format is four channel , except RGB Outside of three channels , There is also a transparency channel . So we call
image = image.convert("RGB")
, Keep its color channel of course , If the picture is originally three color channels , After this operation , unchanged . - Before the picture to be predicted is introduced into the network model for prediction , Pretreatment is needed first , Whether the image size meets the input requirements , Whether the image format and dimension meet the requirements, etc .
- Model if used cuda Trained , You also need to use cuda Type of data to test
- Print out the category predicted by the test image using the network model , The discovery is different from the actual category , The reason is that the training times of the network model are less or the learning rate is inappropriate , More accurate network models can be obtained by increasing training batches and adjusting learning rates .
边栏推荐
- 完整的模型训练套路
- 10.CNN应用于手写数字识别
- 9.卷积神经网络介绍
- Jemter distributed
- 手写一个模拟的ReentrantLock
- ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
- What does interface testing test?
- 6.Dropout应用
- [OBS] the official configuration is use_ GPU_ Priority effect is true
- Is it safe to speculate in stocks on mobile phones?
猜你喜欢
Kubernetes Static Pod (静态Pod)
y59.第三章 Kubernetes从入门到精通 -- 持续集成与部署(三二)
130. 被圍繞的區域
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
FOFA-攻防挑战记录
Codeforces Round #804 (Div. 2)(A~D)
【深度学习】AI一键换天
语义分割模型库segmentation_models_pytorch的详细使用介绍
12.RNN应用于手写数字识别
Kubernetes static pod (static POD)
随机推荐
Su embedded training - Day7
基础篇——整合第三方技术
fabulous! How does idea open multiple projects in a single window?
Summary of the third course of weidongshan
New library online | information data of Chinese journalists
tourist的NTT模板
Summary of weidongshan phase II course content
Langchao Yunxi distributed database tracing (II) -- source code analysis
ABAP ALV LVC模板
Su embedded training - Day6
Lecture 1: the entry node of the link in the linked list
They gathered at the 2022 ecug con just for "China's technological power"
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
12. RNN is applied to handwritten digit recognition
9.卷积神经网络介绍
Reentrantlock fair lock source code Chapter 0
ReentrantLock 公平锁源码 第0篇
攻防演练中沙盘推演的4个阶段
Codeforces Round #804 (Div. 2)(A~D)
手机上炒股安全么?