当前位置:网站首页>PyTorch基础知识(可入门)
PyTorch基础知识(可入门)
2022-07-29 05:21:00 【Quinn-ntmy】
PyTorch基础
- 核心是张量Tensor,一种多维数据的数学对象。
- 使用torch包创建张量,具体的很简单,不多说了。
(1)如果想用特定的值填充张量,可以使用 fill_() 方法。【任何带有下划线(_)的PyTorch方法都是指原位操作,即不用创建新对象就地修改内容】
(2)当使用torch.Tensor构造函数时,默认张量类型是torch.FloatTensor。可以在使用时进行类型转换 或 利用torch.tensor()中的dtype参数。 - 加减乘除操作都和其他类似。
- torch.transpose():转置操作,很简单,可参考https://blog.csdn.net/qq_50001789/article/details/120451717
- 普通索引和切片:基本上同numpy。
- 复杂索引:
(1) 张量的非连续索引
def describe(x):
print("Type: {}".format(x.type()))
print("Shape/size: {}".format(x.shape))
print("Values: \n{}".format(x))
import torch
# 创建一个2D张量
x = torch.arange(6).view(2, 3)
describe(x)
# 1、获取2D张量的第2个维度且索引号为0和2的张量子集
indices = torch.LongTensor([0, 2])
describe(torch.index_select(x, dim=1, index=indices))
# 2、获取2D张量的第1个维度且索引号为0和0的张量子集
indices = torch.LongTensor([0, 0])
describe(torch.index_select(x, dim=0, index=indices))
# 3、row_indices结果是个tensor([0., 1.]),然后col_indices是tensor([1, 2])
row_indices = torch.arange(2).long()
col_indices = torch.LongTensor([1, 2])
describe(x[row_indices, col_indices])
输出结果:
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[0, 2],
[3, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[0, 1, 2]])
Type: torch.LongTensor
Shape/size: torch.Size([2])
Values:
tensor([1, 5])
Process finished with exit code 0
解释一下最后一个操作(3),
(2)连接张量
import torch
x = torch.arange(6).view(2, 3)
describe(x)
describe(torch.cat([x, x], dim=0))
describe(torch.cat([x, x], dim=1))
describe(torch.stack([x, x]))
结果:
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([4, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 6])
Values:
tensor([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 2, 3])
Values:
tensor([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]])
Process finished with exit code 0
- 张量的requires_grad=True时,可以追踪梯度张量以及梯度函数。这两个东西需要基于梯度的学习“监督学习范式”。
创建张量并计算梯度:
import torch
x = torch.ones(2, 2, requires_grad=True)
describe(x)
print(x.grad is None)
y = (x+2)*(x+5)+3
describe(y)
print(x.grad is None)
z = y.mean()
describe(z)
z.backward()
print(x.grad is None)
结果:
Type: torch.FloatTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
True
Type: torch.FloatTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[21., 21.],
[21., 21.]], grad_fn=<AddBackward0>)
True
Type: torch.FloatTensor
Shape/size: torch.Size([])
Values:
21.0
False
Process finished with exit code 0
梯度是一个值(函数输出对于函数输入的斜率)。
- CUDA张量
(1)创建cuda张量
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.rand(3, 3).to(device)
describe(x)
结果:
cuda
Type: torch.cuda.FloatTensor
Shape/size: torch.Size([3, 3])
Values:
tensor([[0.2369, 0.9929, 0.6972],
[0.1366, 0.0594, 0.0726],
[0.4803, 0.1209, 0.6055]], device='cuda:0')
(2)CUDA张量与CPU绑定张量混合
# 将CUDA张量和CPU绑定张量混合(CUDA和非CUDA对象),y在CPU上,x在GPU上。这种情况就会报错
y = torch.rand(3, 3)
x + y
结果报错:
Traceback (most recent call last):
File "C:/Users/27692/PycharmProjects/Quinn/test/test.py", line 60, in <module>
x + y
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
要想对cuda和非cuda对象进行操作,必须确保它们在同一个设备上:
cpu_device = torch.device("cpu")
y = y.to(cpu_device)
x = x.to(cpu_device)
print(x + y)
结果:
tensor([[0.9204, 0.8360, 0.6289],
[1.2421, 1.4353, 1.2174],
[0.9342, 1.0427, 1.1978]])
Process finished with exit code 0
【练习扩展】
- unsqueeze()函数和squeeze()函数:
a = torch.rand(2, 3)
print(describe(a))
# 在第0维增加1个维度
b = a.unsqueeze(0)
print(describe(b))
# 在第1维增加1个维度
b = a.unsqueeze(1)
print(describe(b))
# 在倒数第2个维度增加1个维度
b = a.unsqueeze(-2)
print(describe(b))
注意看输出结果:!!!张量形状(torch.size)
Type: torch.FloatTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0.2643, 0.1925, 0.2562],
[0.7674, 0.9930, 0.2341]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([1, 2, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562],
[0.7674, 0.9930, 0.2341]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562]],
[[0.7674, 0.9930, 0.2341]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562]],
[[0.7674, 0.9930, 0.2341]]])
None
Process finished with exit code 0
# ======squeeze()======
# 上面是插入/增加,这个是删除/去掉
a = torch.rand(2, 1, 3)
print(describe(a))
# 将第1维去掉
b = a.squeeze(1) # a.squeeze(-2)有同样作用 倒数第二个
print(describe(b))
# 试试-3
b = a.squeeze(-3) # 这块结果不变,因为只有维度为1时才会去掉
print(describe(b))
输出结果:
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.4349, 0.0568, 0.0304]],
[[0.2712, 0.3612, 0.3238]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0.4349, 0.0568, 0.0304],
[0.2712, 0.3612, 0.3238]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.4349, 0.0568, 0.0304]],
[[0.2712, 0.3612, 0.3238]]])
None
Process finished with exit code 0
- 创建一个具有正态分布的张量:
a = torch.rand(3,3)
a.normal_()
边栏推荐
- MySql统计函数COUNT详解
- FFmpeg创作GIF表情包教程来了!赶紧说声多谢乌蝇哥?
- Show profiles of MySQL is used.
- 【Transformer】AdaViT: Adaptive Tokens for Efficient Vision Transformer
- Win10+opencv3.2+vs2015 configuration
- Spring, summer, autumn and winter with Miss Zhang (3)
- 【Clustrmaps】访客统计
- 【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
- ABSA1: Attentional Encoder Network for Targeted Sentiment Classification
- asyncawait和promise的区别
猜你喜欢

Android Studio 实现登录注册-源代码 (连接MySql数据库)

第三周周报 ResNet+ResNext

ROS教程(Xavier)
![[DL] introduction and understanding of tensor](/img/d8/a367c26b51d9dbaf53bf4fe2a13917.png)
[DL] introduction and understanding of tensor

虚假新闻检测论文阅读(三):Semi-supervised Content-based Detection of Misinformation via Tensor Embeddings

【网络设计】ConvNeXt:A ConvNet for the 2020s

【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers

Reporting service 2016 custom authentication
![[go] use of defer](/img/10/9e4e1c593870450c381a154f31ebef.png)
[go] use of defer

【Transformer】AdaViT: Adaptive Tokens for Efficient Vision Transformer
随机推荐
Android studio login registration - source code (connect to MySQL database)
研究生新生培训第二周:卷积神经网络基础
Use of file upload (2) -- upload to Alibaba cloud OSS file server
【DL】搭建卷积神经网络用于回归预测(数据+代码详细教程)
Detailed explanation of MySQL statistical function count
【卷积核设计】Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
Realize the scheduled backup of MySQL database in Linux environment through simple script (mysqldump command backup)
SQL repair duplicate data
ANR优化:导致 OOM 崩溃及相对应的解决方案
Exploration of flutter drawing skills: draw arrows together (skill development)
Centos7 silently installs Oracle
【Transformer】SOFT: Softmax-free Transformer with Linear Complexity
【bug】XLRDError: Excel xlsx file; not supported
Thinkphp6 pipeline mode pipeline use
ABSA1: Attentional Encoder Network for Targeted Sentiment Classification
微信小程序源码获取(附工具的下载)
[overview] image classification network
Personal learning website
mysql在查询字符串类型的时候带单引号和不带的区别和原因
【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition