当前位置:网站首页>PyTorch学习(三)
PyTorch学习(三)
2022-06-30 17:28:00 【马少爷】
1、Z-score 标准化(standardization)
严格来说z-score是标准化的操作,有的地方写的归一化(normalization),是错误的说法。1)标准化是通过变换使得数据符合均值为0,方差为1的分布。2)归一化是通过变换使得数据值变到[0, 1] 这个区间中。两者有本质的区别。
1)标准差计算公式:

2)Z-score 标准化计算公式:
注:Z-score 标准化只能使得数据变换为均值为0,方差为1,不会改变原数据的分布
归一化公式:
2、torch.tensor.permute()函数
Permute算子的作用是变换张量数据维度的顺序,举个例子:
data1=torch.randn((3,2,1))
print('data1的数据类型:',type(data1))
print('data1的数据维度:',data1.shape)
print('data1:',data1)
data2=data1.permute(2,1,0)
print('data2的数据类型:',type(data2))
print('data2的数据维度:',data2.shape)
print('data2:',data2)

3、torch.matmul
pytorch中两个张量的乘法可以分为两种:
两个张量对应元素相乘,在PyTorch中可以通过torch.mul函数(或*运算符)实现;
两个张量矩阵相乘,在PyTorch中可以通过torch.matmul函数实现;
torch.matmul()也是一种类似于矩阵相乘操作的tensor连乘操作。但是它可以利用python中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。
若两个tensor都是一维的,则返回两个向量的点积运算结果:
import torch
x = torch.tensor([1,2])
y = torch.tensor([3,4])
print(x,y)
print(torch.matmul(x,y),torch.matmul(x,y).size())

若两个tensor都是二维的,则返回两个矩阵的矩阵相乘结果:
import torch
x = torch.tensor([[1,2],[3,4]])
y = torch.tensor([[5,6,7],[8,9,10]])
print(torch.matmul(x,y),torch.matmul(x,y).size())

4、torch.sum
a = torch.ones((2, 3))
a1 = torch.sum(a)
a2 = torch.sum(a, dim=0)
a3 = torch.sum(a, dim=1)
print(a)
print(a1)
print(a2)
print(a3)

如果加上keepdim=True, 则会保持dim的维度不被squeeze
a1 = torch.sum(a, dim=(0, 1), keepdim=True)
a2 = torch.sum(a, dim=(0, ), keepdim=True)
a3 = torch.sum(a, dim=(1, ), keepdim=True)

5、torch.view
在pytorch中view函数的作用为重构张量的维度,相当于numpy中resize()的功能,但是用法可能不太一样。如下例所示
比如
import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])
print(a.view(1,6))
print(b.view(1,6))

a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.view(-1))
a=torch.Tensor([[[1,2,3],[4,5,6]]])
a=a.view(3,2)
print(a)
a=a.view(2,-1)
print(a)
6、python 中 numpy 模块的 size,shape, len的用法
import numpy as np
X=np.array([[1,2,3,4],
[5,6,7,8],
[9,10,11,12]])
number=X.size # 计算 X 中所有元素的个数
X_row=np.size(X,0) #计算 X 一行元素的个数
X_col=np.size(X,1) #计算 X 一列元素的个数
print("number:",number)
print("X_row:",X_row)
print("X_col:",X_col)

import numpy as np
X=np.array([[1,2,3,4],
[5,6,7,8],
[9,10,11,12]])
X_dim=X.shape # 以元组形式,返回数组的维数
print("X_dim:",X_dim)
print(X.shape[0]) # 输出行的个数
print(X.shape[1]) #输出列的个数

import numpy as np
X=np.array([[1,2,3,4],
[5,6,7,8],
[9,10,11,12]])
length=len(X) #返回对象的长度 不是元素的个数
print("length of X:",length)

7、pytorch之张量的操作:拼接、切分、索引和变换
7.1张量的拼接
torch.cat(tensors, dim=0, out=None)
功能:将张量按维度dim进行拼接
·tensors:张量序列
·dim:要拼接的维度
import torch
t = torch.ones((2,3))
t_0 = torch.cat([t,t], dim=0)#行拼接
t_1 = torch.cat([t,t], dim=1)#列拼接
print('t_0:{} shape:{}\nt_1:{} shape:{}'.format(t_0,t_0.shape,t_1,t_1.shape))

7.2 张量的切分
torch.chunk(input, chunks, dim=0)
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
·input:要切分的张量
·chunks:要切分的份数
·dim:要切分的维度
t = torch.ones((2,5))
list_of_tensors = torch.chunk(t, dim=1, chunks=2)
for idx, mat in enumerate(list_of_tensors):
print('第{}个张量:{}, 维度为{}'.format(idx+1,mat,mat.shape))

7.3 张量索引
torch.index_select(input,dim=0,index=None)
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
·input:要索引的张量
·dim:要索引的维度
·index:要索引数据的序号
t = torch.randint(0,9,size=(3,3))
idx = torch.tensor([0,2], dtype=torch.long) #float
t_select = torch.index_select(t, dim=0, index=idx)#行索引
t_select2 = torch.index_select(t, dim=1, index=idx)#列索引
print('{}\n{}\n{}'.format(t, t_select, t_select2))

torch.masked_select(input, mask, out=None)
功能:按mask中的True进行索引
返回值:一维张量
·input:要索引的张量
·mask:与input同形状的布尔类型张量
t = torch.randint(0,9,size=(3,3))
#返回大小为t的矩阵,其中大于等于5的元素为True,小于5的为False
mask = t.ge(5)
t_select = torch.masked_select(t, mask)
print('t:\n{}\nmask:\n{}\nt_select:\n{}'.format(t,mask,t_select))

7.4 张量变换
torch.reshape(input, shape)
功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存
·input:要变换的张量
·shape:新张量的形状
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1,2,2))
print('t:\n{}\nt_reshape:\n{}'.format(t, t_reshape))
print('t内存地址{}'.format(id(t.data)))
print('t_reshape内存地址{}'.format(id(t_reshape.data)))

torch.transpose(input, dim0, dim1)
功能:交换张量的两个维度
·input:要变换的张量
·dim0:要变换的维度
·dim1:要变换的维度
t = torch.rand((2,3,4))
t_transpose = torch.transpose(t, dim0=1, dim1=2)
print('t shape:{} t_transpose shape:{}'.format(t.shape, t_transpose.shape))

torch.t(input)
功能:2维张量转置,对矩阵而言,等价于torch.transpsoe(input,0,1)
torch.squeeze(input, dim=None, out=None)
功能:压缩长度为1的维度(轴)
·dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;
t = torch.rand((1,2,3,1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)#第二个维度是2故无法压缩掉

torch.usqueeze(input, dim, out=None)
功能:依据dim扩展维度
·dim:扩展的维度
t = torch.rand((1,2,3))
t_sq1 = torch.unsqueeze(t,dim=1)
t_sq2= torch.unsqueeze(t,dim=2)
t_sq3 = torch.unsqueeze(t,dim=3)
print(t.shape)
print(t_sq1.shape)
print(t_sq2.shape)
print(t_sq3.shape)

边栏推荐
- OneFlow源码解析:算子签名的自动推断
- MySQL找不到mysql.sock文件的临时解
- php利用队列解决迷宫问题
- Grep output with multiple colors- Grep output with multiple Colors?
- Glacier teacher's book
- 如何做好软件系统的需求调研,七种武器让你轻松搞定
- Sword finger offer 17 Print from 1 to maximum n digits
- Talk about the SQL server version of DTM sub transaction barrier function
- If you want to learn software testing, you must see series, 2022 software testing engineer's career development
- Geoffrey Hinton: my 50 years of in-depth study and Research on mental skills
猜你喜欢

Do you write API documents or code first?

Solution of enterprise supply chain system in medical industry: realize collaborative visualization of medical digital intelligent supply chain

OneFlow源码解析:算子签名的自动推断

程序员女友给我做了一个疲劳驾驶检测

MySQL advanced - index optimization (super detailed)

Another CVPR 2022 paper was accused of plagiarism, and Ping An insurance researchers sued IBM Zurich team

How to use AI technology to optimize the independent station customer service system? Listen to the experts!

Sword finger offer 17 Print from 1 to maximum n digits

4个技巧告诉你,如何使用SMS促进业务销售?

Communication network electronic billing system based on SSH
随机推荐
Oneortwo bugs in "software testing" are small things, but security vulnerabilities are big things. We must pay attention to them
Talk about the SQL server version of DTM sub transaction barrier function
Openlayers 卷帘地图
【TiDB】TiCDC canal_ Practical application of JSON
Development and construction of NFT mining tour gamefi chain tour system
Flink series: checkpoint tuning
The online procurement system of the electronic components industry accurately matches the procurement demand and leverages the digital development of the electronic industry
100 examples of bug records of unity development (the first example) -- shader failure or bug after packaging
uni-app进阶之自定义【day13】
VScode 状态条 StatusBar
PHP uses queues to solve maze problems
Research on the principle of Tencent persistence framework mmkv
In distributed scenarios, do you know how to generate unique IDs?
Summary of methods for offline installation of chrome extensions in China
【TiDB】TiCDC canal_json的实际应用
LRN local response normalization
iCloud照片无法上传或同步怎么办?
Helping the ultimate experience, best practice of volcano engine edge computing
英飞凌--GTM架构-Generic Timer Module
Solution of enterprise supply chain system in medical industry: realize collaborative visualization of medical digital intelligent supply chain