当前位置:网站首页>Build the first neural network with pytoch and optimize it
Build the first neural network with pytoch and optimize it
2022-06-28 08:36:00 【Sol-itude】
I've been learning pytorch, This time I built a neural network following the tutorial , The most classic CIFAR10, Let's look at the principle first 
Input 3 passageway 32*32, Last pass 3 A convolution ,3 Maximum pooling , also 1 individual flatten, And two linearizations , Get ten outputs
The procedure is as follows :
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
Here we can also use tensorboard Have a look , Remember import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
stay tensorboard It's like this in English 
open NetWork
You can zoom in to see 
Neural networks have errors , So we use gradient descent to reduce the error
The code is as follows
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)## Using gradient descent as the optimizer
for epoch in range(20):## loop 20 Time
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()## Set the value of each drop to zero
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
My computer's GPU yes RTX2060 It belongs to the older one , It took about three times 1 minute , It was so slow that I finished running
Output results :
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
It can be seen that the error is getting smaller and smaller , But in the application 20 There are too few layers , When my new computer arrived, I ran 100 layer
边栏推荐
- What is the bandwidth of the Tiktok server that can be used by hundreds of millions of people at the same time?
- Chenglian premium products donated love materials for flood fighting and disaster relief to Yingde
- Anniversary party
- Anniversary party
- Wasmedge 0.10.0 release! New plug-in extension mechanism, socket API enhancement, llvm 14 support
- CloudCompare&PCL 点云SVD分解
- [.Net6] GRP server and client development cases, as well as the access efficiency duel between the minimum API service, GRP service and traditional webapi service
- VMware Workstation related issues
- [learning notes] matroid
- B_ QuRT_ User_ Guide(29)
猜你喜欢

JS rounding tips

What is the bandwidth of the Tiktok server that can be used by hundreds of millions of people at the same time?
![[untitled]](/img/bb/213f213c695795daecb81a4cf2adcd.jpg)
[untitled]

TCP那点事
![DELL R730服务器开机报错:[XXX] usb 1-1-port4: disabled by hub (EMI?), re-enabling...](/img/90/425965ca4b3df3656ce2a5f4230c4b.jpg)
DELL R730服务器开机报错:[XXX] usb 1-1-port4: disabled by hub (EMI?), re-enabling...

广州:金融新活水 文企新机遇

Chenglian premium products donated love materials for flood fighting and disaster relief to Yingde

Two tips for block level elements

Comment supprimer le crosstalk SiC MOSFET?

Reverse mapping of anonymous pages
随机推荐
Loss loss function
Redis deployment under Linux & redis startup
Trailing Zeroes (II)
Kubernetes notes and the latest k3s installation introduction
Set<String>
Loss损失函数
Cloudcompare & PCL point cloud SVD decomposition
TCP
【Go ~ 0到1 】 第一天 6月24 变量,条件判断 循环语句
JS rounding tips
个人究竟如何开户炒股?在线开户安全么?
[introduction to SQL for 10 days] day4 Combined Query & specified selection
Where is CentOS mysql5.5 configuration file
Force buckle 1024 video splicing
Discussion on the application of GIS 3D system in mining industry
RAC enable archive log
Set the encoding of CMD to UTF-8
The maximum number of Rac open file descriptors, and the processing of hard check failure
Wasmedge 0.10.0 release! New plug-in extension mechanism, socket API enhancement, llvm 14 support
Large current and frequency range that can be measured by Rogowski coil