当前位置:网站首页>用Pytorch搭建第一個神經網絡且進行優化
用Pytorch搭建第一個神經網絡且進行優化
2022-06-28 08:36:00 【Sol-itude】
最近一直在學習pytorch,這次自己跟著教程搭了一個神經網絡,用的最經典的CIFAR10,先看一下原理
輸入3通道32*32,最後經過3個卷積,3個最大池化,還有1個flatten,和兩個線性化,得到十個輸出
程序如下:
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
這裏我們還可以用tensorboard看一看,記得import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
在tensorboard中是這樣的
打開NetWork
可以放大查看
神經網絡都是有誤差的,所以我們采用梯度下降來减少誤差
代碼如下
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)##利用梯度下降作為優化器
for epoch in range(20):##循環20次
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()##把每一次的下降值歸零
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
我電腦的GPU是RTX2060屬於比較老的了,跑了三遍大概花了1分鐘,實在太慢我就結束運行了
輸出結果:
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
可以看出誤差是在越來越小的,但是在應用中跑20層實在太少了,等我新電腦到了我跑100層
边栏推荐
- 开户券商怎么选择?网上开户是否安全么?
- Selenium reptile
- ffmpeg推流报错Failed to update header with correct duration.
- Love analysis released the 2022 love analysis · it operation and maintenance manufacturer panorama report, and an Chao cloud was strongly selected!
- AWS saves data on the cloud (3)
- [learning notes] linear basis
- Children's unit of 2022 Paris fashion week ended successfully at Wuhan station on June 19
- Modifying the SSH default port when installing Oracle RAC makes CRS unable to install
- yaml json
- The maximum number of Rac open file descriptors, and the processing of hard check failure
猜你喜欢
随机推荐
Comment supprimer le crosstalk SiC MOSFET?
Redis02 -- an operation command of five data types for ending redis (it can be learned, reviewed, interviewed and collected for backup)
新唐NUC980使用记录:自制开发板(基于NUC980DK61YC)
与普通探头相比,差分探头有哪些优点
Introduction, compilation, installation and deployment of Doris learning notes
duilib 入门基础十二 样式类
Resolution of Rac grid failure to start after server restart
The preliminary round of the sixth season of 2022 perfect children's model Foshan competition area came to a successful conclusion
Selenium reptile
Little artist huangxinyang was invited to participate in the Wuhan station of children's unit of Paris Fashion Week
Two tips for block level elements
The RAC cannot connect to the database normally after modifying the scan IP. The ora-12514 problem is handled
Quelle est la largeur de bande du serveur de bavardage sonore pour des centaines de millions de personnes en même temps?
IO error in Oracle11g: got minus one from a read call
The maximum number of Rac open file descriptors, and the processing of hard check failure
Installing mysql5.7 under Windows
隐私计算FATE-----离线预测
Large current and frequency range that can be measured by Rogowski coil
[go ~ 0 to 1] the next day, June 25, switch statement, array declaration and traversal
个人究竟如何开户炒股?在线开户安全么?
![[learning notes] matroid](/img/e3/4e003f5d89752306ea901c70230deb.png)








