当前位置:网站首页>Build the first neural network with pytoch and optimize it
Build the first neural network with pytoch and optimize it
2022-06-28 08:36:00 【Sol-itude】
I've been learning pytorch, This time I built a neural network following the tutorial , The most classic CIFAR10, Let's look at the principle first 
Input 3 passageway 32*32, Last pass 3 A convolution ,3 Maximum pooling , also 1 individual flatten, And two linearizations , Get ten outputs
The procedure is as follows :
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
Here we can also use tensorboard Have a look , Remember import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
stay tensorboard It's like this in English 
open NetWork
You can zoom in to see 
Neural networks have errors , So we use gradient descent to reduce the error
The code is as follows
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)## Using gradient descent as the optimizer
for epoch in range(20):## loop 20 Time
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()## Set the value of each drop to zero
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
My computer's GPU yes RTX2060 It belongs to the older one , It took about three times 1 minute , It was so slow that I finished running
Output results :
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
It can be seen that the error is getting smaller and smaller , But in the application 20 There are too few layers , When my new computer arrived, I ran 100 layer
边栏推荐
- 第六届智能家居亚洲峰会暨精品展(Smart Home Asia 2022)将于10月在沪召开
- A - Bi-shoe and Phi-shoe
- PLSQL installation under Windows
- Large current and frequency range that can be measured by Rogowski coil
- Force buckle 1024 video splicing
- Goldbach`s Conjecture
- Privacy computing fat----- offline prediction
- 【.NET6】gRPC服务端和客户端开发案例,以及minimal API服务、gRPC服务和传统webapi服务的访问效率大对决
- B_ QuRT_ User_ Guide(27)
- A - Bi-shoe and Phi-shoe
猜你喜欢

Children's unit of 2022 Paris fashion week ended successfully at Wuhan station on June 19

Large current and frequency range that can be measured by Rogowski coil

利尔达低代码数据大屏,铲平数据应用开发门槛

Discussion on the application of GIS 3D system in mining industry

第六届智能家居亚洲峰会暨精品展(Smart Home Asia 2022)将于10月在沪召开

About using font icons in placeholder

AI chief architect 8-aica-gao Xiang, in-depth understanding and practice of propeller 2.0

【无标题】

如何抑制SiC MOSFET Crosstalk(串擾)?

Error: `brew cask` is no longer a `brew` command. Use `brew <command> --cask` instead.
随机推荐
【无标题】
AWS saves data on the cloud (3)
Comment supprimer le crosstalk SiC MOSFET?
MySQL8.0 忘记 root 密码
[learning notes] differential constraint
Anniversary party
Goldbach`s Conjecture
VMware Workstation related issues
Ffmpeg streaming fails to update header with correct duration
[learning notes] simulation
Oracle RAC -- understanding of VIP
[learning notes] linear basis
如何抑制SiC MOSFET Crosstalk(串扰)?
Where is CentOS mysql5.5 configuration file
Trailing Zeroes (II)
Ffmpeg (I) AV_ register_ all()
Solve NPM err! Unexpected end of JSON input while parsing near
FatMouse and Cheese
887. egg drop
What is the bandwidth of the Tiktok server that can be used by hundreds of millions of people at the same time?