当前位置:网站首页>A simple neural network model based on MLP full connection layer
A simple neural network model based on MLP full connection layer
2022-07-28 22:52:00 【Meteor shower ADI】
Now let's build the simplest neural network , Inspiration and templates come from bilibili One of up Lord Lavita Brother , Relevant video links are as follows :
The following is the simplest neural network classification model that can run through , I will gradually write how to put this simple neural network MLP Model , Change to CNN Neural network model based on convolutional neural network , So as to greatly improve the accuracy , Build a variety of mainstream neural network models, including VGG,Lenet,Alexnet,resnet,transformer wait
The first step is to build the guide package : Here are some common packages needed to build a neural network , You need to use it api
####################################################
# MNIST classification #
####################################################
# Guide pack
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
The second step : Define super parameters ,MLP Type of neural network superparameters often need to be preset .
# Define super parameters
num_epochs=3
batch_size=256
learning_rate=0.001
The third step : load MNIST Data sets .
train_dataset=datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=False)
test_dataset=datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=False)
train_loader=DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(dataset=test_dataset,batch_size=len(test_dataset),shuffle=False)
Step four : Set up your workout configuration
# Set up your workout configuration
cuda = True if torch.cuda.is_available() else False
Step five : To build the network , With simple MLP The whole connection layer is used for data classification tasks . As shown below, the neural network is composed of three layers , Input layer , Hidden layer , Output layer .
# Build neural network architecture
class NeuralNet(nn.Module):
def __init__(self,num_input,num_hidden,num_class):
super(NeuralNet,self).__init__()
self.fc1=nn.Linear(num_input,num_hidden)
self.fc2=nn.Linear(num_hidden,num_class)
self.act=nn.ReLU()
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
Step six : Initialize schema , Define the loss function and optimizer
# Initialize schema
model = NeuralNet(num_input,num_hidden,num_class).cuda()
# Define the loss function and optimizer
LossF=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
Step seven , Training ,MINST The dataset is 28*28 Of
# Training
for epoch in range(num_epochs):
for batch_index, (images, labels) in enumerate(Train_loader):
images = images.reshape(-1, 28*28).cuda()
labels = labels.cuda()
outputs = model(images)
# Calculate the loss
loss = LossF(outputs, labels)
# Back propagation of gradient
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_index % 100 == 0:
print('[{}/{}],[{}/{}],loss={:.4f}'.format(epoch,num_epochs,batch_index,len(Train_loader),loss))
Step eight : test
# test
with torch.no_grad():
correct_num=0
total_num=0
for images,labels in Test_loader:
images=images.reshape(-1,28*28).cuda()
labels=labels.cuda()
outputs=model(images)
_,predictions = torch.max(outputs,1)
correct_num+=(predictions==labels).sum()
total_num+=(predictions.size(0))
print(" The accuracy of the test set is :{}%".format(correct_num/total_num*100))
Finally, the result diagram is attached
C:\Users\25566\miniconda3\python.exe C:/Users/25566/Desktop/Pytorch-Basics-main/minist classification .py
[0/3],[0/235],loss=2.3095
[0/3],[100/235],loss=0.2861
[0/3],[200/235],loss=0.2120
[1/3],[0/235],loss=0.1810
[1/3],[100/235],loss=0.2953
[1/3],[200/235],loss=0.1477
[2/3],[0/235],loss=0.1991
[2/3],[100/235],loss=0.1295
[2/3],[200/235],loss=0.1428
The accuracy of the test set is :95.72000122070312%
Process finished with exit code 0
The above results can be run by copying line by line , Among them, the super parameters can be adjusted manually .
边栏推荐
- Yolov5 improvement 5: improve the feature fusion network panet to bifpn
- GD32F303固件库开发(10)----双ADC轮询模式扫描多个通道
- OSV_ Q write divergence operator div and Laplace stepped on the pit
- Using PCL to batch display PCD point cloud data flow
- STM32 - interrupt overview (interrupt priority)
- Improvement 18 of yolov5: the loss function is improved to alpha IOU loss function
- Annaconda installs pytoch and switches environments
- 使用PCL批量将点云.bin文件转.pcd
- Find out the maximum value of all indicators in epoch [tips]
- ES6 concept
猜你喜欢
STM32_ Hal library driven framework
842. 排列数字
Yolov5 improvement 12: replace backbone network C3 with lightweight network shufflenetv2
The tenth improvement of yolov5: the loss function is improved to Siou
[reprint] the token token is used in the login scenario
Command line agent: proxychains configuration
Yolov5 improvement 5: improve the feature fusion network panet to bifpn
Stm32subeide (10) -- ADC scans multiple channels in DMA mode
Ocr-gan [anomaly detection: Reconstruction Based]
Qt+ffmpeg environment construction
随机推荐
DirectX修复工具下载(exagear模拟器数据包在哪里)
STM32 - systick timer (cubemx configures systick)
Padim [anomaly detection: embedded based]
ES6, deep copy, shallow copy
LTE cell search process and sch/bch design
STM32 - DMA direct memory access controller (cubemx configures DMA)
STM32 -- program startup process
Multi activity disaster recovery construction after 713 failure of station B | takintalks share
OSV-q grd_ x=grd_ x[:, :, 0:-1, :]-data_ in[:, :, 1:, :]IndexError: too many indices for tensor of d
STM32 - reset and clock control (cubemx for clock configuration)
记录一下关于三角函数交换积分次序的一道题
hp proliant dl380从U盘启动按哪个键
OSV_ Q write divergence operator div and Laplace stepped on the pit
776. String shift inclusion problem
775. Inverted words
从 IPv4 向 IPv6 的迁移
Draem+sspcab [anomaly detection: block]
无代码开发平台管理后台入门教程
STM32 - external interrupt application (exti) (use cubemx to configure interrupts)
771. The longest consecutive character in a string