当前位置:网站首页>The simple neural network model based on full connection layer MLP is changed to the model based on CNN convolutional neural network
The simple neural network model based on full connection layer MLP is changed to the model based on CNN convolutional neural network
2022-07-28 22:52:00 【Meteor shower ADI】
In the last blog, I wrote how to build a system based on MLP The neural network model of , This time I will talk about how to base MLP The neural network model of is changed to be based on simple CNN The neural network model of .
Last blog link :
First step : Guide pack , The operation is roughly the same as that of the previous model
# Guide pack
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as FThe second step : Define super parameters , Three superparameter settings have been deleted
# Define super parameters
num_epochs=3
batch_size=256
learning_rate=0.001The third step : Load database , Change the data set according to your own needs , Here we use MNIST
# load MNIST data
Train_dataset = datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=False)
Test_dataset = datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=False)
Train_loader = DataLoader(dataset=Train_dataset,batch_size=batch_size,shuffle=True)
Test_loader = DataLoader(dataset=Test_dataset,batch_size=len(Test_dataset),shuffle=False)Step four : Set up your workout configuration
# Set up your workout configuration
cuda = True if torch.cuda.is_available() else FalseStep five : structure CNN neural network
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv1 = nn.Conv2d(1, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self,x):
x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 8, 8)
x = self.pool2(x) # output(32, 4, 4)
x = x.view(-1, 32 * 4 * 4) # output(32*5*5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x
Step six : Initialization function , Define the loss function and optimizer
# Initialize schema
model=net().cuda()
# Define the loss function and optimizer
# Build neural network architecture
lossF=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
Step seven : Training
# Training
for epoch in range(num_epochs):
for batch_index,(images,labels) in enumerate(train_loader):
images=images.cuda()
labels=labels.cuda()
outputs=model(images)
# Calculate the loss
loss=lossF(outputs,labels)
# Gradient back propagation
optimizer.zero_grad()# Gradient clear
loss.backward()
optimizer.step()
if batch_index%100==0:
print('[{}/{}],[{}/{}],loss:{:.4}'.format(epoch,num_epochs,batch_index,len(train_loader),loss))Step eight : test
# test
with torch.no_grad():
correct_num=0
total_num=0
for images,labels in test_loader:
images=images.cuda()
labels=labels.cuda()
outputs=model(images)
_,predictions=torch.max(outputs,1)
correct_num+=(predictions==labels).sum()
total_num+=(predictions.size(0))
print(" The accuracy of the test set is :{}%".format(correct_num / total_num * 100))The results are as follows :
C:\Users\25566\miniconda3\python.exe C:/Users/25566/Desktop/Pytorch-Basics-main/FashionMnist.py
[0/3],[0/235],loss:2.308
[0/3],[100/235],loss:0.2817
[0/3],[200/235],loss:0.1224
[1/3],[0/235],loss:0.1086
[1/3],[100/235],loss:0.1185
[1/3],[200/235],loss:0.1118
[2/3],[0/235],loss:0.07762
[2/3],[100/235],loss:0.04928
[2/3],[200/235],loss:0.04956
The accuracy of the test set is :98.05999755859375%
Process finished with exit code 0
边栏推荐
- Qt+FFmpeg环境搭建
- 记录一下关于三角函数交换积分次序的一道题
- NPM run dev, automatically open the browser after running the project
- STM32 - external interrupt application (exti) (use cubemx to configure interrupts)
- OSV_ q AttributeError: ‘numpy. ndarray‘ object has no attribute ‘clone‘
- shell脚本基础——Shell运行原理+变量、数组定义
- npm run dev,运行项目后自动打开浏览器
- Torch.fft.fft 2. () error reporting problem solution
- CFA [anomaly detection: embedded_based]
- PUA in the workplace, but it makes sense
猜你喜欢

Paper reading: deep forest / deep forest /gcforest

Qt+FFmpeg环境搭建

STM32 board level support package for keys

Summary of common formula notes for solving problems in Higher Mathematics

Annaconda installs pytoch and switches environments
![Differernet [anomaly detection: normalizing flow]](/img/75/958d753c20227fbbfe1085e7d6ce6f.png)
Differernet [anomaly detection: normalizing flow]

How to delete and remove the first row of elements in PHP two-dimensional array
![[get mobile information] - get mobile information through ADB command](/img/ad/b10c5d09a21fb0cb22aa8a002fbd99.png)
[get mobile information] - get mobile information through ADB command

Multi activity disaster recovery construction after 713 failure of station B | takintalks share

STM32 - advanced control timer (time base unit, functional block diagram, input, capture, output, open circuit)
随机推荐
Leetcode exercise 3 - palindromes
使用PCL批量将点云.bin文件转.pcd
Lenovo r9000p installation matlab2018a+cuda10.0 compilation
Use PCL to batch convert point cloud.Bin files to.Pcd
npm run dev,运行项目后自动打开浏览器
Common library code snippet pytorch_ based【tips】
775. Inverted words
Gd32f303 firmware library development (10) -- dual ADC polling mode scanning multiple channels
歌尔股份与上海泰矽微达成长期合作协议!专用SoC共促TWS耳机发展
Improvement 16 of yolov5: replace backbone network C3 with lightweight network pp-lcnet
Configuration and official document of Freia library [tips]
Symbol符号类型
Yolov5 improvement 6: add small target detection layer
Qt+FFmpeg环境搭建
使用PCL批量显示PCD点云数据流
Introduction to structure
770. 单词替换
The tenth improvement of yolov5: the loss function is improved to Siou
STM32 - external interrupt application (exti) (use cubemx to configure interrupts)
[connect your mobile phone wirelessly] - debug your mobile device wirelessly via LAN