当前位置:网站首页>Yolov3 network model building
Yolov3 network model building
2022-07-03 19:03:00 【The little girl is so cute】
The structure diagram is copied from others
import torch
import torch.nn as nn
from collections import OrderedDict
class CBL(nn.Module):
def __init__(self, channel_in, channel_out, ks, p=1, strides=(1, 1)):
super(CBL, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(channel_in, channel_out, ks, padding=p),
nn.BatchNorm2d(channel_out),
nn.LeakyReLU(0.1)
)
def forward(self, x):
return self.block(x)
class ResidualBlock(nn.Module):
# Using a 1x1 Number of convolution descent channels ===> Using a 3x3 Convolution feature extraction ===> Using a 1x1 Number of convolution rising channels
def __init__(self, inp, planes):
""" This class needs to pass in two parameters , One is 【input_channel,output_channel】"""
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inp, planes[0], kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes[0]),
nn.LeakyReLU(0.1))
self.conv2 = nn.Sequential(
nn.Conv2d(planes[0], planes[1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes[1]),
nn.LeakyReLU(0.1))
self.conv3 = nn.Sequential(
nn.Conv2d(planes[1], planes[1], kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes[1]),
nn.LeakyReLU(0.1))
def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
return (x + inputs)
class DarkNet(nn.Module):
def __init__(self, blocks):
super(DarkNet, self).__init__()
self.inp = 32
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=self.inp, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(self.inp),
nn.LeakyReLU(0.1))
self.block1 = nn.Sequential(self._make_residual_layer(planes=[32, 64], block=blocks[0]),
self._make_residual_layer(planes=[64, 128], block=blocks[1]),
self._make_residual_layer(planes=[128, 256], block=blocks[2]))
self.block2 = self._make_residual_layer(planes=[256, 512], block=blocks[3])
self.block3 = self._make_residual_layer(planes=[512, 1024], block=blocks[4])
def forward(self, inputs):
#########################
# 1.Darknet part
#########################
x = self.conv1(inputs)
x = self.block1(x) # Including the first three res
feat1 = x
x = self.block2(x)
feat2 = x
x = self.block3(x)
feat3 = x
return feat1, feat2, feat3
def _make_residual_layer(self, planes, block):
# planes yes 【input_channel,output_channel】,blocks Is the number of times the residual is repeated
layers = []
for i in range(block):
if i < 1: # For the first time, because channel Dissimilarity , Therefore, the number of channels needs to be adjusted
layers.append(
("ds_conv", nn.Conv2d(self.inp, planes[1], kernel_size=3, stride=2, padding=1, bias=False)))
layers.append(("ds_bn", nn.BatchNorm2d(planes[1])))
layers.append(("ds_relu", nn.LeakyReLU(0.1)))
self.inp = planes[1]
else:
layers.append(("residual_{}".format(i), ResidualBlock(self.inp, planes)))
return nn.Sequential(OrderedDict(layers))
class YoloBody(nn.Module):
def __init__(self, num_classes):
super(YoloBody, self).__init__()
self.darknet = DarkNet([1, 2, 8, 8, 4])
self.num_classes = num_classes
self.up1 = nn.Sequential(
CBL(512, 256, 1, 0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.up2 = nn.Sequential(
CBL(256, 128, 1, 0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
def forward(self, inputs):
feat1, feat2, feat3 = self.darknet(inputs)
# print(feat1.shape)
#########################
# 2. Strengthen feature extraction
#########################
big_map, x = self.make_five_conv(feat3, 1024, 512)
big_output = self.output_conv(big_map, 512, (self.num_classes+5)*3)
x = self.up1(x)
x = torch.cat([feat2, x],dim=1)
middel_map, x = self.make_five_conv(x, 768, 256)
middel_output = self.output_conv(middel_map, 256, (self.num_classes+5)*3)
x = self.up2(x)
x = torch.cat([feat1, x], dim=1)
small_map, _ = self.make_five_conv(x, 384, 128)
small_output = self.output_conv(small_map, 128, (self.num_classes+5)*3)
print(big_output.shape, middel_output.shape, small_output.shape)
return (big_output, middel_output, small_output)
def make_five_conv(self, x, channel_in, channel_out):
x = CBL(channel_in, channel_out, ks=1, p=0)(x)
x = CBL(channel_out, channel_out, ks=3, p=1)(x)
x = CBL(channel_out, channel_out, ks=1, p=0)(x)
x = CBL(channel_out, channel_out, ks=3, p=1)(x)
x = CBL(channel_out, channel_out, ks=1, p=0)(x)
return x,x
def output_conv(self, x, channel_in, channel_out):
x = CBL(channel_in, channel_out, ks=3)(x)
y = nn.Conv2d(channel_out, channel_out, kernel_size=1, padding=0)(x)
return y
if __name__ == '__main__':
inputs = torch.zeros(size=(1,3, 416, 416))
print(inputs.shape)
# model = DarkNet(blocks=[1, 2, 8, 8, 4])
model = YoloBody(num_classes=20)
out = model.forward(inputs)
边栏推荐
- Typescript official website tutorial
- 041. (2.10) talk about manpower outsourcing
- High concurrency architecture cache
- EGO Planner代碼解析bspline_optimizer部分(1)
- Day-27 database
- Succession of flutter
- [leetcode] [SQL] notes
- Multifunctional web file manager filestash
- [combinatorics] dislocation problem (recursive formula | general term formula | derivation process)*
- Smart wax therapy machine based on STM32 and smart cloud
猜你喜欢
Does SQL always report foreign key errors when creating tables?
[new year job hopping season] test the technical summary of interviewers' favorite questions (with video tutorials and interview questions)
Simulation scheduling problem of SystemVerilog (1)
Why should the gradient be manually cleared before back propagation in pytorch?
Record: writing MySQL commands
[leetcode weekly race] game 300 - 6110 Number of incremental paths in the grid graph - difficult
EGO Planner代码解析bspline_optimizer部分(3)
FBI警告:有人利用AI换脸冒充他人身份进行远程面试
[mathematical modeling] ship three degree of freedom MMG model based on MATLAB [including Matlab source code 1925]
my. INI file not found
随机推荐
Foundation of ActiveMQ
flask 生成swagger文档
Max of PHP FPM_ Some misunderstandings of children
Web3 credential network project galaxy is better than nym?
We have built an intelligent retail settlement platform
Does SQL always report foreign key errors when creating tables?
硬盘监控和分析工具:Smartctl
“google is not defined” when using Google Maps V3 in Firefox remotely
【光学】基于matlab涡旋光产生【含Matlab源码 1927期】
Unity webgl optimization
JS_ Array_ sort
【数学建模】基于matlab船舶三自由度MMG模型【含Matlab源码 1925期】
The more you talk, the more your stupidity will be exposed.
我们做了一个智能零售结算平台
Streaming media server (16) -- figure out the difference between live broadcast and on-demand
Help change the socket position of PCB part
Le changement est un thème éternel
DriveSeg:动态驾驶场景分割数据集
Record: writing MySQL commands
Introduction to SSH Remote execution command