当前位置:网站首页>卷积神经网络模型之——LeNet网络结构与代码实现
卷积神经网络模型之——LeNet网络结构与代码实现
2022-07-04 21:39:00 【1 + 1=王】
文章目录
LeNet简介
LeNet原文地址:https://ieeexplore.ieee.org/abstract/document/726791
LeNet可以说是卷积神经网络的“HelloWorld”,它通过巧妙的设计,利用卷积、池化等操作提取特征,再使用全连接神经网络进行分类。
Lenet是一个 7 层的神经网络(不包含输入层),包含 3 个卷积层,2 个池化层,2 个全连接层。它的网络结构图如下所示:
LeNet7层结构
第0层:输入
输入的原始图像大小是32×32像素的3通道图像。
C1:第一个卷积层
C1是一个卷积层,卷积核大小为5,输入大小为3X32X32,输出特征图大小为16X28X28。
self.conv1 = nn.Conv2d(3, 16, 5)
torch.nn.Conv2d()的参数解释如下:
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
- in_channels (int) – Number of channels in the input image
- out_channels (int) – Number of channels produced by the convolution
- kernel_size (int or tuple) – Size of the convolving kernel
- stride (int or tuple, optional) – Stride of the convolution. Default: 1
- padding (int, tuple or str, optional) – Padding added to all four sides of the input. Default: 0
- padding_mode (string, optional) – ‘zeros’, ‘reflect’, ‘replicate’ or ‘circular’. Default: ‘zeros’
- dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1
- groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1
- bias (bool, optional) – If True, adds a learnable bias to the output. Default: True
经过卷积后的输出大小计算为下:
S2:第一个下采样层
S2是一个池化层,kernel_size为2,stride为2,输入大小为16X28X28,输出特征图大小为16X14X14。
self.pool1 = nn.MaxPool2d(2, 2)
torch.nn.MaxPool2d的参数解释如下:
torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
- kernel_size – the size of the window to take a max over
- stride – the stride of the window. Default value is kernel_size
- padding – implicit zero padding to be added on both sides
- dilation – a parameter that controls the stride of elements in the window
- return_indices – if True, will return the max indices along with the outputs. Useful for torch.nn.MaxUnpool2d later
- ceil_mode – when True, will use ceil instead of floor to compute the output shape
经过池化后的输出大小计算为下:
C3:第2个卷积层
C3是一个卷积层,卷积核大小为5,输入大小为16X14X14,输出特征图大小为32X10X10。
self.conv2 = nn.Conv2d(16, 32, 5)
S4:第2个下采样层
S4是一个池化层,kernel_size为2,stride为2,输入大小为32X10X10,输出特征图大小为32X5X5。
self.pool2 = nn.MaxPool2d(2, 2)
C5:第3个卷积层
C5是一个卷积层,卷积核大小为5,输入大小为32X5X5,输出特征图大小为120X1X1。
此处用全连接层代替
self.fc1 = nn.Linear(32*5*5, 120)
F6:第1个全连接层
F6是一个全连接层,输入大小为120,输出特征图大小为84。
self.fc2 = nn.Linear(120, 84)
F7:第2个全连接层
F7是一个全连接层,输入大小为84,输出特征图大小为10(表示有10种类别)。
self.fc3 = nn.Linear(84, 10)
使用pytorch搭建LeNet
搭建一个网络模型,最少需要分两步进行
1. 创建一个类并继承nn.Module
import torch.nn as nn
# pytorch:通道排序:[N,Channel,Height,Width]
class LeNet(nn.Module):
2. 类中实现两个方法
def __init__(self):定义搭建网络中需要使用的网络层结构
def forward(self, x): 定义正向传播过程
LeNet网络结构pytorch实现:
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5) # C1
self.pool1 = nn.MaxPool2d(2, 2) # S2
self.conv2 = nn.Conv2d(16, 32, 5) # C3
self.pool2 = nn.MaxPool2d(2, 2) # S4
self.fc1 = nn.Linear(32*5*5, 120) # C5(用全连接代替)
self.fc2 = nn.Linear(120, 84) # F6
self.fc3 = nn.Linear(84, 10) # F7
def forward(self, x):
x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 10, 10)
x = self.pool2(x) # output(32, 5, 5)
x = x.view(-1, 32*5*5) # output(32*5*5) 展平
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x
打印LeNet结构如下:
model = LeNet()
print(model)
边栏推荐
- 开户哪家券商比较好?网上开户安全吗
- [optimtool.unconstrained] unconstrained optimization toolbox
- ArcGIS 10.2.2 | solution to the failure of ArcGIS license server to start
- Basic structure of PostgreSQL - table
- MySQL存储数据加密
- From repvgg to mobileone, including mobileone code
- File read write
- Visual task scheduling & drag and drop | scalph data integration based on Apache seatunnel
- Drop down selection of Ehlib database records
- Relational database
猜你喜欢
赋能数字经济 福昕软件出席金砖国家可持续发展高层论坛
开源之夏专访|Apache IoTDB社区 新晋Committer谢其骏
历史最全混合专家(MOE)模型相关精选论文、系统、应用整理分享
Zhiyang innovation signed a cooperation agreement with Huawei to jointly promote the sustainable development of shengteng AI industry
什么是商业智能(BI),就看这篇文章足够了
传智教育|如何转行互联网高薪岗位之一的软件测试?(附软件测试学习路线图)
机器学习笔记 - 互信息Mutual Information
How to remove the black dot in front of the title in word document
湘江鲲鹏加入昇腾万里伙伴计划,与华为续写合作新篇章
i.MX6ULL驱动开发 | 24 - 基于platform平台驱动模型点亮LED
随机推荐
一文掌握数仓中auto analyze的使用
GTEST from ignorance to skillful use (1) GTEST installation
Which securities company is better to open an account? Is online account opening safe
QT - double buffer plot
gtest从一无所知到熟练使用(4)如何用gtest写单元测试
1807. Replace the parentheses in the string
Use of class methods and class variables
B站大量虚拟主播被集体强制退款:收入蒸发,还倒欠B站;乔布斯被追授美国总统自由勋章;Grafana 9 发布|极客头条
AcWing 2022 每日一题
What is the stock account opening process? Is it safe to use flush mobile stock trading software?
# 2156. 查找给定哈希值的子串-后序遍历
【C语言进阶篇】数组&&指针&&数组笔试题
Open3D 曲面法向量计算
How was MP3 born?
close系统调用分析-性能优化
From repvgg to mobileone, including mobileone code
Solve the problem of data disorder caused by slow asynchronous interface
gtest从一无所知到熟练使用(2)什么是测试夹具/装置(test fixture)
做BI开发,为什么一定要熟悉行业和企业业务?
超详细教程,一文入门Istio架构原理及实战应用