当前位置:网站首页>用Pytorch搭建第一個神經網絡且進行優化
用Pytorch搭建第一個神經網絡且進行優化
2022-06-28 08:36:00 【Sol-itude】
最近一直在學習pytorch,這次自己跟著教程搭了一個神經網絡,用的最經典的CIFAR10,先看一下原理
輸入3通道32*32,最後經過3個卷積,3個最大池化,還有1個flatten,和兩個線性化,得到十個輸出
程序如下:
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
這裏我們還可以用tensorboard看一看,記得import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
在tensorboard中是這樣的
打開NetWork
可以放大查看
神經網絡都是有誤差的,所以我們采用梯度下降來减少誤差
代碼如下
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)##利用梯度下降作為優化器
for epoch in range(20):##循環20次
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()##把每一次的下降值歸零
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
我電腦的GPU是RTX2060屬於比較老的了,跑了三遍大概花了1分鐘,實在太慢我就結束運行了
輸出結果:
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
可以看出誤差是在越來越小的,但是在應用中跑20層實在太少了,等我新電腦到了我跑100層
边栏推荐
- App automated testing appium Tutorial Part 1 - advanced supplementary content
- 爱分析发布《2022爱分析 · IT运维厂商全景报告》 安超云强势入选!
- In flood fighting and disaster relief, the city donated 100000 yuan of love materials to help Yingde
- Kali Notes(1)
- CloudCompare&PCL 点云裁剪(基于封闭曲面或多边形)
- Avframe Memory Management API
- 【无标题】
- Why are function templates not partial specialization?
- 整数划分
- Selenium+chromedriver cannot open Google browser page
猜你喜欢

887. egg drop

AWS saves data on the cloud (3)

罗氏线圈可以测量的大电流和频率范围

VMware Workstation related issues

Kubernetes notes and the latest k3s installation introduction

AWS builds a virtual infrastructure including servers and networks (2)

What are the advantages of a differential probe over a conventional probe

B_ QuRT_ User_ Guide(28)

MySQL8.0 忘记 root 密码

Infinite penetration test
随机推荐
新唐NUC980使用记录:自制开发板(基于NUC980DK61YC)
【转载】STM32 GPIO类型
利尔达低代码数据大屏,铲平数据应用开发门槛
Superimposed ladder diagram and line diagram and merged line diagram and needle diagram
Selenium reptile
Set the icon for the title section of the page
Robot Rapping Results Report
MySQL8.0 忘记 root 密码
爱分析发布《2022爱分析 · IT运维厂商全景报告》 安超云强势入选!
Unity - use of API related to Pico development input system ---c
罗氏线圈工作原理
抖音服务器带宽有多大,才能供上亿人同时刷?
How do people over 40 allocate annuity insurance? Which product is more suitable?
TCP
Reverse mapping of anonymous pages
Kubernetes notes and the latest k3s installation introduction
Oracle view all tablespaces in the current library
Preparation for Oracle 11g RAC deployment on centos7
Super Jumping! Jumping! Jumping!
Solution: selenium common. exceptions. WebDriverException: Message: ‘chromedriver‘ execu