当前位置:网站首页>用Pytorch搭建第一個神經網絡且進行優化
用Pytorch搭建第一個神經網絡且進行優化
2022-06-28 08:36:00 【Sol-itude】
最近一直在學習pytorch,這次自己跟著教程搭了一個神經網絡,用的最經典的CIFAR10,先看一下原理
輸入3通道32*32,最後經過3個卷積,3個最大池化,還有1個flatten,和兩個線性化,得到十個輸出
程序如下:
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
這裏我們還可以用tensorboard看一看,記得import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
在tensorboard中是這樣的
打開NetWork
可以放大查看
神經網絡都是有誤差的,所以我們采用梯度下降來减少誤差
代碼如下
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)##利用梯度下降作為優化器
for epoch in range(20):##循環20次
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()##把每一次的下降值歸零
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
我電腦的GPU是RTX2060屬於比較老的了,跑了三遍大概花了1分鐘,實在太慢我就結束運行了
輸出結果:
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
可以看出誤差是在越來越小的,但是在應用中跑20層實在太少了,等我新電腦到了我跑100層
边栏推荐
- [learning notes] linear basis
- [.Net6] GRP server and client development cases, as well as the access efficiency duel between the minimum API service, GRP service and traditional webapi service
- 【云原生 | Kubernetes篇】深入了解Pod(六)
- App automated testing appium Tutorial Part 1 - advanced supplementary content
- [untitled]
- FFMpeg (一) av_register_all()
- PMP从报考到拿证基本操作,了解PMP必看篇
- A - deep sea exploration
- Robot Rapping Results Report
- FatMouse and Cheese
猜你喜欢

Children's unit of 2022 Paris fashion week ended successfully at Wuhan station on June 19

887. egg drop

Redis02 -- an operation command of five data types for ending redis (it can be learned, reviewed, interviewed and collected for backup)

【云原生 | Kubernetes篇】深入了解Pod(六)

About using font icons in placeholder

DB

B_ QuRT_ User_ Guide(28)

Introduction, compilation, installation and deployment of Doris learning notes

Why MySQL cannot insert Chinese data in CMD

罗氏线圈工作原理
随机推荐
Comment supprimer le crosstalk SiC MOSFET?
罗氏线圈工作原理
抖音服务器带宽有多大,才能供上亿人同时刷?
FFMpeg (一) av_register_all()
[reprint] STM32 GPIO type
与普通探头相比,差分探头有哪些优点
yaml json
Loss损失函数
Usage record of Xintang nuc980: self made development board (based on nuc980dk61yc)
After installing NRM, the internal/validators js:124 throw new ERR_ INVALID_ ARG_ TYPE(name, ‘string‘, value)
开户券商怎么选择?网上开户是否安全么?
Children's unit of 2022 Paris fashion week ended successfully at Wuhan station on June 19
用Pytorch搭建第一个神经网络且进行优化
How do people over 40 allocate annuity insurance? Which product is more suitable?
The micro kernel zephyr is supported by many manufacturers!
[learning notes] simulation
VMware Workstation related issues
【云原生 | Kubernetes篇】深入了解Pod(六)
利尔达低代码数据大屏,铲平数据应用开发门槛
Unity - use of API related to Pico development input system ---c