当前位置:网站首页>用Pytorch搭建第一個神經網絡且進行優化
用Pytorch搭建第一個神經網絡且進行優化
2022-06-28 08:36:00 【Sol-itude】
最近一直在學習pytorch,這次自己跟著教程搭了一個神經網絡,用的最經典的CIFAR10,先看一下原理
輸入3通道32*32,最後經過3個卷積,3個最大池化,還有1個flatten,和兩個線性化,得到十個輸出
程序如下:
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
這裏我們還可以用tensorboard看一看,記得import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
在tensorboard中是這樣的
打開NetWork
可以放大查看
神經網絡都是有誤差的,所以我們采用梯度下降來减少誤差
代碼如下
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)##利用梯度下降作為優化器
for epoch in range(20):##循環20次
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()##把每一次的下降值歸零
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
我電腦的GPU是RTX2060屬於比較老的了,跑了三遍大概花了1分鐘,實在太慢我就結束運行了
輸出結果:
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
可以看出誤差是在越來越小的,但是在應用中跑20層實在太少了,等我新電腦到了我跑100層
边栏推荐
- B_ QuRT_ User_ Guide(26)
- Selenium+chromedriver cannot open Google browser page
- 【Go ~ 0到1 】 第二天 6月25 Switch语句,数组的声明与遍历
- Comment supprimer le crosstalk SiC MOSFET?
- About RAC modifying scan IP
- Chenglian premium products donated love materials for flood fighting and disaster relief to Yingde
- Redis deployment under Linux & redis startup
- DB
- Super Jumping! Jumping! Jumping!
- VMware Workstation related issues
猜你喜欢

抖音服务器带宽有多大,才能供上亿人同时刷?

MySQL8.0 忘记 root 密码

Set the icon for the title section of the page
![[cloud native | kubernetes] in depth understanding of pod (VI)](/img/ae/f16f5c090251ab603b88ddadff7eb3.png)
[cloud native | kubernetes] in depth understanding of pod (VI)

【无标题】

Superimposed ladder diagram and line diagram and merged line diagram and needle diagram

第六届智能家居亚洲峰会暨精品展(Smart Home Asia 2022)将于10月在沪召开

Large current and frequency range that can be measured by Rogowski coil

JS rounding tips

广州:金融新活水 文企新机遇
随机推荐
Tree
罗氏线圈工作原理
NPM clean cache
Operating principle of Rogowski coil
微内核Zephyr获众多厂家支持!
Resolution of Rac grid failure to start after server restart
Oracle RAC -- understanding of VIP
Redis deployment under Linux & redis startup
Basic twelve style classes for duilib
抖音服务器带宽有多大,才能供上亿人同时刷?
Solve NPM err! Unexpected end of JSON input while parsing near
centos mysql5.5配置文件在哪
11grac turn off archive log
How to choose an account opening broker? Is it safe to open an account online?
What are the advantages of a differential probe over a conventional probe
The Falling Leaves
MySQL8.0 忘记 root 密码
Mysql8.0 forgot the root password
The Falling Leaves
[learning notes] differential constraint