当前位置:网站首页>神经网络-使用Sequential搭建神经网络
神经网络-使用Sequential搭建神经网络
2022-07-01 04:35:00 【booze-J】
我们以这个神经网络图为例子,来搭建对比看看正常情况搭建神经网络和使用Sequential搭建神经网络的区别,以及搭建神经网络中一些要注意的点。
正常情况下搭建神经网络
搭建神经网络代码:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
# 1.根据网络图搭建网络的时候,有些参数网络图上没给,是需要自己去计算的,像是padding,stride等等
self.conv1 = Conv2d(3,32,5,padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32,32,5,padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32,64,5,padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
# 2.设置这个线性层的时候in_feature和out_feature可能也需要自己算,这个in_feature也可以通过打印flatten来查看
self.linear1 = Linear(1024,64)
self.linear2 = Linear(64,10)
def forward(self,x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
obj = Booze()
print(obj)
'''3.对网络结构进行一个简单的检验'''
input = torch.ones((64,3,32,32))
output = obj(input)
print(output.shape)
上述代码中有一些要注意的点,需要单独的拿出来讲讲。
1. 根据网络图搭建网络的时候,有些参数网络图上没给,是需要自己去计算的,像是padding,stride等等
像是搭建第一个卷积层的时候,就需要自己去计算padding和stride。那么如何计算呢?这个时候我们就要用到官方文档提供的计算公式了。
2.搭建这个线性层的时候in_feature可能也需要自己算,这个in_feature也可以通过打印flatten来查看
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
Flatten可以通过官方文档中的介绍来使用。
# (batch_size,channels,H,W)=(32, 1, 5, 5)
input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
# torch.Size([32, 25]) batch_size=32
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
# torch.Size([160, 5]) batch_size=160
使用Sequential搭建神经网络
搭建神经网络代码:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
obj = Booze()
print(obj)
'''对网络结构进行一个简单的检验'''
input = torch.ones((64,3,32,32))
output = obj(input)
print(output.shape)
'''对网络模型进行可视化'''
writer = SummaryWriter("logs")
writer.add_graph(obj,input)
writer.close()
上述代码中也有一些要注意的点,需要单独的拿出来讲讲。
3.搭建完了网络之后,需要对网络结构进行一个简单的检验
obj = Booze()
print(obj)
'''对网络结构进行一个简单的检验'''
input = torch.ones((64,3,32,32))
output = obj(input)
print(output.shape)
就像上述代码一样,运行之后不会报错就行。
4.网络搭建完了之后,是可以使用tensorboard对网络模型进行可视化的
'''对网络模型进行可视化'''
writer = SummaryWriter("logs")
writer.add_graph(obj,input)
writer.close()
这里用到了add_graph这个方法,具体使用方法可以参考官方文档,其实使用方法和add_images和add_scalar差不多。
显示结果如下:
具体区别
其实看代码就很容易看出来哈。
正常情况:
def __init__(self):
super(Booze, self).__init__()
self.conv1 = Conv2d(3,32,5,padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32,32,5,padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32,64,5,padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
self.linear1 = Linear(1024,64)
self.linear2 = Linear(64,10)
def forward(self,x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
Sequential搭建:
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
```
边栏推荐
- What are permissions? What are roles? What are users?
- LM小型可编程控制器软件(基于CoDeSys)笔记十九:报错does not match the profile of the target
- The junior college students were angry for 32 days, four rounds of interviews, five hours of soul torture, and won Ali's offer with tears
- Daily algorithm & interview questions, 28 days of special training in large factories - the 13th day (array)
- How do I sort a list of strings in dart- How can I sort a list of strings in Dart?
- [pat (basic level) practice] - [simple simulation] 1064 friends
- 2022年上海市安全员C证考试题模拟考试题库及答案
- Pytorch(四) —— 可视化工具 Visdom
- How to ensure the idempotency of the high concurrency interface?
- Measurement of quadrature axis and direct axis inductance of three-phase permanent magnet synchronous motor
猜你喜欢

Dede collection plug-in does not need to write rules

Tip of edge browser: enter+ctrl can automatically convert the address bar into a web address

Programs and processes, process management, foreground and background processes

Basic usage, principle and details of session

2022 G2 power station boiler stoker examination question bank and G2 power station boiler stoker simulation examination question bank
![[learn C and fly] S1E20: two dimensional array](/img/68/34fad73ff23d3e0719ef364fc60cb5.jpg)
[learn C and fly] S1E20: two dimensional array

2022年上海市安全员C证考试题模拟考试题库及答案

Extension fragment
![[godot] unity's animator is different from Godot's animplayer](/img/51/48f40a7b6736d7f78040eabbbd3395.jpg)
[godot] unity's animator is different from Godot's animplayer

尺取法:有效三角形的个数
随机推荐
(12) Somersault cloud case (navigation bar highlights follow)
Daily question - line 10
How to view the changes and opportunities in the construction of smart cities?
CF1638E colorful operations
How to choose the right server for website data collection?
This sideline workload is small, 10-15k, free unlimited massage
Shell之分析服务器日志命令集锦
Tcp/ip explanation (version 2) notes / 3 link layer / 3.4 bridge and switch / 3.4.2 multiple registration protocol (MRP)
Shell之Unix运维常用命令
[ue4] event distribution mechanism of reflective event distributor and active call event mechanism
Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation 阅读笔记
2022 a special equipment related management (elevator) simulation test and a special equipment related management (elevator) certificate examination
All in all, the low code still needs to solve these four problems
Summary of testing experience - Testing Theory
Threejs opening
js 图片路径转换base64格式
Account sharing technology enables the farmers' market and reshapes the efficiency of transaction management services
JVM栈和堆简介
Programs and processes, process management, foreground and background processes
MySQL advanced -- you will have a new understanding of MySQL