当前位置:网站首页>卷积神经网络模型之——VGG-16网络结构与代码实现
卷积神经网络模型之——VGG-16网络结构与代码实现
2022-07-25 12:43:00 【1 + 1=王】
VGGNet简介
VGG原文:Very deep convolutional networks for large-scale image recognition:https://arxiv.org/pdf/1409.1556.pdf

VGG在2014年由牛津大学Visual GeometryGroup提出,获得该年lmageNet竞赛中Localization Task(定位任务)第一名和 Classification Task (分类任务)第二名。
VGG与AlexNet相比,它采用几个连续的3x3的卷积核代替AlexNet中的较大卷积核。
在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5*5卷积核,从而在保证具有相同感知野的条件下,提升了网络的深度,在一定程度上提升了神经网络的效果。
在论文中,作者尝试了使用5种不同的网络结构,深度分别为11,11,13,16,19,5种结构图如下所示:
其中最常用的是VGG16和VGG19,下面我们就以VGG16为例来分析它的网络结构。
VGG16网络结构
VGG16中的16指的是它由16层组成(13个卷积层 + 3个全连接层,不包括池化层)。
VGG的输入图像大小为224X224X3的三通道彩色图像,共有1000个类别。
其中卷积层的卷积核大小都为3,padding为1;池化层的kernel_size为2,stride为2。
因此
- 卷积层只改变特征图的通道数,不改变大小。(W - 3 + 2*1)/ 1 + 1 = W
- 池化层不改变特征图的通道数,大小变为原来的一半。
VGG具有明显的块结构,VGG可以分为如下六块:
- 两个卷积 + 一个池化:conv3-64+conv3-64 + maxpool
- 两个卷积 + 一个池化:conv3-128+conv3-128+ maxpool
- 三个卷积 + 一个池化:conv3-256+conv3-256+conv3-256+ maxpool
- 三个卷积 + 一个池化:conv3-512+conv3-512+conv3-512+ maxpool
- 三个卷积 + 一个池化:conv3-512+conv3-512+conv3-512+ maxpool
- 三个全连接:fc-4096 + fc-4096 + fc-1000(对应1000个类别)
使用pytorch搭建VGG16
为了便于理解,我们把正向传播过程分为两块,
- 一块为特征提取层(features),包括13个卷积层;
- 另一块为分类层(classify),包括3个全连接层。
features
def make_features(self):
cfgs = [64, 64, 'MaxPool', 128, 128, 'MaxPool', 256, 256, 256, 'MaxPool', 512, 512, 512, 'MaxPool', 512, 512, 512, 'MaxPool']
layers = []
in_channel = 3
for cfg in cfgs:
if cfg == "MaxPool": # 池化层
layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
else:
layers += [nn.Conv2d(in_channels=in_channel,out_channels=cfg,kernel_size=3,padding=1)]
layers += [nn.ReLU(True)]
in_channel = cfg
return nn.Sequential(*layers)
classifier
【注意】:在进行全连接之前,需要现将卷积层输出的三维特征图展平为1维。
x = torch.flatten(x,start_dim=1)
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 1000)
)
完整代码
""" #-*-coding:utf-8-*- # @author: wangyu a beginner programmer, striving to be the strongest. # @date: 2022/7/1 15:01 """
import torch
import torch.nn as nn
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.features = self.make_features()
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 1000)
)
def forward(self,x):
x = self.features(x)
x = torch.flatten(x,start_dim=1)
x = self.classifier(x)
return x
def make_features(self):
cfgs = [64, 64, 'MaxPool', 128, 128, 'MaxPool', 256, 256, 256, 'MaxPool', 512, 512, 512, 'MaxPool', 512, 512, 512, 'MaxPool']
layers = []
in_channel = 3
for cfg in cfgs:
if cfg == "MaxPool": # 池化层
layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
else:
layers += [nn.Conv2d(in_channels=in_channel,out_channels=cfg,kernel_size=3,padding=1)]
layers += [nn.ReLU(True)]
in_channel = cfg
return nn.Sequential(*layers)
net = VGG()
print(net)

边栏推荐
- 感动中国人物刘盛兰
- Plus SBOM: assembly line BOM pbom
- 全球都热炸了,谷歌服务器已经崩掉了
- web安全入门-UDP测试与防御
- A turbulent life
- Cmake learning notes (II) generation and use of Library
- Eccv2022 | transclassp class level grab posture migration
- 【问题解决】org.apache.ibatis.exceptions.PersistenceException: Error building SqlSession.1 字节的 UTF-8 序列的字
- 2022.07.24 (lc_6124_the first letter that appears twice)
- B tree and b+ tree
猜你喜欢

微软提出CodeT:代码生成新SOTA,20个点的性能提升
![Detailed explanation of switch link aggregation [Huawei ENSP]](/img/34/dff118b52404e35f74a8f06b2517be.png)
Detailed explanation of switch link aggregation [Huawei ENSP]
What does the software testing process include? What are the test methods?

程序的内存布局

Kyligence was selected into Gartner 2022 data management technology maturity curve report

web安全入门-UDP测试与防御

Clickhouse notes 03-- grafana accesses Clickhouse

【运维、实施精品】月薪10k+的技术岗位面试技巧

2022.07.24 (lc_6126_design food scoring system)

2022.07.24 (lc_6125_equal row and column pairs)
随机推荐
【AI4Code】《GraphCodeBERT: Pre-Training Code Representations With DataFlow》 ICLR 2021
If you want to do a good job in software testing, you can first understand ast, SCA and penetration testing
MySQL implements inserting data from one table into another table
The larger the convolution kernel, the stronger the performance? An interpretation of replknet model
SSTI 模板注入漏洞总结之[BJDCTF2020]Cookie is so stable
【AI4Code】《Contrastive Code Representation Learning》 (EMNLP 2021)
7行代码让B站崩溃3小时,竟因“一个诡计多端的0”
软件测试流程包括哪些内容?测试方法有哪些?
How to understand metrics in keras
Seven lines of code made station B crash for three hours, but "a scheming 0"
[shutter -- layout] stacked layout (stack and positioned)
使用vsftpd服务传输文件(匿名用户认证、本地用户认证、虚拟用户认证)
AtCoder Beginner Contest 261 F // 树状数组
Eccv2022 | transclassp class level grab posture migration
ORAN专题系列-21:主要的玩家(设备商)以及他们各自的态度、擅长领域
B tree and b+ tree
Leetcode 0133. clone diagram
Use of hystrix
yum和vim须掌握的常用操作
【OpenCV 例程 300篇】239. Harris 角点检测之精确定位(cornerSubPix)