当前位置:网站首页>conv2d详解--在数组和图像中的使用
conv2d详解--在数组和图像中的使用
2022-06-30 23:05:00 【Philo`】
1、环境要求
1、需要安装Pytorch依赖
2、官方文档conv2d
3、图片需要CIFAR10数据集
2、原理讲解
将原始二维数据,通过卷积核进行运算,得到运算结果,具体运算步骤:

通过卷积核,覆盖输入数据,将选中的数据进行相乘后再相加,则得到输出数据
反复计算到最后,得到输出结果
这里只是在卷积核安全覆盖在原始图像上时才进行计算,但也可以继续向四周移动,不是完全覆盖,只要有覆盖即可计算,多出的地方补0即可;
这里也是左右上下移动都是一格一格移动,也可以每次移动两格;
上面说的两种情况,是conv2d中的padding参数和stride参数不是默认值的情况
3、函数要求
函数原型:
参数要求:
最新官网上面要求输入数据为int就行了,这是针对图片数据,在数组数据中,需要tensor数据类型,详细区别见如下例子
- 输入要求是tensor数据类型,并且需要minibatch和输入通道,原始的二维数组没有,需要用reshape进行变换
- 卷积核也是相同的要求
3、例子使用
3.1、数组
代码:
import torch
import torchvision
import torch.nn.functional as F
# 输入数据
input = torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
print("原始input shape",input.shape) # torch.Size([5, 5])
input = torch.reshape(input,(1,1,5,5)) # 进行格式转换,添加前面两个参数,batchsize=1,channel=1,数据是5*5 torch.Size([1, 1, 5, 5])
print("torch.shape后的shape",input.shape)
# 卷积核
kernel = torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
kernel = torch.reshape(kernel,(1,1,3,3))
# 默认卷积使用,padding=0,stride=1
output1 = F.conv2d(input,kernel)
print("默认卷积",output1)
# padding = 1,stride = 1
output2 = F.conv2d(input,kernel,padding = 1,stride = 1)
print("padding = 1,stride = 1",output2)
# padding =1,stride = 2
output3 = F.conv2d(input,kernel,padding =1,stride = 2)
print("padding =1,stride = 2",output3)
输出:
3.2、图片
代码:
import torch
from torch import nn
from torch.nn import Conv2d
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# download=False,我这里的数据集已经下载好了,就不用每次运行的时候都下载一次,可以在第一次的时候,改为True进行下载
# "./datasetvision"为存放路径
# transform=torchvision.transforms.ToTensor() 将图片数据使用torchvision进行格式转换
dataset = torchvision.datasets.CIFAR10("./datasetvision",train=False,
transform=torchvision.transforms.ToTensor(),download=False)
# 数据预处理,batch_size=64表明每次获取的数据个数为64张
dataloader = DataLoader(dataset,batch_size=64)
# 简单神经网络定义
class ConNet(nn.Module):
def __init__(self):
super(ConNet, self).__init__()
# 输入通道 因为是彩色图像RGB 所以输入通道是3层,输出6层,卷积层是3*3
self.conv2d = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
#定义具体函数体
def forward(self,x):
result = self.conv2d(x)
return result
Work = ConNet()
print(Work) # 打印一下神经网络结构: ConNet((conv2d): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))
# 使用tensorboard进行文件夹命名
write = SummaryWriter("logsConv2d")
# data是dataloader中的元组
step = 0
for data in dataloader:
imgs,target = data
# print("原始图像",imgs.shape) #前后差异
# print(output.shape)
write.add_images("input",imgs,step) # 将初始图像放入tensorboard进行对比
output = Work(imgs) # 进行图像卷积
output = torch.reshape(output,(-1,3,30,30)) # 这里因为卷积的时候,将输出通道定义为6个通道,board不知道如何展示,所以使用reshape进行转换
write.add_images("output",output,step)
step = step+1
write.close()
结果:

边栏推荐
- Classic case of multithreading
- Deployment of microservices based on kubernetes platform
- 分享十万级TPS的IM即时通讯综合消息系统的架构
- In depth analysis of Apache bookkeeper series: Part 4 - back pressure
- [Android, kotlin, tflite] mobile device integration depth learning light model tflite (image classification)
- Apache server OpenSSL upgrade
- In depth analysis of Apache bookkeeper series: Part 4 - back pressure
- pytorch 的Conv2d的详细解释
- Achieve secure data sharing among multiple parties and solve the problem of asymmetric information in Inclusive Finance
- AtCoder Beginner Contest 257
猜你喜欢

远程办公期间,项目小组微信群打卡 | 社区征文

2022-06-30:以下golang代码输出什么?A:0;B:2;C:运行错误。 package main import “fmt“ func main() { ints := make

Spark - understand partitioner in one article

"More Ford, more China" saw through the clouds, and the orders of Changan Ford's flagship products exceeded 10000

MIT博士论文 | 优化理论与机器学习实践

How to design test cases

Redis' cache penetration, cache breakdown and cache avalanche
![[450. delete nodes in binary search tree]](/img/fd/bab2f92edeadd16263f15de6cc4420.png)
[450. delete nodes in binary search tree]

How to change the win11 computer name? Win11 method of changing computer name

What if the taskbar is blank after win11 update? Solution to blank and stuck taskbar after win11 update
随机推荐
QQmlApplicationEngine failed to load component qrc:/main. qml:-1 No such file or directory
AtCoder Beginner Contest 257
latex字母头顶两个点
Redis - 01 cache: how to use read cache to improve system performance?
"More Ford, more China" saw through the clouds, and the orders of Changan Ford's flagship products exceeded 10000
Deployment of microservices based on kubernetes platform
What is flush software? In addition, is it safe to open an account online now?
多线程经典案例
Ride: get picture Base64
Some memory problems summarized
Redis - 01 缓存:如何利用读缓存提高系统性能?
企业出海数字化转型解决方案介绍
Qlineedit of QT notes (74) specifies the input type
Fund customer service
E-commerce seckill system
如何区分平台安全和网上炒作?网络投机有哪些止损技巧?
微信小程序通过点击事件传参(data-)
How to ensure the security of our core drawings by drawing encryption
[golang] golang实现截取字符串函数SubStr
Detailed explanation of conv2d of pytorch