当前位置:网站首页>PyTorch⑨---卷积神经网络_线性层
PyTorch⑨---卷积神经网络_线性层
2022-08-02 14:07:00 【伏月三十】
线性层
在卷积神经网络里最后几层,会把卷积层摊开平放到全连接层里计算,然后进入sofmax进行分类。线性层相当于全连接层。
例如在vgg16里:
77512----torch.flatten(imgs)—>114096—Linear(4096,1000)—>111000
import torch.nn
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("dataset_CIFAR10",
train=False,
transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64,drop_last=True)
class Demo(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1=Linear(196608,10) #把全部展开的196608变成10,最后分类的结果是10类
def forward(self,input):
output=self.linear1(input)
return output
demo=Demo()
for data in dataloader:
imgs,targets=data
print(imgs.shape)
output=torch.flatten(imgs)
print(output.shape)
output=demo(output)
print(output.shape)
结果:
边栏推荐
猜你喜欢
随机推荐
使用预训练语言模型进行文本生成的常用微调策略
MapReduce流程
无人驾驶综述:摘要
MySQL知识总结 (五) 锁
vscode编译keil工程,烧录程序
IllegalStateException: Room cannot verify the data integrity. Looks like you've changed schema but
标签加id 和 加号 两个文本框 和一个var 赋值
Ffmpeg交叉编译
MySQL知识总结 (三) 索引
预训练模型 Bert
5.使用RecyclerView优雅的实现瀑布流效果
In the Visual studio code solutions have red wavy lines
Cannot figure out how to save this field into database. You can consider adding a type converter for
Pytorch(16)---搭建一个完整的模型
【目标检测】YOLO v5 吸烟行为识别检测
LLVM系列第二十章:写一个简单的Function Pass
自定义圆形seekBar,超简单
两个surfaceview的重叠效果类似直播效果中的视频和讲义实践
Redis-01-Nosql概述
UIWindow的makeKeyAndVisible不调用rootviewController 的viewDidLoad的问题