当前位置:网站首页>Build the first neural network with pytoch and optimize it
Build the first neural network with pytoch and optimize it
2022-06-28 08:36:00 【Sol-itude】
I've been learning pytorch, This time I built a neural network following the tutorial , The most classic CIFAR10, Let's look at the principle first 
Input 3 passageway 32*32, Last pass 3 A convolution ,3 Maximum pooling , also 1 individual flatten, And two linearizations , Get ten outputs
The procedure is as follows :
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
Here we can also use tensorboard Have a look , Remember import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
stay tensorboard It's like this in English 
open NetWork
You can zoom in to see 
Neural networks have errors , So we use gradient descent to reduce the error
The code is as follows
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)## Using gradient descent as the optimizer
for epoch in range(20):## loop 20 Time
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()## Set the value of each drop to zero
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
My computer's GPU yes RTX2060 It belongs to the older one , It took about three times 1 minute , It was so slow that I finished running
Output results :
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
It can be seen that the error is getting smaller and smaller , But in the application 20 There are too few layers , When my new computer arrived, I ran 100 layer
边栏推荐
- CloudCompare&PCL 点云裁剪(基于封闭曲面或多边形)
- About using font icons in placeholder
- FatMouse and Cheese
- 罗氏线圈可以测量的大电流和频率范围
- 电子元器件销售ERP管理系统哪个比较好?
- Kubernetes notes and the latest k3s installation introduction
- Selenium reptile
- [cloud native | kubernetes] in depth understanding of pod (VI)
- The Falling Leaves
- DELL R730服务器开机报错:[XXX] usb 1-1-port4: disabled by hub (EMI?), re-enabling...
猜你喜欢

Error: `brew cask` is no longer a `brew` command. Use `brew <command> --cask` instead.

WasmEdge 0.10.0 发布!全新的插件扩展机制、Socket API 增强、LLVM 14 支持

找合适的PMP机构只需2步搞定,一查二问

Kali installation configuration

How to suppress SiC MOSFET crosstalk?

抖音服务器带宽有多大,才能供上亿人同时刷?

Quelle est la largeur de bande du serveur de bavardage sonore pour des centaines de millions de personnes en même temps?

Two tips for block level elements

Set the icon for the title section of the page
![[cloud native | kubernetes] in depth understanding of pod (VI)](/img/ae/f16f5c090251ab603b88ddadff7eb3.png)
[cloud native | kubernetes] in depth understanding of pod (VI)
随机推荐
【力扣10天SQL入门】Day5+6 合并表
Modifying the SSH default port when installing Oracle RAC makes CRS unable to install
A - Bi-shoe and Phi-shoe
Case tool
如何抑制SiC MOSFET Crosstalk(串擾)?
Force buckle 1884 Egg drop - two eggs
Goldbach`s Conjecture
The RAC cannot connect to the database normally after modifying the scan IP. The ora-12514 problem is handled
[learning notes] linear basis
叠加阶梯图和线图及合并线图和针状图
Anniversary party
三体攻击(三维拆分加二分)
爱分析发布《2022爱分析 · IT运维厂商全景报告》 安超云强势入选!
Not so Mobile
Privacy computing fat----- offline prediction
Unity - use of API related to Pico development input system ---c
Where is CentOS mysql5.5 configuration file
Usage record of Xintang nuc980: self made development board (based on nuc980dk61yc)
Hidden scroll bar on PC
Quelle est la largeur de bande du serveur de bavardage sonore pour des centaines de millions de personnes en même temps?