当前位置:网站首页>The simple neural network model based on full connection layer MLP is changed to the model based on CNN convolutional neural network
The simple neural network model based on full connection layer MLP is changed to the model based on CNN convolutional neural network
2022-07-28 22:52:00 【Meteor shower ADI】
In the last blog, I wrote how to build a system based on MLP The neural network model of , This time I will talk about how to base MLP The neural network model of is changed to be based on simple CNN The neural network model of .
Last blog link :
First step : Guide pack , The operation is roughly the same as that of the previous model
# Guide pack
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as FThe second step : Define super parameters , Three superparameter settings have been deleted
# Define super parameters
num_epochs=3
batch_size=256
learning_rate=0.001The third step : Load database , Change the data set according to your own needs , Here we use MNIST
# load MNIST data
Train_dataset = datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=False)
Test_dataset = datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=False)
Train_loader = DataLoader(dataset=Train_dataset,batch_size=batch_size,shuffle=True)
Test_loader = DataLoader(dataset=Test_dataset,batch_size=len(Test_dataset),shuffle=False)Step four : Set up your workout configuration
# Set up your workout configuration
cuda = True if torch.cuda.is_available() else FalseStep five : structure CNN neural network
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv1 = nn.Conv2d(1, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self,x):
x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 8, 8)
x = self.pool2(x) # output(32, 4, 4)
x = x.view(-1, 32 * 4 * 4) # output(32*5*5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x
Step six : Initialization function , Define the loss function and optimizer
# Initialize schema
model=net().cuda()
# Define the loss function and optimizer
# Build neural network architecture
lossF=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
Step seven : Training
# Training
for epoch in range(num_epochs):
for batch_index,(images,labels) in enumerate(train_loader):
images=images.cuda()
labels=labels.cuda()
outputs=model(images)
# Calculate the loss
loss=lossF(outputs,labels)
# Gradient back propagation
optimizer.zero_grad()# Gradient clear
loss.backward()
optimizer.step()
if batch_index%100==0:
print('[{}/{}],[{}/{}],loss:{:.4}'.format(epoch,num_epochs,batch_index,len(train_loader),loss))Step eight : test
# test
with torch.no_grad():
correct_num=0
total_num=0
for images,labels in test_loader:
images=images.cuda()
labels=labels.cuda()
outputs=model(images)
_,predictions=torch.max(outputs,1)
correct_num+=(predictions==labels).sum()
total_num+=(predictions.size(0))
print(" The accuracy of the test set is :{}%".format(correct_num / total_num * 100))The results are as follows :
C:\Users\25566\miniconda3\python.exe C:/Users/25566/Desktop/Pytorch-Basics-main/FashionMnist.py
[0/3],[0/235],loss:2.308
[0/3],[100/235],loss:0.2817
[0/3],[200/235],loss:0.1224
[1/3],[0/235],loss:0.1086
[1/3],[100/235],loss:0.1185
[1/3],[200/235],loss:0.1118
[2/3],[0/235],loss:0.07762
[2/3],[100/235],loss:0.04928
[2/3],[200/235],loss:0.04956
The accuracy of the test set is :98.05999755859375%
Process finished with exit code 0
边栏推荐
- 使用PCL批量将点云.bin文件转.pcd
- PHP库neo4j怎么安装及使用
- STM32_ Hal library driven framework
- 定了!哪吒S全系产品将于7月31日上市发售
- Ocr-gan [anomaly detection: Reconstruction Based]
- Solve various problems of sudo rosdep init and rosdep update
- es学习目录
- Improvement 18 of yolov5: the loss function is improved to alpha IOU loss function
- 二进制的原码、反码、补码
- CFA [anomaly detection: embedded_based]
猜你喜欢

LTE小区搜索过程及SCH/BCH设计

STM32 - memory, I2C protocol

STM32 - DMA direct memory access controller (cubemx configures DMA)

Console.log() console display... Solution

STM32 single chip microcomputer drive L298N

STM32 - external interrupt application (exti) (use cubemx to configure interrupts)

Vscode ROS configuration GDB debugging error record
![MKD [anomaly detection: knowledge disruption]](/img/15/10f5c8d6851e94dac764517c488dbc.png)
MKD [anomaly detection: knowledge disruption]

《Shortening passengers’ travel time A dynamic metro train scheduling approach using deep reinforcem》

Target segmentation learning
随机推荐
Target segmentation learning
Improvement 13 of yolov5: replace backbone network C3 with lightweight network efficientnetv2
美国FCC提供16亿美元资助本国运营商移除华为和中兴设备
OSV-q The size of tensor a (3) must match the size of tensor b (320) at non-singleton dimension 3
Pictures are named in batches in order (change size /jpg to PNG) [tips]
STM32 single chip microcomputer drive L298N
Annaconda installs pytoch and switches environments
JSON file to PNG image (batch conversion / image naming / migration / pixel value change) [tips]
OSV-q grd_ x=grd_ x[:, :, 0:-1, :]-data_ in[:, :, 1:, :]IndexError: too many indices for tensor of d
【三维目标检测】3DSSD(一)
PUA in the workplace, but it makes sense
Bluetooth smart Bracelet system based on STM32 MCU
Improvement 11 of yolov5: replace backbone network C3 with lightweight network mobilenetv3
Mspba [anomaly detection: representation_based]
LTE小区搜索过程及SCH/BCH设计
[virtual machine _2]-hyper-v and vmware/virtualbox cannot coexist
hp proliant dl380从U盘启动按哪个键
Record a question about the order of trigonometric function exchange integrals
STM32 - Communication
轮子七:TCP客户端