当前位置:网站首页>【深度学习】Pytorch Tensor 张量
【深度学习】Pytorch Tensor 张量
2022-07-27 18:03:00 【爱吃糖的范同学】

目录
现在是凌晨12点,记录一下学习,重新复习一下Pytorch......好吧,其实也不算复习,之前也只是简单的了解了一下,仅此而已。但是!现在不一样,需要仔细的去学习!
Let‘s do it !!!
这只是一篇简单的学习笔记,仅此而已!!!
一、张量概述:
一种特殊的数据结构,使用在深度学习的神经网络中,类似数组(多维度)和矩阵。神经网络的输入输出、网格参数都是使用张量来进行描述!
import torch
import numpy as np
二、初始化张量:
张量的初始化方式有多种,主要是根据数据来源选择不同的初始化方法:
直接使用Python列表转化为张量:
data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)
使用torch库中的函数tensor将一个二维python列表转换为一个二维的张量。
通过Numpy数组(ndarray)转换为张量:
ndarray和张量(tensor)之间是支持相互转换的
np_array = np.array(data)
x_np = torch.from_numpy(np_array)
通过已有的张量生成新的张量:
新的张量将会继承原有张量的数据属性(结构和类型),也可以重新指定新的数据属性。
x_ones = torch.ones_like(x_data) # 保留 x_data 的属性
print(f"Ones Tensor: \n {x_ones} \n")
x_rand = torch.rand_like(x_data, dtype=torch.float) # 重写 x_data 的数据类型int -> float
print(f"Random Tensor: \n {x_rand} \n")
Ones Tensor:
tensor([[1, 1],
[1, 1]])
Random Tensor:
tensor([[0.0381, 0.5780],
[0.3963, 0.0840]])
通过指定数据维度生成张量:
使用shape元组指定生成的张量维度,将元组传递给torch函数创建不同的张量:
shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)
print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")
Random Tensor:
tensor([[0.0266, 0.0553, 0.9843],
[0.0398, 0.8964, 0.3457]])
Ones Tensor:
tensor([[1., 1., 1.],
[1., 1., 1.]])
Zeros Tensor:
tensor([[0., 0., 0.],
[0., 0., 0.]])
三、张量属性:
通过张量的不同属性,可以知道张量的维度,张量的数据类型、张量的存储设备(物理设备)
tensor = torch.rand(3,4)
print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")
Shape of tensor: torch.Size([3, 4]) # 维数
Datatype of tensor: torch.float32 # 数据类型
Device tensor is stored on: cpu # 存储设备
四、张量的运算:
检查当前运行环境是否支持Pytorch,检查代码:
# 判断当前环境GPU是否可用, 然后将tensor导入GPU内运行
if torch.cuda.is_available():
tensor = tensor.to('cuda')
1.张量的索引和切片:
Python的切片,第一个参数是行操作,第二个参数是列操作。
tensor = torch.ones(4, 4)
tensor[:,1] = 0 # 将第1列(从0开始)的数据全部赋值为0
print(tensor)
所有的索引位置都是从0开始:
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
2.张量的拼接:
你可以通过torch.cat方法将一组张量按照指定的维度进行拼接, 也可以参考torch.stack方法。
t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(t1)
注意这里的dim参数,这里是指定tensor拼接的维度,维度索引同样是从0开始,0表示第一维,1表示第二维,所以拼接在二维的情况是按照列拼接:
tensor([[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.]])
想知道有几个维度,数出有几层中括号就行,有几层中括号就有几维。而且,随着中括号由外向里走,维度依次增加:从 0 变为 1 变为 2。
3.张量的乘法和矩阵乘法:
简单区分一下乘法和矩阵乘法的区别:
- 乘法:在矩阵上是两个shape相同的矩阵(就是需要满足矩阵的形状一致),对应位置上的元素相乘
- 矩阵乘法:要求矩阵内联的维度一致,即(n,m)x (m,z)
乘法(点乘):
# 逐个元素相乘结果
print(f"tensor.mul(tensor): \n {tensor.mul(tensor)} \n")
# 等价写法:
print(f"tensor * tensor: \n {tensor * tensor}")
tensor.mul(tensor):
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
tensor * tensor:
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
矩阵乘法(叉乘):
print(f"tensor.matmul(tensor.T): \n {tensor.matmul(tensor.T)} \n")
# 等价写法:
print(f"tensor @ tensor.T: \n {tensor @ tensor.T}")
tensor.matmul(tensor.T):
tensor([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]])
tensor @ tensor.T:
tensor([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]])
4.自动赋值运算:
自动赋值运算通常在方法后有 _ 作为后缀, 例如: x.copy_(y), x.t_()操作会改变 x 的取值。即将方法调用执行的结果重新赋值给调用方法的变量。
print(tensor, "\n")
tensor.add_(5)
print(tensor)
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
tensor([[6., 5., 6., 6.],
[6., 5., 6., 6.],
[6., 5., 6., 6.],
[6., 5., 6., 6.]])
注意:自动赋值运算虽然可以节省内存, 但在求导时会因为丢失了中间过程而导致一些问题, 所以我们并不鼓励使用它。
五、Tensor和Numpy的相互转换:
张量和ndarray数组在CPU上可以共用一块内存区域,改变其中一个值,另一个也会发生改变。
1.由tensor转换为ndarray:
tensor直接调用numpy方法:
t = torch.ones(5)
print(f"t: {t}")
n = t.numpy()
print(f"n: {n}")
t: tensor([1., 1., 1., 1., 1.])
n: [1. 1. 1. 1. 1.]
此时,如果修改张量tensor的值,那么对应的ndarray中的值也会发生改变,这里只是变量类型的改变,但是变量指向的内存地址是同一个内存空间:
t.add_(1)
print(f"t: {t}")
print(f"n: {n}")
t: tensor([2., 2., 2., 2., 2.])
n: [2. 2. 2. 2. 2.]
2.由Ndarray转换为Tensor:
n = np.ones(5)
t = torch.from_numpy(n)
修改Numpy array数组的值,则张量值也会随之改变。
np.add(n, 1, out=n)
print(f"t: {t}")
print(f"n: {n}")
t: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
n: [2. 2. 2. 2. 2.]
学习到这里,目前对张量Tensor有了一个简单的认识,了解如何创建Tensor变量,Tensor的属性,以及Tensor的常用操作!
后面将继续学习相关的内容,加油各位!!!
.
顺便提一下我的学习直播间在某站,哈哈哈,没事的时候可以过来逛逛,那就到这里吧,共勉各位!!!
边栏推荐
- How to solve the problem of missing alarm information and synchronization when Haikang equipment is connected to easycvr?
- Under the epidemic, I left my job for a year, and my income increased 10 times
- [deep learning] video classification technology sorting
- In 2019, China's smart machine Market: Huawei won nearly 4 components, firmly ranking first in China
- 国际权威认可!OceanBase入选Forrester Translytical数据平台报告
- Jetpack compose performance optimization guide - compilation metrics
- greedy
- Oracle simple advanced query
- antdv: Each record in table should have a unique `key` prop,or set `rowKey` to an unique primary key
- 如何监控NVIDIA Jetson的的运行状态和使用情况
猜你喜欢

Two years after its release, the price increased by $100, and the reverse growth of meta Quest 2

Nailing development document

How bad can a programmer be?

Some contents related to cmsis-rtos

海康设备接入EasyCVR,出现告警信息缺失且不同步该如何解决?

【深度学习】视频分类技术整理

Babbitt | metauniverse daily must read: Tencent News suspended the sales service of digital collections, users left messages asking for a "refund", and phantom core also fell into the dilemma of "unsa

Injection attack

JVM overview and memory management (to be continued)

Understand the wonderful use of dowanward API, and easily grasp kubernetes environment variables
随机推荐
JVS私有化部署启动失败处理方案
Idea: solve the problem of code without prompt
MySQL log query log
一个程序员的水平能差到什么程度?
Learn about the 12 necessary animation plug-ins of blender
Express: search product API by keyword
[deep learning] video classification technology sorting
Mlx90640 infrared thermal imager temperature sensor module development notes (VII)
我也是醉了,Eureka 延迟注册还有这个坑
JS jump to the page and refresh (jump to this page)
[map set]
MLX90640 红外热成像仪测温传感器模块开发笔记(七)
Get wechat product details API
Redis 事物学习
A recently summarized universal violent cracking method
2022.07.11
greedy
Flask Mdict builds online MDICT Dictionary Service
MySQL log error log
RK3399平台入门到精通系列讲解(导读篇)21天学习挑战介绍
https://live.bilibili.com/22404070?broadcast_type=0&is_room_feed=1?spm_id_from=333.999.space_home.left_live.click