当前位置:网站首页>卷积神经网络模型之——VGG-16网络结构与代码实现
卷积神经网络模型之——VGG-16网络结构与代码实现
2022-07-25 12:43:00 【1 + 1=王】
VGGNet简介
VGG原文:Very deep convolutional networks for large-scale image recognition:https://arxiv.org/pdf/1409.1556.pdf

VGG在2014年由牛津大学Visual GeometryGroup提出,获得该年lmageNet竞赛中Localization Task(定位任务)第一名和 Classification Task (分类任务)第二名。
VGG与AlexNet相比,它采用几个连续的3x3的卷积核代替AlexNet中的较大卷积核。
在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5*5卷积核,从而在保证具有相同感知野的条件下,提升了网络的深度,在一定程度上提升了神经网络的效果。
在论文中,作者尝试了使用5种不同的网络结构,深度分别为11,11,13,16,19,5种结构图如下所示:
其中最常用的是VGG16和VGG19,下面我们就以VGG16为例来分析它的网络结构。
VGG16网络结构
VGG16中的16指的是它由16层组成(13个卷积层 + 3个全连接层,不包括池化层)。
VGG的输入图像大小为224X224X3的三通道彩色图像,共有1000个类别。
其中卷积层的卷积核大小都为3,padding为1;池化层的kernel_size为2,stride为2。
因此
- 卷积层只改变特征图的通道数,不改变大小。(W - 3 + 2*1)/ 1 + 1 = W
- 池化层不改变特征图的通道数,大小变为原来的一半。
VGG具有明显的块结构,VGG可以分为如下六块:
- 两个卷积 + 一个池化:conv3-64+conv3-64 + maxpool
- 两个卷积 + 一个池化:conv3-128+conv3-128+ maxpool
- 三个卷积 + 一个池化:conv3-256+conv3-256+conv3-256+ maxpool
- 三个卷积 + 一个池化:conv3-512+conv3-512+conv3-512+ maxpool
- 三个卷积 + 一个池化:conv3-512+conv3-512+conv3-512+ maxpool
- 三个全连接:fc-4096 + fc-4096 + fc-1000(对应1000个类别)
使用pytorch搭建VGG16
为了便于理解,我们把正向传播过程分为两块,
- 一块为特征提取层(features),包括13个卷积层;
- 另一块为分类层(classify),包括3个全连接层。
features
def make_features(self):
cfgs = [64, 64, 'MaxPool', 128, 128, 'MaxPool', 256, 256, 256, 'MaxPool', 512, 512, 512, 'MaxPool', 512, 512, 512, 'MaxPool']
layers = []
in_channel = 3
for cfg in cfgs:
if cfg == "MaxPool": # 池化层
layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
else:
layers += [nn.Conv2d(in_channels=in_channel,out_channels=cfg,kernel_size=3,padding=1)]
layers += [nn.ReLU(True)]
in_channel = cfg
return nn.Sequential(*layers)
classifier
【注意】:在进行全连接之前,需要现将卷积层输出的三维特征图展平为1维。
x = torch.flatten(x,start_dim=1)
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 1000)
)
完整代码
""" #-*-coding:utf-8-*- # @author: wangyu a beginner programmer, striving to be the strongest. # @date: 2022/7/1 15:01 """
import torch
import torch.nn as nn
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.features = self.make_features()
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.5),
nn.Linear(4096, 1000)
)
def forward(self,x):
x = self.features(x)
x = torch.flatten(x,start_dim=1)
x = self.classifier(x)
return x
def make_features(self):
cfgs = [64, 64, 'MaxPool', 128, 128, 'MaxPool', 256, 256, 256, 'MaxPool', 512, 512, 512, 'MaxPool', 512, 512, 512, 'MaxPool']
layers = []
in_channel = 3
for cfg in cfgs:
if cfg == "MaxPool": # 池化层
layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
else:
layers += [nn.Conv2d(in_channels=in_channel,out_channels=cfg,kernel_size=3,padding=1)]
layers += [nn.ReLU(True)]
in_channel = cfg
return nn.Sequential(*layers)
net = VGG()
print(net)

边栏推荐
- Go: Gin custom log output format
- 2022 Henan Mengxin League game (3): Henan University I - Travel
- Kyligence 入选 Gartner 2022 数据管理技术成熟度曲线报告
- Mysql 远程连接权限错误1045问题
- 网络空间安全 渗透攻防9(PKI)
- Perf performance debugging
- 部署Apache网站服务以及访问控制的实现
- If you want to do a good job in software testing, you can first understand ast, SCA and penetration testing
- Emqx cloud update: more parameters are added to log analysis, which makes monitoring, operation and maintenance easier
- Seven lines of code made station B crash for three hours, but "a scheming 0"
猜你喜欢

JS 将伪数组转换成数组

Chapter5 : Deep Learning and Computational Chemistry

If you want to do a good job in software testing, you can first understand ast, SCA and penetration testing

Kyligence was selected into Gartner 2022 data management technology maturity curve report

2022.07.24 (lc_6125_equal row and column pairs)

Atcoder beginer contest 261e / / bitwise thinking + DP

【C语言进阶】动态内存管理

485 communication (detailed explanation)

The programmer's father made his own AI breast feeding detector to predict that the baby is hungry and not let the crying affect his wife's sleep

How to use causal inference and experiments to drive user growth| July 28 tf67
随机推荐
零基础学习CANoe Panel(14)——二极管( LED Control )和液晶屏(LCD Control)
Deployment of Apache website services and implementation of access control
Alibaba cloud technology expert Qin long: reliability assurance is a must - how to carry out chaos engineering on the cloud?
使用vsftpd服务传输文件(匿名用户认证、本地用户认证、虚拟用户认证)
go : gin 自定义日志输出格式
Interviewer: "classmate, have you ever done a real landing project?"
2022.07.24 (lc_6124_the first letter that appears twice)
全球都热炸了,谷歌服务器已经崩掉了
Azure Devops (XIV) use azure's private nuget warehouse
ECCV 2022 | 登顶SemanticKITTI!基于二维先验辅助的激光雷达点云语义分割
Word style and multi-level list setting skills (II)
Microsoft proposed CodeT: a new SOTA for code generation, with 20 points of performance improvement
Requirements specification template
Emqx cloud update: more parameters are added to log analysis, which makes monitoring, operation and maintenance easier
Implementation of recommendation system collaborative filtering in spark
感动中国人物刘盛兰
B树和B+树
LeetCode 1184. 公交站间的距离
零基础学习CANoe Panel(16)—— Clock Control/Panel Control/Start Stop Control/Tab Control
The programmer's father made his own AI breast feeding detector to predict that the baby is hungry and not let the crying affect his wife's sleep