当前位置:网站首页>Pytorch学习笔记--Pytorch常用函数总结1
Pytorch学习笔记--Pytorch常用函数总结1
2022-07-25 15:28:00 【whut_L】
目录
5-最大池化(max_pool2d)和平均池化(avg_pool2d)函数
1-torch.randn()函数
import torch
batch_size = 1
seq_len = 3
input_size = 4
inputs = torch.randn(seq_len, batch_size, input_size) torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数。示例如下:
import torch
print(torch.randn(3, 2, 3, 3))
torch.randn(seq_len, batch_size, input_size):第一个参数seq_len表示序列长度,示例中序列长度为3;第二个参数batch_size表示批大小,示例中批大小为2;第三个参数input_size为输入向量的维度,示例中为(3, 3)。(在RNN中可理解成:示例中,共有3个序列,每个序列分为2批,每批的维度为3*3。)
#####################################
#####################################
2-set()函数和sorted()函数
self.country_list = list(sorted(set(self.countries))) # set()去重,删除重复的数据; sorted()排序set()函数用于删除重复的数据元素;sorted()用于元素的排序,示例如下:
a = ['china', 'china', 'japan']
print(list(set(a)))
print(list(sorted(set(a))))
由于‘c’ < 'j',所以‘china’排在‘japan’前面。
#####################################
#####################################
3-DataLoader()函数和Dataset类
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), #将shape为(H, W, C)的img转为shape为(C, H, W)的tensor,将每一个数值归一化到[0,1]
transforms.Normalize((0.1307, ), (0.3081, )) #按通道进行数据标准化
])
train_dataset = datasets.MNIST(root = '../dataset/mnist/', train = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST(root = '../dataset/mnist/', train = False, download = True, transform = transform)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)DataLoader()函数导入的数据集为Dataset类型,shuffle表示是否打乱数据集。
#####################################
#####################################
4-.t()函数
.t()函数的作用是将Tensor转置,示例如下:
import torch
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(input)
print(input.t())
#####################################
#####################################
5-最大池化(max_pool2d)和平均池化(avg_pool2d)函数
import torch
import torch.nn.functional as F
input = torch.tensor([[[1, 2, 3, 1], [4, 5, 6, 1], [7, 8, 9, 1]]]).unsqueeze(0).float() # unsqueeze(0)在第0个维度前增加一个维度
print(input.size())
output = F.max_pool2d(input, kernel_size = (1, 4))
print(output)
max_pool2d():最大池化操作。根据设定的核大小,选取最大的元素值。示例中,核大小是(1,4),可理解为挑选出每行最大的元素值。
需要说明的是:unsqueeze(0)的作用是在第0维度前扩展一个维度,所以input的size为(1, 1, 3, 4)。
###
import torch
import torch.nn.functional as F
input = torch.randn(1, 1, 4, 4)
print(input.size())
print(input)
output = F.avg_pool2d(input, kernel_size = (2, 2))
print(output)
avg_pool2d():平均池化操作。根据设定的核大小,计算得到核内元素的平均值。
池化的作用:降维;抑制噪声,降低信息冗余;提升模型的尺度不变性、旋转不变形;降低模型计算量;防止过拟合。
边栏推荐
- MySQL transactions and mvcc
- Yan required executor memory is above the max threshold (8192mb) of this cluster!
- JVM-参数配置详解
- ML - 语音 - 深度神经网络模型
- 使用cpolar建立一个商业网站(如何购买域名)
- Geogle Colab笔记1--运行Geogle云端硬盘上的.py文件
- 2021HNCPC-E-差分,思维
- 解决vender-base.66c6fc1c0b393478adf7.js:6 TypeError: Cannot read property ‘validate‘ of undefined问题
- Instance tunnel use
- The difference between Apple buy in and apple pay
猜你喜欢

Solve the timeout of dbeaver SQL client connection Phoenix query

SVD奇异值分解推导及应用与信号恢复

p4552-差分

Implementation of asynchronous FIFO

你准备好脱离“内卷化怪圈”了吗?

Remember that spark foreachpartition once led to oom

How to solve the problem of scanf compilation error in Visual Studio

Distributed principle - what is a distributed system

Spark partition operators partitionby, coalesce, repartition

Spark AQE
随机推荐
ML - 语音 - 语音处理介绍
死锁杂谈
User defined annotation verification API parameter phone number
Submarine cable detector tss350 (I)
Take you to learn more about JS basic grammar (recommended Collection)
matlab---错误使用 var 数据类型无效。第一个输入参数必须为单精度值或双精度值
Implementation of asynchronous FIFO
Understanding the difference between wait() and sleep()
Brain racking CPU context switching
p4552-差分
解决vender-base.66c6fc1c0b393478adf7.js:6 TypeError: Cannot read property ‘validate‘ of undefined问题
PAT甲级1153 Decode Registration Card of PAT (25 分)
JVM-动态字节码技术详解
MySQL优化总结二
Take you to create your first C program (recommended Collection)
CGO is realy Cool!
4PAM在高斯信道与瑞利信道下的基带仿真系统实验
Remember that spark foreachpartition once led to oom
Node learning
HBCK fix problem