当前位置:网站首页>PyTorch基础知识(可入门)
PyTorch基础知识(可入门)
2022-07-29 05:21:00 【Quinn-ntmy】
PyTorch基础
- 核心是张量Tensor,一种多维数据的数学对象。
- 使用torch包创建张量,具体的很简单,不多说了。
(1)如果想用特定的值填充张量,可以使用 fill_() 方法。【任何带有下划线(_)的PyTorch方法都是指原位操作,即不用创建新对象就地修改内容】
(2)当使用torch.Tensor构造函数时,默认张量类型是torch.FloatTensor。可以在使用时进行类型转换 或 利用torch.tensor()中的dtype参数。 - 加减乘除操作都和其他类似。
- torch.transpose():转置操作,很简单,可参考https://blog.csdn.net/qq_50001789/article/details/120451717
- 普通索引和切片:基本上同numpy。
- 复杂索引:
(1) 张量的非连续索引
def describe(x):
print("Type: {}".format(x.type()))
print("Shape/size: {}".format(x.shape))
print("Values: \n{}".format(x))
import torch
# 创建一个2D张量
x = torch.arange(6).view(2, 3)
describe(x)
# 1、获取2D张量的第2个维度且索引号为0和2的张量子集
indices = torch.LongTensor([0, 2])
describe(torch.index_select(x, dim=1, index=indices))
# 2、获取2D张量的第1个维度且索引号为0和0的张量子集
indices = torch.LongTensor([0, 0])
describe(torch.index_select(x, dim=0, index=indices))
# 3、row_indices结果是个tensor([0., 1.]),然后col_indices是tensor([1, 2])
row_indices = torch.arange(2).long()
col_indices = torch.LongTensor([1, 2])
describe(x[row_indices, col_indices])
输出结果:
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[0, 2],
[3, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[0, 1, 2]])
Type: torch.LongTensor
Shape/size: torch.Size([2])
Values:
tensor([1, 5])
Process finished with exit code 0
解释一下最后一个操作(3),
(2)连接张量
import torch
x = torch.arange(6).view(2, 3)
describe(x)
describe(torch.cat([x, x], dim=0))
describe(torch.cat([x, x], dim=1))
describe(torch.stack([x, x]))
结果:
Type: torch.LongTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([4, 3])
Values:
tensor([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 6])
Values:
tensor([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]])
Type: torch.LongTensor
Shape/size: torch.Size([2, 2, 3])
Values:
tensor([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]])
Process finished with exit code 0
- 张量的requires_grad=True时,可以追踪梯度张量以及梯度函数。这两个东西需要基于梯度的学习“监督学习范式”。
创建张量并计算梯度:
import torch
x = torch.ones(2, 2, requires_grad=True)
describe(x)
print(x.grad is None)
y = (x+2)*(x+5)+3
describe(y)
print(x.grad is None)
z = y.mean()
describe(z)
z.backward()
print(x.grad is None)
结果:
Type: torch.FloatTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
True
Type: torch.FloatTensor
Shape/size: torch.Size([2, 2])
Values:
tensor([[21., 21.],
[21., 21.]], grad_fn=<AddBackward0>)
True
Type: torch.FloatTensor
Shape/size: torch.Size([])
Values:
21.0
False
Process finished with exit code 0
梯度是一个值(函数输出对于函数输入的斜率)。
- CUDA张量
(1)创建cuda张量
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.rand(3, 3).to(device)
describe(x)
结果:
cuda
Type: torch.cuda.FloatTensor
Shape/size: torch.Size([3, 3])
Values:
tensor([[0.2369, 0.9929, 0.6972],
[0.1366, 0.0594, 0.0726],
[0.4803, 0.1209, 0.6055]], device='cuda:0')
(2)CUDA张量与CPU绑定张量混合
# 将CUDA张量和CPU绑定张量混合(CUDA和非CUDA对象),y在CPU上,x在GPU上。这种情况就会报错
y = torch.rand(3, 3)
x + y
结果报错:
Traceback (most recent call last):
File "C:/Users/27692/PycharmProjects/Quinn/test/test.py", line 60, in <module>
x + y
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
要想对cuda和非cuda对象进行操作,必须确保它们在同一个设备上:
cpu_device = torch.device("cpu")
y = y.to(cpu_device)
x = x.to(cpu_device)
print(x + y)
结果:
tensor([[0.9204, 0.8360, 0.6289],
[1.2421, 1.4353, 1.2174],
[0.9342, 1.0427, 1.1978]])
Process finished with exit code 0
【练习扩展】
- unsqueeze()函数和squeeze()函数:
a = torch.rand(2, 3)
print(describe(a))
# 在第0维增加1个维度
b = a.unsqueeze(0)
print(describe(b))
# 在第1维增加1个维度
b = a.unsqueeze(1)
print(describe(b))
# 在倒数第2个维度增加1个维度
b = a.unsqueeze(-2)
print(describe(b))
注意看输出结果:!!!张量形状(torch.size)
Type: torch.FloatTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0.2643, 0.1925, 0.2562],
[0.7674, 0.9930, 0.2341]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([1, 2, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562],
[0.7674, 0.9930, 0.2341]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562]],
[[0.7674, 0.9930, 0.2341]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.2643, 0.1925, 0.2562]],
[[0.7674, 0.9930, 0.2341]]])
None
Process finished with exit code 0
# ======squeeze()======
# 上面是插入/增加,这个是删除/去掉
a = torch.rand(2, 1, 3)
print(describe(a))
# 将第1维去掉
b = a.squeeze(1) # a.squeeze(-2)有同样作用 倒数第二个
print(describe(b))
# 试试-3
b = a.squeeze(-3) # 这块结果不变,因为只有维度为1时才会去掉
print(describe(b))
输出结果:
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.4349, 0.0568, 0.0304]],
[[0.2712, 0.3612, 0.3238]]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 3])
Values:
tensor([[0.4349, 0.0568, 0.0304],
[0.2712, 0.3612, 0.3238]])
None
Type: torch.FloatTensor
Shape/size: torch.Size([2, 1, 3])
Values:
tensor([[[0.4349, 0.0568, 0.0304]],
[[0.2712, 0.3612, 0.3238]]])
None
Process finished with exit code 0
- 创建一个具有正态分布的张量:
a = torch.rand(3,3)
a.normal_()
边栏推荐
- mysql 的show profiles 使用。
- 【Transformer】AdaViT: Adaptive Tokens for Efficient Vision Transformer
- The third week of postgraduate freshman training: resnet+resnext
- [semantic segmentation] setr_ Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformer
- anaconda中移除旧环境、增加新环境、查看环境、安装库、清理缓存等操作命令
- 【Clustrmaps】访客统计
- 这些你一定要知道的进程知识
- 电脑视频暂停再继续,声音突然变大
- Super simple integration of HMS ml kit to realize parent control
- The difference between asyncawait and promise
猜你喜欢

Windos下安装pyspider报错:Please specify --curl-dir=/path/to/built/libcurl解决办法

Process management of day02 operation

Simple optimization of interesting apps for deep learning (suitable for novices)

并发编程学习笔记 之 Lock锁及其实现类ReentrantLock、ReentrantReadWriteLock和StampedLock的基本用法

Reporting Services- Web Service

【数据库】数据库课程设计一一疫苗接种数据库

Spring, summer, autumn and winter with Miss Zhang (3)

SSM integration

MySql统计函数COUNT详解

研究生新生培训第三周:ResNet+ResNeXt
随机推荐
day02 作业之文件权限
30 knowledge points that must be mastered in quantitative development [what is individual data]?
ASM piling: after learning ASM tree API, you don't have to be afraid of hook anymore
[ml] PMML of machine learning model -- Overview
Super simple integration HMS ml kit face detection to achieve cute stickers
【ML】机器学习模型之PMML--概述
[semantic segmentation] full attention network for semantic segmentation
[DL] introduction and understanding of tensor
有价值的博客、面经收集(持续更新)
yum本地源制作
Rsync+inotyfy realize real-time synchronization of single data monitoring
Thinkphp6 output QR code image format to solve the conflict with debug
Semaphore (semaphore) for learning notes of concurrent programming
并发编程学习笔记 之 工具类Semaphore(信号量)
关于Flow的原理解析
【语义分割】Fully Attentional Network for Semantic Segmentation
【pycharm】pycharm远程连接服务器
【Clustrmaps】访客统计
Lock lock of concurrent programming learning notes and its implementation basic usage of reentrantlock, reentrantreadwritelock and stampedlock
Anr Optimization: cause oom crash and corresponding solutions