当前位置:网站首页>Pytorch two-dimensional multi-channel convolution operation method
Pytorch two-dimensional multi-channel convolution operation method
2022-06-29 15:06:00 【Hopi tongzj】
Fully connected convolution
For single channel convolution , I believe you have seen a lot
So in convolutional neural networks , How to realize the number of image channels “ from 3 To 8” Such a jump ?
Next, take the size 5 × 5 Do some experiments with the convolution kernel of :
import torch
c1, c2, k = 3, 8, 5
# Instantiate two-dimensional convolution , Don't use padding
conv = torch.nn.Conv2d(in_channels=c1, out_channels=c2,
kernel_size=k, bias=False)
# Two dimensional convolution weight: [c2, c1, k, k]
weight = conv.weight.data
print('Weight shape:', weight.shape)Use PyTorch Of Conv2d You can find ,weight Parametric shape yes [8, 3, 5, 5]
[3, 5, 5] Image ( Write it down as img) After the convolution operation, we get [8, 1, 1], Take this as an example , Make a guess :
- img[..., r, c] Of shape by [3, ], That is, it represents the pixel 3 Values for channels
- weight[..., r, c] Of shape by [8, 3], Represents the receptive field of the convolution kernel , The first r Xing di c Column parameters
- Use matrix multiplication weight[..., r, c] × img[..., r, c], The result is shape by [8, ] Tensor , Represents the pixel 8 Values for channels ; If all pixels 8 The values of the channels are added , The convolution result can be obtained

Write the following function for verification :
# Test the images used , After convolution : [c1, k, k] -> [c2, 1, 1]
img = torch.rand([1, c1, k, k])
def torch_conv():
return conv(img).view(-1)
def guess_for_FCconv():
''' Fully connected convolution ( Standard convolution )'''
result = torch.zeros(c2)
# Operate on each pixel
for r in range(k):
for c in range(k):
# Each channel value of the corresponding pixel point : [c1, ]
pixel = img[..., r, c].view(-1)
# Parameters of corresponding pixels in convolution kernel : [c2, c1]
linear = weight[..., r, c]
# The contribution of the pixel to each channel : [c2, c1] × [c1, ] -> [c2, ]
result += linear @ pixel
return result
print('PyTorch:', torch_conv())
print('Guess:', guess_for_FCconv())You can see ,PyTorch The result of the operation is the same as mine , The conjecture holds
Weight shape: torch.Size([8, 3, 5, 5])
PyTorch: tensor([-0.2874, -0.4310, 0.1660, -0.0021, 0.6042, -0.0716, 0.0821, 0.0322],
grad_fn=<ViewBackward>)
Guess: tensor([-0.2874, -0.4310, 0.1660, -0.0021, 0.6042, -0.0716, 0.0821, 0.0322])
Put forward this operation method , To help you better understand convolution
In the actual deployment , I don't know whether to use this method to calculate
Depth separates the convolution
When the number of input channels is 4, The number of output channels is 8 when , Set the number of convolution kernel groups to 2
be weight Parametric shape by [8, 2, 5, 5], It can also be expressed as [8, 4/2, 5, 5]
import torch
c1, c2, k, g = 4, 8, 5, 2
# Instantiate two-dimensional convolution , Don't use padding
conv = torch.nn.Conv2d(in_channels=c1, out_channels=c2,
kernel_size=k, groups=g, bias=False)
# Two dimensional convolution weight: [c2, c1/g, k, k]
weight = conv.weight.data
print('Weight shape:', weight.shape)
# Test the images used , After convolution : [c1, k, k] -> [c2, 1, 1]
img = torch.rand([1, c1, k, k])The number of input channels is grouped , So when the image is calculated , Must be separated on the channel
The number of output channels is not grouped , Is the image separated in the operation ?
[4, 5, 5] Image ( Write it down as img) After the convolution operation, we get [8, 1, 1], Take this as an example , Make a guess :
- Group the number of channels of the image :img It can be expressed as [2, 2, 5, 5], namely 2 Zhang [2, 5, 5] Image
- Yes weight The number of output channels :weight It can be expressed as [2, 4, 2, 5, 5], namely 2 individual weight by [4, 2, 5, 5] All connected convolution of
- Separate use [4, 2, 5, 5] The full connected convolution of can be obtained 2 Zhang [4, 1, 1], After being spliced together, we get [8, 1, 1]

Write the following function for verification :
# Test the images used , After convolution : [c1, k, k] -> [c2, 1, 1]
img = torch.rand([1, c1, k, k])
def torch_conv():
return conv(img).view(-1)
def guess_for_DWConv():
''' Depth separates the convolution '''
# take c2 Channels are represented as g × c2/g
result = torch.zeros([g, c2 // g])
# Group the channels of the image : [1, c1, k, k] -> [g, c1/g, k, k]
img_ = img.view(g, -1, k, k)
# The weight of convolution kernel is extracted by grouping : [c2, c1/g, k, k] -> [g, c2/g, c1/g, k, k]
for i, w in enumerate(weight.view(g, -1, c1//g, k, k)):
# Operate on each pixel
for r in range(k):
for c in range(k):
# Each channel value of the corresponding pixel point : [c1/g, ]
pixel = img_[i, :, r, c].view(-1)
# Parameters of corresponding pixels in convolution kernel : [c2/g, c1/g]
linear = w[..., r, c]
# The contribution of the pixel to each channel : [c2/g, c1/g] × [c1/g, ] -> [c2/g, ]
result[i] += linear @ pixel
return result.view(-1)
print('PyTorch:', torch_conv())
print('Guess:', guess_for_DWConv())Obvious , The result of the two functions is the same , The conjecture holds
Weight shape: torch.Size([8, 2, 5, 5])
PyTorch: tensor([ 0.1674, -0.1527, 0.4059, -0.3422, -0.2362, -0.4508, 0.3286, 0.3232],
grad_fn=<ViewBackward>)
Guess: tensor([ 0.1674, -0.1527, 0.4059, -0.3422, -0.2362, -0.4508, 0.3286, 0.3232])
边栏推荐
- 我 35 岁,可以转行当程序员吗?
- ROS notes (10) - Launch file startup
- 又拍云 Redis 的改进之路
- MCS:离散随机变量——几何分布
- 期货开户可以线下开户吗?在网上开户安全吗?
- 微信公众号—菜单
- Research Report on research and development prospect of China's urea dioxide industry (2022 Edition)
- 仿头条新闻资讯dz模板 Discuz新闻资讯商业版GBK模板源码
- The 5th China software open source innovation competition | opengauss track live training
- 在平面坐标上画斜线
猜你喜欢

携程季报图解:净营收41亿 与疫情前相比已被“腰斩”

Paper study -- accurate accounting of annual total runoff control rate considering the interannual variation characteristics of rainfall

phpcms打开后台首页时向官网发送升级请求觉得卡怎么办?

Solidworks零件图存放位置更改后装配图识别不出来的解决办法

知乎热议:一个程序员的水平能差到什么程度?

Slow bear market, bit Store provides stable stacking products to help you cross the bull and bear

Deploy redis sentry in k8s

Lumiprobe reactive dye - amino dye: cyanine 5 amine

Digital IC code -- traffic lights

Netease strict selection offline data warehouse quality construction practice
随机推荐
中国软冰淇淋市场预测与投资前景研究报告(2022版)
MCS:离散随机变量——Poisson分布
微信公众号—菜单
如果我在佛山,到哪里开户比较好?究竟网上开户是否安全么?
卫龙更新招股书:年营收48亿 创始人刘卫平家族色彩浓厚
mysql 备份与还原
Konva series Tutorial 4: drawing attributes
【Try to Hack】vulnhub DC2
Trigonometric function corresponding to drawing circle on plane coordinate
卫星运动的微分方程
期货开户可以线下开户吗?在网上开户安全吗?
Unity C# 基础复习26——初识委托(P447)
Informatics Olympiad all in one 1003: aligned output
他山之石 | 丁香园 医疗领域图谱的构建与应用
我想在数仓的几百个节点里面 查找一个都有哪些sql里面用到了某张表 能查吗
June 27 talk SofiE
What is the relationship between synchronized and multithreading
Are the top ten domestic securities companies safe?
You need to know about project procurement management
Heavyweight! The latest SCI impact factors were released in 2022, and the ranking of the three famous journals NCS and the top10 of domestic journals has changed (the latest impact factors in 2022 are