当前位置:网站首页>PyTorch基础知识(可入门)
PyTorch基础知识(可入门)
2022-07-29 05:21:00 【Quinn-ntmy】
PyTorch基础
- 核心是张量Tensor,一种多维数据的数学对象。
- 使用torch包创建张量,具体的很简单,不多说了。
(1)如果想用特定的值填充张量,可以使用 fill_() 方法。【任何带有下划线(_)的PyTorch方法都是指原位操作,即不用创建新对象就地修改内容】
(2)当使用torch.Tensor构造函数时,默认张量类型是torch.FloatTensor。可以在使用时进行类型转换 或 利用torch.tensor()中的dtype参数。 - 加减乘除操作都和其他类似。
- torch.transpose():转置操作,很简单,可参考https://blog.csdn.net/qq_50001789/article/details/120451717
- 普通索引和切片:基本上同numpy。
- 复杂索引:
(1) 张量的非连续索引
def describe(x):
print("Type: {}".format(x.type()))
print("Shape/size: {}".format(x.shape))
print("Values: \n{}".format(x))
import torch
# 创建一个2D张量
x = torch.arange(6).view(2, 3)
describe(x)
# 1、获取2D张量的第2个维度且索引号为0和2的张量子集
indices = torch.LongTensor([0, 2])
describe(torch.index_select(x, dim=1, index=indices))
# 2、获取2D张量的第1个维度且索引号为0和0的张量子集
indices = torch.LongTensor([0, 0])
describe(torch.index_select(x, dim=0, index=indices))
# 3、row_indices结果是个tensor([0., 1.]),然后col_indices是tensor([1, 2])
row_indices = torch.arange(2).long()
col_indices = torch.LongTensor([1, 2])
describe(x[row_indices, col_indices])
输出结果:
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[0, 2],
[3, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[0, 1, 2]])
Type: torch.LongTensor
Shape/size: torch.Size([2])
Values:
tensor([1, 5])
Process finished with exit code 0
解释一下最后一个操作(3),
(2)连接张量
import torch
x = torch.arange(6).view(2, 3)
describe(x)
describe(torch.cat([x, x], dim=0))
describe(torch.cat([x, x], dim=1))
describe(torch.stack([x, x]))
结果:
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([4, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 6])
Values:
tensor([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 2, 3])
Values:
tensor([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]])
Process finished with exit code 0
- 张量的requires_grad=True时,可以追踪梯度张量以及梯度函数。这两个东西需要基于梯度的学习“监督学习范式”。
创建张量并计算梯度:
import torch
x = torch.ones(2, 2, requires_grad=True)
describe(x)
print(x.grad is None)
y = (x+2)*(x+5)+3
describe(y)
print(x.grad is None)
z = y.mean()
describe(z)
z.backward()
print(x.grad is None)
结果:
Type: torch.FloatTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
True
Type: torch.FloatTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[21., 21.],
[21., 21.]], grad_fn=<AddBackward0>)
True
Type: torch.FloatTensor
Shape/size: torch.Size([])
Values:
21.0
False
Process finished with exit code 0
梯度是一个值(函数输出对于函数输入的斜率)。
- CUDA张量
(1)创建cuda张量
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.rand(3, 3).to(device)
describe(x)
结果:
cuda
Type: torch.cuda.FloatTensor
Shape/size: torch.Size([3, 3])
Values:
tensor([[0.2369, 0.9929, 0.6972],
[0.1366, 0.0594, 0.0726],
[0.4803, 0.1209, 0.6055]], device='cuda:0')
(2)CUDA张量与CPU绑定张量混合
# 将CUDA张量和CPU绑定张量混合(CUDA和非CUDA对象),y在CPU上,x在GPU上。这种情况就会报错
y = torch.rand(3, 3)
x + y
结果报错:
Traceback (most recent call last):
File "C:/Users/27692/PycharmProjects/Quinn/test/test.py", line 60, in <module>
x + y
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
要想对cuda和非cuda对象进行操作,必须确保它们在同一个设备上:
cpu_device = torch.device("cpu")
y = y.to(cpu_device)
x = x.to(cpu_device)
print(x + y)
结果:
tensor([[0.9204, 0.8360, 0.6289],
[1.2421, 1.4353, 1.2174],
[0.9342, 1.0427, 1.1978]])
Process finished with exit code 0
【练习扩展】
- unsqueeze()函数和squeeze()函数:
a = torch.rand(2, 3)
print(describe(a))
# 在第0维增加1个维度
b = a.unsqueeze(0)
print(describe(b))
# 在第1维增加1个维度
b = a.unsqueeze(1)
print(describe(b))
# 在倒数第2个维度增加1个维度
b = a.unsqueeze(-2)
print(describe(b))
注意看输出结果:!!!张量形状(torch.size)
Type: torch.FloatTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0.2643, 0.1925, 0.2562],
[0.7674, 0.9930, 0.2341]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([1, 2, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562],
[0.7674, 0.9930, 0.2341]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562]],
[[0.7674, 0.9930, 0.2341]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562]],
[[0.7674, 0.9930, 0.2341]]])
None
Process finished with exit code 0
# ======squeeze()======
# 上面是插入/增加,这个是删除/去掉
a = torch.rand(2, 1, 3)
print(describe(a))
# 将第1维去掉
b = a.squeeze(1) # a.squeeze(-2)有同样作用 倒数第二个
print(describe(b))
# 试试-3
b = a.squeeze(-3) # 这块结果不变,因为只有维度为1时才会去掉
print(describe(b))
输出结果:
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.4349, 0.0568, 0.0304]],
[[0.2712, 0.3612, 0.3238]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0.4349, 0.0568, 0.0304],
[0.2712, 0.3612, 0.3238]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.4349, 0.0568, 0.0304]],
[[0.2712, 0.3612, 0.3238]]])
None
Process finished with exit code 0
- 创建一个具有正态分布的张量:
a = torch.rand(3,3)
a.normal_()
边栏推荐
- C # judge whether the user accesses by mobile phone or computer
- [DL] introduction and understanding of tensor
- 并发编程学习笔记 之 ReentrantLock实现原理的探究
- Thinkphp6 pipeline mode pipeline use
- A preliminary study on fastjason's autotype
- 关于Flow的原理解析
- Anr Optimization: cause oom crash and corresponding solutions
- 【Transformer】ACMix:On the Integration of Self-Attention and Convolution
- Flink, the mainstream real-time stream processing computing framework, is the first experience.
- Detailed explanation of atomic operation class atomicinteger in learning notes of concurrent programming
猜你喜欢
主流实时流处理计算框架Flink初体验。
【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
【语义分割】语义分割综述
Is flutter being quietly abandoned? On the future of flutter
通过简单的脚本在Linux环境实现Mysql数据库的定时备份(Mysqldump命令备份)
ASM piling: after learning ASM tree API, you don't have to be afraid of hook anymore
Super simple integration of HMS ml kit to realize parent control
Super simple integration HMS ml kit face detection to achieve cute stickers
Flink connector Oracle CDC synchronizes data to MySQL in real time (oracle19c)
研究生新生培训第二周:卷积神经网络基础
随机推荐
asyncawait和promise的区别
【网络设计】ConvNeXt:A ConvNet for the 2020s
第2周学习:卷积神经网络基础
These process knowledge you must know
【语义分割】语义分割综述
【ML】机器学习模型之PMML--概述
Nifi changed UTC time to CST time
【bug】XLRDError: Excel xlsx file; not supported
Ffmpeg creation GIF expression pack tutorial is coming! Say thank you, brother black fly?
Spring, summer, autumn and winter with Miss Zhang (4)
[DL] introduction and understanding of tensor
虚假新闻检测论文阅读(二):Semi-Supervised Learning and Graph Neural Networks for Fake News Detection
[competition website] collect machine learning / deep learning competition website (continuously updated)
How to obtain openid of wechat applet in uni app project
Are you sure you know the interaction problem of activity?
【数据库】数据库课程设计一一疫苗接种数据库
【pycharm】pycharm远程连接服务器
GAN:生成对抗网络 Generative Adversarial Networks
Ribbon learning notes II
Realize the scheduled backup of MySQL database in Linux environment through simple script (mysqldump command backup)