当前位置:网站首页>Xception for deeplab v3+ (including super detailed code comments and original drawing of the paper)
Xception for deeplab v3+ (including super detailed code comments and original drawing of the paper)
2022-07-03 18:21:00 【ZRX_ GIS】
import torch
import torchvision.models as models
from torch import nn
# Depth separates the convolution
class SeparableConv2d_same(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
super(SeparableConv2d_same, self).__init__()
# Deep convolution
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
groups=inplanes, bias=bias)
# Point by point convolution
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
# Deep convolution
x = self.conv1(x)
# Point by point convolution
x = self.pointwise(x)
return x
# xception 20 individual block Plus 5 Bulk ( Start two , End three )
class Block(nn.Module):
# ( Input channel , Output channel , The same species conv Number , step , Void ratio , Whether it is from relu Start , Is it the last (20) modular )
def __init__(self, inplanes, planes, reps, stride=1
, dilation=1, start_with_rule=True, grow_first=True, is_last=False):
super(Block, self).__init__()
# difference block1-3 And block4-19
# The conditions are as follows block1-3, Definition 1X1 Convolution and convolution bn layer
if planes != inplanes or stride != 1:
# 1X1 Convolution , In steps of 2,( Jump before connection ) Convolution down sampling structure
self.skip = nn.Conv2d(
inplanes, planes, 1, stride=stride, bias=False
)
# batch_normal
self.skipbn = nn.BatchNorm2d(planes)
# The condition that is not satisfied is block4-19, Does not define skip
else:
self.skip = None
# Definition relu layer
# inplace = True , Will change the value of the input data , Save space and time for repeated application and memory release , Just pass the original address , More efficient
self.relu = nn.ReLU(inplace=True)
# Definition rep
rep = []
filters = inplanes # Record the number of input channels
# If each group starts , First define a set consisting of relu-sepconv-bn Composed of separable convolution blocks
if grow_first:
rep.append(self.relu) # first floor relu
rep.append(SeparableConv2d_same(
filters, planes, 3, stride=1, dilation=dilation
)) # The second floor 3X3 Separable convolution layer
rep.append(nn.BatchNorm2d(planes)) # The third level bn layer
filters = planes # filters Update to the number of output channels
# Definition rsp-1 One by one relu-sepconv-bn Composed of separable convolution blocks
for i in range(reps - 1):
rep.append(self.relu) # first floor relu
rep.append(SeparableConv2d_same(
filters, filters, 3, stride=1, dilation=dilation
)) # The second floor 3X3 Separable convolution layer
rep.append(nn.BatchNorm2d(filters))
# If not every group starts , Define a definition by relu-sepconv-bn Composed of separable convolution blocks
if not grow_first:
rep.append(self.relu) # first floor relu
rep.append(SeparableConv2d_same(
inplanes, planes, 3, stride=1, dilation=dilation
)) # The second floor 3X3 Separable convolution layer
rep.append(nn.BatchNorm2d(planes)) # The third level bn layer
# Whether to keep the beginning of the block RELU()
if not start_with_rule:
rep = rep[1:]
# Judge whether it is block1-3
if stride != 1:
rep.append(SeparableConv2d_same(planes,
planes, 3, stride=2)) # Definition 3X3 Down sampling separable convolution
# Rewrite the original stride=1 by =2
if stride == 1 and is_last:
self.rep = nn.Sequential(*rep)
def forward(self, input):
# Put the input value in rep in
x = self.rep(input)
# Judge whether there is skip( The logical structure is 32-41 That's ok )
# If , In the block1-3 in , It's time to put skip and bn Instantiate it
if self.skip is not None:
skip = self.skip(input)
skip = self.skipbn(skip)
else:
skip = input
# skip
x += skip
return x
class Xception(nn.Module):
def __init__(self, inplanes=3):
super(Xception, self).__init__()
entry_block3_stride = 2 # block1-3 The lowest step in the
middle_block_dilation = 1 # block4-19 Hollow hole rate
exit_block_dilations = (1, 2) # block20 Hollow hole rate
# block1-3
self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True,
is_last=True)
# block4-19
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True,
grow_first=True)
# block20
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d_same(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1])
self.bn3 = nn.BatchNorm2d(1536)
self.conv4 = SeparableConv2d_same(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1])
self.bn4 = nn.BatchNorm2d(1536)
self.conv5 = SeparableConv2d_same(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1])
self.bn5 = nn.BatchNorm2d(2048)
def forward(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
low_level_feat = x # Quarter size
x = self.block2(x)
x = self.block3(x)
# Middle flow
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.block13(x)
x = self.block14(x)
x = self.block15(x)
x = self.block16(x)
x = self.block17(x)
x = self.block18(x)
x = self.block19(x)
# Exit flow
x = self.block20(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
return x, low_level_feat

边栏推荐
- Redis core technology and practice - learning notes (VI) how to achieve data consistency between master and slave Libraries
- Use of unsafe class
- Lesson 13 of the Blue Bridge Cup -- tree array and line segment tree [exercise]
- [combinatorics] generating function (shift property)
- The number of incremental paths in the grid graph [dfs reverse path + memory dfs]
- What problems can cross-border e-commerce sellers solve with multi platform ERP management system
- Image 24 bits de profondeur à 8 bits de profondeur
- 模块九作业
- What kind of experience is it when the Institute earns 20000 yuan a month?
- Line by line explanation of yolox source code of anchor free series network (6) -- mixup data enhancement
猜你喜欢

Analysis of the reasons why enterprises build their own software development teams to use software manpower outsourcing services at the same time

BFS - topology sort

2022-2028 global scar care product industry research and trend analysis report

Sensor debugging process

On Data Mining

How to deploy applications on kubernetes cluster
![[combinatorics] generating function (generating function application scenario | using generating function to solve recursive equation)](/img/e6/9880e708aed42dc82c94aea002020c.jpg)
[combinatorics] generating function (generating function application scenario | using generating function to solve recursive equation)

Valentine's day, send you a little red flower~

Investigation on the operation prospect of the global and Chinese Anti enkephalinase market and analysis report on the investment strategy of the 14th five year plan 2022-2028

Line by line explanation of yolox source code of anchor free series network (5) -- mosaic data enhancement and mathematical understanding
随机推荐
WebView module manages the application window interface to realize the logical control and management operation of multiple windows (Part 1)
Redis cache avalanche, penetration, breakdown
PHP MySQL inserts data
win32:堆破坏的dump文件分析
[combinatorics] generating function (use generating function to solve the combination number of multiple sets R)
How to deploy applications on kubernetes cluster
Postfix tips and troubleshooting commands
Bloom filter [proposed by bloom in 1970; redis cache penetration solution]
模块九作业
2022-2028 global scar care product industry research and trend analysis report
Administrative division code acquisition
Data analysis is popular on the Internet, and the full version of "Introduction to data science" is free to download
English grammar_ Noun classification
Kotlin's collaboration: Context
win32:堆破壞的dump文件分析
PHP MySQL order by keyword
English grammar_ Adjective / adverb Level 3 - multiple expression
Unsafe类的使用
The number of incremental paths in the grid graph [dfs reverse path + memory dfs]
SDNUOJ1015