当前位置:网站首页>conv2d详解--在数组和图像中的使用
conv2d详解--在数组和图像中的使用
2022-06-30 23:05:00 【Philo`】
1、环境要求
1、需要安装Pytorch依赖
2、官方文档conv2d
3、图片需要CIFAR10数据集
2、原理讲解
将原始二维数据,通过卷积核进行运算,得到运算结果,具体运算步骤:

通过卷积核,覆盖输入数据,将选中的数据进行相乘后再相加,则得到输出数据
反复计算到最后,得到输出结果
这里只是在卷积核安全覆盖在原始图像上时才进行计算,但也可以继续向四周移动,不是完全覆盖,只要有覆盖即可计算,多出的地方补0即可;
这里也是左右上下移动都是一格一格移动,也可以每次移动两格;
上面说的两种情况,是conv2d中的padding参数和stride参数不是默认值的情况
3、函数要求
函数原型:
参数要求:
最新官网上面要求输入数据为int就行了,这是针对图片数据,在数组数据中,需要tensor数据类型,详细区别见如下例子
- 输入要求是tensor数据类型,并且需要minibatch和输入通道,原始的二维数组没有,需要用reshape进行变换
- 卷积核也是相同的要求
3、例子使用
3.1、数组
代码:
import torch
import torchvision
import torch.nn.functional as F
# 输入数据
input = torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
print("原始input shape",input.shape) # torch.Size([5, 5])
input = torch.reshape(input,(1,1,5,5)) # 进行格式转换,添加前面两个参数,batchsize=1,channel=1,数据是5*5 torch.Size([1, 1, 5, 5])
print("torch.shape后的shape",input.shape)
# 卷积核
kernel = torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
kernel = torch.reshape(kernel,(1,1,3,3))
# 默认卷积使用,padding=0,stride=1
output1 = F.conv2d(input,kernel)
print("默认卷积",output1)
# padding = 1,stride = 1
output2 = F.conv2d(input,kernel,padding = 1,stride = 1)
print("padding = 1,stride = 1",output2)
# padding =1,stride = 2
output3 = F.conv2d(input,kernel,padding =1,stride = 2)
print("padding =1,stride = 2",output3)
输出:
3.2、图片
代码:
import torch
from torch import nn
from torch.nn import Conv2d
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# download=False,我这里的数据集已经下载好了,就不用每次运行的时候都下载一次,可以在第一次的时候,改为True进行下载
# "./datasetvision"为存放路径
# transform=torchvision.transforms.ToTensor() 将图片数据使用torchvision进行格式转换
dataset = torchvision.datasets.CIFAR10("./datasetvision",train=False,
transform=torchvision.transforms.ToTensor(),download=False)
# 数据预处理,batch_size=64表明每次获取的数据个数为64张
dataloader = DataLoader(dataset,batch_size=64)
# 简单神经网络定义
class ConNet(nn.Module):
def __init__(self):
super(ConNet, self).__init__()
# 输入通道 因为是彩色图像RGB 所以输入通道是3层,输出6层,卷积层是3*3
self.conv2d = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
#定义具体函数体
def forward(self,x):
result = self.conv2d(x)
return result
Work = ConNet()
print(Work) # 打印一下神经网络结构: ConNet((conv2d): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))
# 使用tensorboard进行文件夹命名
write = SummaryWriter("logsConv2d")
# data是dataloader中的元组
step = 0
for data in dataloader:
imgs,target = data
# print("原始图像",imgs.shape) #前后差异
# print(output.shape)
write.add_images("input",imgs,step) # 将初始图像放入tensorboard进行对比
output = Work(imgs) # 进行图像卷积
output = torch.reshape(output,(-1,3,30,30)) # 这里因为卷积的时候,将输出通道定义为6个通道,board不知道如何展示,所以使用reshape进行转换
write.add_images("output",output,step)
step = step+1
write.close()
结果:

边栏推荐
- 如何区分平台安全和网上炒作?网络投机有哪些止损技巧?
- When unittest automatically tests multiple use cases, the logging module prints repeatedly to solve the problem
- What are database OLAP and OLTP? Same and different? Applicable scenarios
- 206页上海BIM技术应用与发展报告2021
- Based on the open source stream batch integrated data synchronization engine Chunjun data restore DDL parsing module actual combat sharing
- Braces on the left of latex braces in latex multiline formula
- KVM IO performance test data
- Using Obsidian with Hugo, markdown's local editing software is seamlessly connected with online
- What if the taskbar is blank after win11 update? Solution to blank and stuck taskbar after win11 update
- The Sandbox 正在 Polygon 网络上进行部署
猜你喜欢

Two dots on the top of the latex letter

Prospects of world digitalization and machine intelligence in the next decade

Where can I find the computer version of wechat files
![CesiumJS 2022^ 源码解读[6] - 三维模型(ModelExperimental)新架构](/img/ce/519778cd731f814ad111d1e37abd10.png)
CesiumJS 2022^ 源码解读[6] - 三维模型(ModelExperimental)新架构

微信小程序中的数据双向绑定

Zero sample and small sample learning

Fastjson V2 simple user manual

During telecommuting, the project team punched in the wechat group | solicited papers from the community

多线程经典案例

latex中 & 号什么含义?
随机推荐
「团队训练赛」ShanDong Multi-University Training #3
AtCoder Beginner Contest 257
Prospects of world digitalization and machine intelligence in the next decade
多线程经典案例
软件确认测试的内容和流程有哪些?确认测试报告需要多少钱?
一次革命、两股力量、三大环节:《工业能效提升行动计划》背后的“减碳”路线图
Some memory problems summarized
Ride: get picture Base64
基金客户和销售机构
Neo4j load CSV configuration and use
Deployment of microservices based on kubernetes platform
项目管理到底管的是什么?
Swift 5.0 - creation and use of swift framework
How does win11 optimize services? Win11 method of optimizing service
latex字母头顶两个点
MIT doctoral dissertation optimization theory and machine learning practice
How to use dataant to monitor Apache APIs IX
[Android, kotlin, tflite] mobile device integration deep learning light model tflite (object detection)
The sandbox is being deployed on the polygon network
Esp8266 becomes client and server