当前位置:网站首页>Yolov3 network model building
Yolov3 network model building
2022-07-03 19:03:00 【The little girl is so cute】
The structure diagram is copied from others

import torch
import torch.nn as nn
from collections import OrderedDict
class CBL(nn.Module):
def __init__(self, channel_in, channel_out, ks, p=1, strides=(1, 1)):
super(CBL, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(channel_in, channel_out, ks, padding=p),
nn.BatchNorm2d(channel_out),
nn.LeakyReLU(0.1)
)
def forward(self, x):
return self.block(x)
class ResidualBlock(nn.Module):
# Using a 1x1 Number of convolution descent channels ===> Using a 3x3 Convolution feature extraction ===> Using a 1x1 Number of convolution rising channels
def __init__(self, inp, planes):
""" This class needs to pass in two parameters , One is 【input_channel,output_channel】"""
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inp, planes[0], kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes[0]),
nn.LeakyReLU(0.1))
self.conv2 = nn.Sequential(
nn.Conv2d(planes[0], planes[1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes[1]),
nn.LeakyReLU(0.1))
self.conv3 = nn.Sequential(
nn.Conv2d(planes[1], planes[1], kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes[1]),
nn.LeakyReLU(0.1))
def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
return (x + inputs)
class DarkNet(nn.Module):
def __init__(self, blocks):
super(DarkNet, self).__init__()
self.inp = 32
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=self.inp, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(self.inp),
nn.LeakyReLU(0.1))
self.block1 = nn.Sequential(self._make_residual_layer(planes=[32, 64], block=blocks[0]),
self._make_residual_layer(planes=[64, 128], block=blocks[1]),
self._make_residual_layer(planes=[128, 256], block=blocks[2]))
self.block2 = self._make_residual_layer(planes=[256, 512], block=blocks[3])
self.block3 = self._make_residual_layer(planes=[512, 1024], block=blocks[4])
def forward(self, inputs):
#########################
# 1.Darknet part
#########################
x = self.conv1(inputs)
x = self.block1(x) # Including the first three res
feat1 = x
x = self.block2(x)
feat2 = x
x = self.block3(x)
feat3 = x
return feat1, feat2, feat3
def _make_residual_layer(self, planes, block):
# planes yes 【input_channel,output_channel】,blocks Is the number of times the residual is repeated
layers = []
for i in range(block):
if i < 1: # For the first time, because channel Dissimilarity , Therefore, the number of channels needs to be adjusted
layers.append(
("ds_conv", nn.Conv2d(self.inp, planes[1], kernel_size=3, stride=2, padding=1, bias=False)))
layers.append(("ds_bn", nn.BatchNorm2d(planes[1])))
layers.append(("ds_relu", nn.LeakyReLU(0.1)))
self.inp = planes[1]
else:
layers.append(("residual_{}".format(i), ResidualBlock(self.inp, planes)))
return nn.Sequential(OrderedDict(layers))
class YoloBody(nn.Module):
def __init__(self, num_classes):
super(YoloBody, self).__init__()
self.darknet = DarkNet([1, 2, 8, 8, 4])
self.num_classes = num_classes
self.up1 = nn.Sequential(
CBL(512, 256, 1, 0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.up2 = nn.Sequential(
CBL(256, 128, 1, 0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
def forward(self, inputs):
feat1, feat2, feat3 = self.darknet(inputs)
# print(feat1.shape)
#########################
# 2. Strengthen feature extraction
#########################
big_map, x = self.make_five_conv(feat3, 1024, 512)
big_output = self.output_conv(big_map, 512, (self.num_classes+5)*3)
x = self.up1(x)
x = torch.cat([feat2, x],dim=1)
middel_map, x = self.make_five_conv(x, 768, 256)
middel_output = self.output_conv(middel_map, 256, (self.num_classes+5)*3)
x = self.up2(x)
x = torch.cat([feat1, x], dim=1)
small_map, _ = self.make_five_conv(x, 384, 128)
small_output = self.output_conv(small_map, 128, (self.num_classes+5)*3)
print(big_output.shape, middel_output.shape, small_output.shape)
return (big_output, middel_output, small_output)
def make_five_conv(self, x, channel_in, channel_out):
x = CBL(channel_in, channel_out, ks=1, p=0)(x)
x = CBL(channel_out, channel_out, ks=3, p=1)(x)
x = CBL(channel_out, channel_out, ks=1, p=0)(x)
x = CBL(channel_out, channel_out, ks=3, p=1)(x)
x = CBL(channel_out, channel_out, ks=1, p=0)(x)
return x,x
def output_conv(self, x, channel_in, channel_out):
x = CBL(channel_in, channel_out, ks=3)(x)
y = nn.Conv2d(channel_out, channel_out, kernel_size=1, padding=0)(x)
return y
if __name__ == '__main__':
inputs = torch.zeros(size=(1,3, 416, 416))
print(inputs.shape)
# model = DarkNet(blocks=[1, 2, 8, 8, 4])
model = YoloBody(num_classes=20)
out = model.forward(inputs)
边栏推荐
- Nous avons fait une plateforme intelligente de règlement de détail
- Boost. Asio Library
- math_泰勒公式
- Ego planner code parsing Bspline_ Optimizer section (1)
- 记录在模拟器中运行flutter时报的错
- Implementation of cqrs architecture mode under Kratos microservice framework
- User identity used by startup script and login script in group policy
- application
- 组策略中开机脚本与登录脚本所使用的用户身份
- leetcode:556. 下一个更大元素 III【模拟 + 尽可能少变更】
猜你喜欢

FBI warning: some people use AI to disguise themselves as others for remote interview

Driveseg: dynamic driving scene segmentation data set

【光学】基于matlab介电常数计算【含Matlab源码 1926期】

【水质预测】基于matlab模糊神经网络水质预测【含Matlab源码 1923期】

Does SQL always report foreign key errors when creating tables?

Real time split network (continuous update)

The installation path cannot be selected when installing MySQL 8.0.23

FBI警告:有人利用AI换脸冒充他人身份进行远程面试

KINGS

Getting started with JDBC
随机推荐
Pytorch introduction to deep learning practice notes 13- advanced chapter of cyclic neural network - Classification
NFT new opportunity, multimedia NFT aggregation platform okaleido will be launched soon
Getting started with JDBC
Database creation, addition, deletion, modification and query
22.2.14 -- station B login with code -for circular list form - 'no attribute' - 'needs to be in path selenium screenshot deviation -crop clipping error -bytesio(), etc
leetcode:11. 盛最多水的容器【雙指針 + 貪心 + 去除最短板】
變化是永恒的主題
[leetcode周赛]第300场——6110. 网格图中递增路径的数目-较难
Ego planner code parsing Bspline_ Optimizer section (2)
How to disable the clear button of ie10 insert text box- How can I disable the clear button that IE10 inserts into textboxes?
cipher
leetcode:11. Container with the most water [double pointer + greed + remove the shortest board]
User identity used by startup script and login script in group policy
知其然,而知其所以然,JS 对象创建与继承【汇总梳理】
my. INI file not found
[new year job hopping season] test the technical summary of interviewers' favorite questions (with video tutorials and interview questions)
flask 生成swagger文档
【光学】基于matlab涡旋光产生【含Matlab源码 1927期】
Record: solve the problem that MySQL is not an internal or external command environment variable
Why should the gradient be manually cleared before back propagation in pytorch?