当前位置:网站首页>A simple neural network model based on MLP full connection layer
A simple neural network model based on MLP full connection layer
2022-07-28 22:52:00 【Meteor shower ADI】
Now let's build the simplest neural network , Inspiration and templates come from bilibili One of up Lord Lavita Brother , Relevant video links are as follows :
The following is the simplest neural network classification model that can run through , I will gradually write how to put this simple neural network MLP Model , Change to CNN Neural network model based on convolutional neural network , So as to greatly improve the accuracy , Build a variety of mainstream neural network models, including VGG,Lenet,Alexnet,resnet,transformer wait
The first step is to build the guide package : Here are some common packages needed to build a neural network , You need to use it api
####################################################
# MNIST classification #
####################################################
# Guide pack
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as FThe second step : Define super parameters ,MLP Type of neural network superparameters often need to be preset .
# Define super parameters
num_epochs=3
batch_size=256
learning_rate=0.001The third step : load MNIST Data sets .
train_dataset=datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=False)
test_dataset=datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=False)
train_loader=DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(dataset=test_dataset,batch_size=len(test_dataset),shuffle=False)Step four : Set up your workout configuration
# Set up your workout configuration
cuda = True if torch.cuda.is_available() else FalseStep five : To build the network , With simple MLP The whole connection layer is used for data classification tasks . As shown below, the neural network is composed of three layers , Input layer , Hidden layer , Output layer .
# Build neural network architecture
class NeuralNet(nn.Module):
def __init__(self,num_input,num_hidden,num_class):
super(NeuralNet,self).__init__()
self.fc1=nn.Linear(num_input,num_hidden)
self.fc2=nn.Linear(num_hidden,num_class)
self.act=nn.ReLU()
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
Step six : Initialize schema , Define the loss function and optimizer
# Initialize schema
model = NeuralNet(num_input,num_hidden,num_class).cuda()
# Define the loss function and optimizer
LossF=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
Step seven , Training ,MINST The dataset is 28*28 Of
# Training
for epoch in range(num_epochs):
for batch_index, (images, labels) in enumerate(Train_loader):
images = images.reshape(-1, 28*28).cuda()
labels = labels.cuda()
outputs = model(images)
# Calculate the loss
loss = LossF(outputs, labels)
# Back propagation of gradient
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_index % 100 == 0:
print('[{}/{}],[{}/{}],loss={:.4f}'.format(epoch,num_epochs,batch_index,len(Train_loader),loss))
Step eight : test
# test
with torch.no_grad():
correct_num=0
total_num=0
for images,labels in Test_loader:
images=images.reshape(-1,28*28).cuda()
labels=labels.cuda()
outputs=model(images)
_,predictions = torch.max(outputs,1)
correct_num+=(predictions==labels).sum()
total_num+=(predictions.size(0))
print(" The accuracy of the test set is :{}%".format(correct_num/total_num*100))
Finally, the result diagram is attached
C:\Users\25566\miniconda3\python.exe C:/Users/25566/Desktop/Pytorch-Basics-main/minist classification .py
[0/3],[0/235],loss=2.3095
[0/3],[100/235],loss=0.2861
[0/3],[200/235],loss=0.2120
[1/3],[0/235],loss=0.1810
[1/3],[100/235],loss=0.2953
[1/3],[200/235],loss=0.1477
[2/3],[0/235],loss=0.1991
[2/3],[100/235],loss=0.1295
[2/3],[200/235],loss=0.1428
The accuracy of the test set is :95.72000122070312%
Process finished with exit code 0
The above results can be run by copying line by line , Among them, the super parameters can be adjusted manually .
边栏推荐
- Morphology of image
- STM32 - Basic timer (tim6, tim7) working process, interpretation function block diagram, timing analysis, cycle calculation
- OSV-q ValueError: axes don‘t match array
- 770. Word replacement
- Baidu map usage
- Stm32subeide (10) -- ADC scans multiple channels in DMA mode
- [3D target detection] 3dssd (II)
- JSON file to PNG image (batch conversion / image naming / migration / pixel value change) [tips]
- DIP-VBTV: Color Image Restoration Model Combining Deep Image Prior and Vector Bundle Total Variation
- [reprint] the token token is used in the login scenario
猜你喜欢

Improvement 17 of yolov5: cnn+transformer -- integrating bottleneck transformers

STM32 - Communication

《Shortening passengers’ travel time A dynamic metro train scheduling approach using deep reinforcem》

OSV_ q AttributeError: ‘numpy. ndarray‘ object has no attribute ‘clone‘

定了!哪吒S全系产品将于7月31日上市发售

Install PCL and VTK under the background of ROS installation, and solve VTK and PCL_ ROS conflict problem

Anomaly detection summary: intensity_ based/Normalizing Flow

LTE cell search process and sch/bch design

Simple es highlight practice

STM32 - external interrupt application (exti) (use cubemx to configure interrupts)
随机推荐
Summary of common formula notes for solving problems in Higher Mathematics
[connect set-top box] - use ADB command line to connect ec6108v9 Huawei Yuehe box wirelessly
Yolov5 improvement 4: add ECA channel attention mechanism
轮子七:TCP客户端
Torch.fft.fft 2. () error reporting problem solution
JS获取当前时间(年月日时分秒)
Command line agent: proxychains configuration
Target segmentation learning
从 IPv4 向 IPv6 的迁移
Configuration and official document of Freia library [tips]
B站713故障后的多活容灾建设|TakinTalks大咖分享
Anomaly detection summary: intensity_ based/Normalizing Flow
Pictures are named in batches in order (change size /jpg to PNG) [tips]
Qt+ffmpeg environment construction
Summary of C language learning content
OSV_ q Expected all tensors to be on the same device, but found at least two devices, cuda:0
JS get the current time (year month day hour minute second)
Fastflow [abnormal detection: normalizing flow]
即将获售高通、联发科芯片,荣耀要超越华为做国内第一?
投资1450亿欧元!欧盟17国宣布联合发展半导体技术