当前位置:网站首页>conv2d详解--在数组和图像中的使用
conv2d详解--在数组和图像中的使用
2022-06-30 23:05:00 【Philo`】
1、环境要求
1、需要安装Pytorch依赖
2、官方文档conv2d
3、图片需要CIFAR10数据集
2、原理讲解
将原始二维数据,通过卷积核进行运算,得到运算结果,具体运算步骤:
通过卷积核,覆盖输入数据,将选中的数据进行相乘后再相加,则得到输出数据
反复计算到最后,得到输出结果
这里只是在卷积核安全覆盖在原始图像上时才进行计算,但也可以继续向四周移动,不是完全覆盖,只要有覆盖即可计算,多出的地方补0即可;
这里也是左右上下移动都是一格一格移动,也可以每次移动两格;
上面说的两种情况,是conv2d中的padding参数和stride参数不是默认值的情况
3、函数要求
函数原型:
参数要求:
最新官网上面要求输入数据为int就行了,这是针对图片数据,在数组数据中,需要tensor数据类型,详细区别见如下例子
- 输入要求是tensor数据类型,并且需要minibatch和输入通道,原始的二维数组没有,需要用reshape进行变换
- 卷积核也是相同的要求
3、例子使用
3.1、数组
代码:
import torch
import torchvision
import torch.nn.functional as F
# 输入数据
input = torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
print("原始input shape",input.shape) # torch.Size([5, 5])
input = torch.reshape(input,(1,1,5,5)) # 进行格式转换,添加前面两个参数,batchsize=1,channel=1,数据是5*5 torch.Size([1, 1, 5, 5])
print("torch.shape后的shape",input.shape)
# 卷积核
kernel = torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
kernel = torch.reshape(kernel,(1,1,3,3))
# 默认卷积使用,padding=0,stride=1
output1 = F.conv2d(input,kernel)
print("默认卷积",output1)
# padding = 1,stride = 1
output2 = F.conv2d(input,kernel,padding = 1,stride = 1)
print("padding = 1,stride = 1",output2)
# padding =1,stride = 2
output3 = F.conv2d(input,kernel,padding =1,stride = 2)
print("padding =1,stride = 2",output3)
输出:
3.2、图片
代码:
import torch
from torch import nn
from torch.nn import Conv2d
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# download=False,我这里的数据集已经下载好了,就不用每次运行的时候都下载一次,可以在第一次的时候,改为True进行下载
# "./datasetvision"为存放路径
# transform=torchvision.transforms.ToTensor() 将图片数据使用torchvision进行格式转换
dataset = torchvision.datasets.CIFAR10("./datasetvision",train=False,
transform=torchvision.transforms.ToTensor(),download=False)
# 数据预处理,batch_size=64表明每次获取的数据个数为64张
dataloader = DataLoader(dataset,batch_size=64)
# 简单神经网络定义
class ConNet(nn.Module):
def __init__(self):
super(ConNet, self).__init__()
# 输入通道 因为是彩色图像RGB 所以输入通道是3层,输出6层,卷积层是3*3
self.conv2d = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
#定义具体函数体
def forward(self,x):
result = self.conv2d(x)
return result
Work = ConNet()
print(Work) # 打印一下神经网络结构: ConNet((conv2d): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))
# 使用tensorboard进行文件夹命名
write = SummaryWriter("logsConv2d")
# data是dataloader中的元组
step = 0
for data in dataloader:
imgs,target = data
# print("原始图像",imgs.shape) #前后差异
# print(output.shape)
write.add_images("input",imgs,step) # 将初始图像放入tensorboard进行对比
output = Work(imgs) # 进行图像卷积
output = torch.reshape(output,(-1,3,30,30)) # 这里因为卷积的时候,将输出通道定义为6个通道,board不知道如何展示,所以使用reshape进行转换
write.add_images("output",output,step)
step = step+1
write.close()
结果:
边栏推荐
- [golang] golang实现截取字符串函数SubStr
- Strictly minor spanning tree
- Redis' cache penetration, cache breakdown and cache avalanche
- E-commerce seckill system
- Fund sales code of conduct and information management
- HP 惠普笔记本电脑 禁用触摸板 在插入鼠标后
- Ms17-010 Eternal Blue vulnerability of MSF
- Redis的缓存穿透、缓存击穿和缓存雪崩
- 基金銷售行為規範及信息管理
- How do I open a stock account on my mobile phone? In addition, is it safe to open a mobile account?
猜你喜欢
唯一性索引与逻辑删除冲突问题解决思路
KubeVela 1.4:让应用交付更安全、上手更简单、过程更透明
JMeter cross thread parameter association requires no script
Deployment of microservices based on kubernetes platform
[fundamentals of wireless communication-13]: illustrated mobile communication technology and application development-1-overview
远程办公期间,项目小组微信群打卡 | 社区征文
msf之ms17-010永恒之蓝漏洞
“飞桨+辨影相机”成为AI界的“预制菜”,工业AI质检落地更简单
The Sandbox 正在 Polygon 网络上进行部署
Achieve secure data sharing among multiple parties and solve the problem of asymmetric information in Inclusive Finance
随机推荐
在线客服系统代码_h5客服_对接公众号_支持APP_支持多语言
2022-06-30: what does the following golang code output? A:0; B:2; C: Running error. package main import “fmt“ func main() { ints := make
CNN经典网络模型详解-LeNet-5(pytorch实现)
如何区分平台安全和网上炒作?网络投机有哪些止损技巧?
CTFSHOW权限维持篇
How to develop the exchange system? Mature technology case of digital currency exchange system development
Introduction to digital transformation solutions for enterprises going to sea
Ride: get picture Base64
d编译时计数
How to change the win11 computer name? Win11 method of changing computer name
[golang] golang实现截取字符串函数SubStr
「团队训练赛」ShanDong Multi-University Training #3
Flitter - sort list sort
How to ensure the security of our core drawings by drawing encryption
Based on the open source stream batch integrated data synchronization engine Chunjun data restore DDL parsing module actual combat sharing
How to open a stock account? Is it safe to open a mobile account
[450. delete nodes in binary search tree]
基金客户和销售机构
The sandbox is being deployed on the polygon network
Deployment of microservices based on kubernetes platform