当前位置:网站首页>神经网络-使用Sequential搭建神经网络
神经网络-使用Sequential搭建神经网络
2022-07-01 04:35:00 【booze-J】
我们以这个神经网络图为例子,来搭建对比看看正常情况搭建神经网络和使用Sequential搭建神经网络的区别,以及搭建神经网络中一些要注意的点。
正常情况下搭建神经网络
搭建神经网络代码:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
# 1.根据网络图搭建网络的时候,有些参数网络图上没给,是需要自己去计算的,像是padding,stride等等
self.conv1 = Conv2d(3,32,5,padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32,32,5,padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32,64,5,padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
# 2.设置这个线性层的时候in_feature和out_feature可能也需要自己算,这个in_feature也可以通过打印flatten来查看
self.linear1 = Linear(1024,64)
self.linear2 = Linear(64,10)
def forward(self,x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
obj = Booze()
print(obj)
'''3.对网络结构进行一个简单的检验'''
input = torch.ones((64,3,32,32))
output = obj(input)
print(output.shape)
上述代码中有一些要注意的点,需要单独的拿出来讲讲。
1. 根据网络图搭建网络的时候,有些参数网络图上没给,是需要自己去计算的,像是padding,stride等等
像是搭建第一个卷积层的时候,就需要自己去计算padding和stride。那么如何计算呢?这个时候我们就要用到官方文档提供的计算公式了。
2.搭建这个线性层的时候in_feature可能也需要自己算,这个in_feature也可以通过打印flatten来查看
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
Flatten可以通过官方文档中的介绍来使用。
# (batch_size,channels,H,W)=(32, 1, 5, 5)
input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
# torch.Size([32, 25]) batch_size=32
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
# torch.Size([160, 5]) batch_size=160
使用Sequential搭建神经网络
搭建神经网络代码:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
obj = Booze()
print(obj)
'''对网络结构进行一个简单的检验'''
input = torch.ones((64,3,32,32))
output = obj(input)
print(output.shape)
'''对网络模型进行可视化'''
writer = SummaryWriter("logs")
writer.add_graph(obj,input)
writer.close()
上述代码中也有一些要注意的点,需要单独的拿出来讲讲。
3.搭建完了网络之后,需要对网络结构进行一个简单的检验
obj = Booze()
print(obj)
'''对网络结构进行一个简单的检验'''
input = torch.ones((64,3,32,32))
output = obj(input)
print(output.shape)
就像上述代码一样,运行之后不会报错就行。
4.网络搭建完了之后,是可以使用tensorboard对网络模型进行可视化的
'''对网络模型进行可视化'''
writer = SummaryWriter("logs")
writer.add_graph(obj,input)
writer.close()
这里用到了add_graph这个方法,具体使用方法可以参考官方文档,其实使用方法和add_images和add_scalar差不多。
显示结果如下:
具体区别
其实看代码就很容易看出来哈。
正常情况:
def __init__(self):
super(Booze, self).__init__()
self.conv1 = Conv2d(3,32,5,padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32,32,5,padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32,64,5,padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
self.linear1 = Linear(1024,64)
self.linear2 = Linear(64,10)
def forward(self,x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
Sequential搭建:
class Booze(nn.Module):
def __init__(self):
super(Booze, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
```
边栏推荐
- Pytest automated testing - compare robotframework framework
- LM小型可编程控制器软件(基于CoDeSys)笔记二十:plc通过驱动器控制步进电机
- How to do the performance pressure test of "Health Code"
- Maixll-Dock 使用方法
- [godot] unity's animator is different from Godot's animplayer
- OSPF notes [multiple access, two multicast addresses with OSPF]
- Advanced application of ES6 modular and asynchronous programming
- Codeworks round 449 (Div. 1) C. Kodori tree template
- 如何看待智慧城市建设中的改变和机遇?
- Pytorch(四) —— 可视化工具 Visdom
猜你喜欢

Cmake selecting compilers and setting compiler options

Question bank and online simulation examination for special operation certificate of G1 industrial boiler stoker in 2022

Simple implementation of slf4j

Strategic suggestions and future development trend of global and Chinese vibration isolator market investment report 2022 Edition

OdeInt與GPU

The index is invalid

Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling

Measurement of quadrature axis and direct axis inductance of three-phase permanent magnet synchronous motor

Dual contractual learning: text classification via label aware data augmentation reading notes

【LeetCode】100. Same tree
随机推荐
Leetcode learning - day 36
LM小型可编程控制器软件(基于CoDeSys)笔记二十:plc通过驱动器控制步进电机
JS rotation chart
Basic usage, principle and details of session
What are permissions? What are roles? What are users?
Knowledge supplement: redis' basic data types and corresponding commands
Introduction of Spock unit test framework and its practice in meituan optimization___ Chapter I
Mallbook: how can hotel enterprises break the situation in the post epidemic era?
什么是权限?什么是角色?什么是用户?
2022 hoisting machinery command registration examination and hoisting machinery command examination registration
Dede collection plug-in does not need to write rules
2022危险化学品生产单位安全生产管理人员题库及答案
After many job hopping, the monthly salary is equal to the annual salary of old colleagues
Pytorch(四) —— 可视化工具 Visdom
Advanced application of ES6 modular and asynchronous programming
Simple implementation of slf4j
TCP/IP 详解(第 2 版) 笔记 / 3 链路层 / 3.4 桥接器与交换机 / 3.4.2 多属性注册协议(Multiple Registration Protocol (MRP))
Use winmtr software to simply analyze, track and detect network routing
OdeInt与GPU
Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling