当前位置:网站首页>AssertionError assert I.ndim == 4 and I.shape[1] == 3
AssertionError assert I.ndim == 4 and I.shape[1] == 3
2022-07-01 04:35:00 【booze-J】
运行代码:
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 注意dataset中transform参数接收的是个对象,所以要加上括号,还有就是之后使用神经网络进行运算的时候需要的数据类型是tensor类型,所以transforms参数要加上。
dataloader = DataLoader(dataset,batch_size=64)
# 搭建一个简单的网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super().__init__()
# 注意这里是创建一个全局变量所以要加上一个self 当out_channels远大于in_channels时需要对原图像进行扩充,也就是padding的值不能设为0了,需要根据公式
self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=(3),stride=1,padding=0)
# 重写forward函数
def forward(self,x):
x = self.conv1(x)
return x
# 初始化网络
obj = Booze()
# 查看网络
print(obj)
''' Booze( (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)) ) '''
writer = SummaryWriter("logs")
step = 0
for data in dataloader:
imgs,targets = data
output = obj(imgs)
# torch.Size([64, 3, 32, 32]) 64张3通道32X32的图片
print(imgs.shape)
# torch.Size([64, 6, 30, 30]) 64张6通道30X30的图片
print(output.shape)
# 使用tensorboard可视化 注意多张图片是要使用add_images而不是add_image
writer.add_images("input",imgs,step)
# 由于output是6通道数的无法显示,直接可视化会报错,所以我们需要对output进行reshape reshape的第二参数中当一个数未知时,你可以填入-1,他会自动帮你计算,为什么会未知呢?因为就是不知道填多少,填64的话肯定不行吧,然后改变通道数相当于把多余的像素给切出来了
writer.add_images("output",output,step)
step+=1
writer.close()
运行代码报错如下:
为什么会出错呢?
原因是我们搭建的神经网络中self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=(3),stride=1,padding=0)
其中out_channels=6就是输出图片的通道数是6通道的,6通道数的图片无法显示,直接使用tensorboard可视化会报错,报错的就是上述代码中的writer.add_images("output",output,step)这一行代码,所以在执行这行代码前需要对output进行reshape,reshape成3通道数的图片。
解决方案output = torch.reshape(output,(-1,3,30,30))这一行代码添加到writer.add_images("output",output,step)之前,reshape的第二参数中当一个数未知时,你可以填入-1,他会自动帮你计算,为什么会未知呢?因为就是不知道填多少,填64的话肯定不行吧,然后改变通道数相当于把多余的像素给切出来了,放到了batch_size中。
# (-1,3,30,30) = (batch_size,channels,H,W)
output = torch.reshape(output,(-1,3,30,30))
writer.add_images("output",output,step)
边栏推荐
- 做网站数据采集,怎么选择合适的服务器呢?
- How do I sort a list of strings in dart- How can I sort a list of strings in Dart?
- 软件研发的十大浪费:研发效能的另一面
- What is uid? What is auth? What is a verifier?
- PgSQL failed to start after installation
- [difficult] sqlserver2008r2, can you recover only some files when recovering the database?
- 总结全了,低代码还需要解决这4点问题
- 2022 question bank and answers for safety production management personnel of hazardous chemical production units
- Maixll dock quick start
- Knowledge supplement: basic usage of redis based on docker
猜你喜欢

slf4j 简单实现

Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation 阅读笔记

Daily algorithm & interview questions, 28 days of special training in large factories - the 13th day (array)

OdeInt與GPU

How to use maixll dock

VIM简易使用教程

About the transmission pipeline of stage in spark

TASK04|數理統計

Threejs opening

Possible problems and solutions of using scroll view to implement slider view
随机推荐
Question bank and online simulation examination for special operation certificate of G1 industrial boiler stoker in 2022
Software testing needs more and more talents. Why do you still not want to take this path?
Seven crimes of counting software R & D Efficiency
Pytorch(二) —— 激活函数、损失函数及其梯度
Common UNIX Operation and maintenance commands of shell
Difference between cookie and session
[pat (basic level) practice] - [simple simulation] 1064 friends
Custom components in applets
Basic exercise of test questions hexadecimal to decimal
How to do the performance pressure test of "Health Code"
Openresty rewrites the location of 302
(12) Somersault cloud case (navigation bar highlights follow)
Caijing 365 stock internal reference | the first IPO of Beijing stock exchange; the subsidiary of the recommended securities firm for gambling and gambling, with a 40% discount
Pytest automated testing - compare robotframework framework
How to choose the right server for website data collection?
Threejs opening
JS rotation chart
LM小型可编程控制器软件(基于CoDeSys)笔记十九:报错does not match the profile of the target
【LeetCode】100. Same tree
TASK04|數理統計