当前位置:网站首页>pytorch基础(1)
pytorch基础(1)
2022-06-27 20:17:00 【51CTO】
基本数据类型和tensor
1 import torch
2 import numpy as np
3
4 #array 和 tensor的转换
5 array = np.array([1.1,2,3])
6 tensorArray = torch.from_numpy(array) #array对象变为tensor对象
7 array1 = tensorArray.numpy()#tensor对象变为array对象
8 print(array,'\t', tensorArray, '\t', array1 )
9
10 #torch拥有和numpy一样的处理数据的能力
11 print(torch.sin(tensorArray))
12 print(np.ones([2,5]))#两行五列
13 print(np.ones(2))#一行两个数字
14 a = torch.randn(2, 3)#两行三列的正态分布
15 print(a)
16 print(a.size(0),a.size(1),a.shape[1])#2,3,3 0代表行,1代表对应的列数
17 print(a.shape)#torch.Size([2,3])
18 print(a.type())#torch.FloatTensor
19 isinstance(a, torch.DoubleTensor)#false
20 isinstance(a, torch.FloatTensor)#true
21 a1 = a.cuda()
22 print(isinstance(a1,torch.FloatTensor))#false
23 print(isinstance(a1,torch.cuda.FloatTensor))#true,#torch里面的数据不同于torch.cuda里面的数据
24
25 #torch的tensor对象
26 tensor1 = torch.tensor(1)
27 print(tensor1)#tensor(1)
28 tensor2 = torch.tensor(1.2)
29 print(tensor2)#tensor(1.2000)
30 print(tensor2.shape)#torch.Size([])
31 print(len(tensor2.shape))#0,当tensor只是一个数字的时候,他的维度是0,所以他的size是[],shape为0
32 tensor3 = torch.tensor([1.1])#一维的列表,所以输出维度是1
33 print(tensor3,tensor3.shape)#tensor([1.1000]) torch.Size([1])
34 tensor4 = torch.FloatTensor(1)#注意此时1代表随机返回一个FloatTensor对象
35 print(tensor4)#tensor([1.1000])
36 tensor5 = torch.FloatTensor(3)#考虑一下tensor和FloatTensor的差别
37 print(tensor5)#tensor([0.0000e+00, 0.0000e+00, 6.8645e+36])
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
切片
1 import torch
2 import numpy as np
3
4 #tensor和随机数
5 a = torch.rand(2, 3, 28, 28)
6 print(a, a.shape)#随机生成一个2*3*28*28的四维矩阵(可以看成声明一个四维矩阵),torch.Size([2, 3, 28, 28])
7 #四维适合做CNN ,三维适合RNN,二维适合batch
8 print(a.numel())#4707,计算元素个数
9 print(a.dim())#4
10 print(torch.tensor(1).dim())#0,
11 print(torch.empty(1))#一维数字0,tensor([0.])
12 print(torch.Tensor(2,3).type())#默认是torch.FloatTensor
13 print(torch.IntTensor(2,3))#2*3
14 print(torch.tensor([1,1]).type())#torch.LongTensor
15 print(torch.tensor([1.2,1]).type())#torch.FloatTensor
16 print(torch.rand(3,3))#取值范围为0到1之间
17 print(torch.rand_like(torch.rand(3,3)))#rand_like直接继承了参数的行和列,生成3*3的0到1之间的随机矩阵
18 print(torch.randint(1,10,(3,3)))#取值在1到10之间(左闭右开)大小为3*3的矩阵
19 print(torch.randn(3,3))#3*3矩阵,服从均值为0,方差为1的正态分布
20 print(torch.normal(mean=torch.full([10],0),std = torch.arange(1,0,-0.1)))#均值为0方差递减的10*1一维矩阵
21 print(torch.full([2,3],7))#2*3全为7的二维矩阵
22 print(torch.full([],7))#数字7维度0
23 print([1],7)#一维1*1矩阵元素为7
24 print(torch.logspace(0,1,steps=10))#log(10^0)到log(10^1)中间取10个数
25
26 #切片
27 a = torch.rand(4,3,28,28)
28 print(a[0].shape)#torch.Size([3,28,28])
29 print(a[0,0].shape)#torch.Size([28, 28])
30 print(a[0,0,2,4])#tensor(0.6186)
31 print(a[:2].shape)#torch.Size([2, 3, 28, 28])
32 print(a[:2,1:,:,:].shape)#torch.Size([2, 2, 28, 28])
33 print(a[:2,-1:,:,:].shape)#torch.Size([2, 1, 28, 28])
34
35 print(a[:,:,0:28:2,0:28:2].shape)#torch.Size([4, 3, 14, 14])
36 print(a[:,:,::2,::2].shape)#torch.Size([4, 3, 14, 14])
37 print(a.index_select(1,torch.arange(1)).shape)#torch.Size([4, 1, 28, 28])
38
39 x = torch.randn(3,4)
40 mask = x.ge(0.5)#比0.5大的标记为true
41 print(mask)
42 torch.masked_select(x,mask)#把为true的选择出来
43 torch.masked_select(x,mask).shape
44
45 src =torch.tensor([[4,3,5],[6,7,8]])
46 taa = torch.take(src, torch.tensor([0,3,5]))#展平后按照位置选数据
47 print(taa)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
维度变换
1 import torch
2 import numpy as np
3 #维度变化
4 #View reshape
5 a = torch.rand(4,1,28,28)#四张图片,通道数是1,长宽是28*28
6 print(a.shape)#torch.Size([4, 1, 28, 28])
7 print(a.view(4,1*28*28).shape)#torch.Size([4, 784]),把后三维展成一行
8 print(a.view(4*28,28).shape)#torch.Size([112, 28])变成112行28列的二维数据
9 print(a.view(4*1,28,28).shape)#torch.Size([4, 28, 28])要理解对应的图片的物理意义
10 b = a.view(4,784)
11 print(b.view(4,28,28,1).shape)#torch.Size([4, 28, 28, 1]),b变成的数据不是a(一定要注意)
12 #print(a.view(4,783))#尺寸不一致会报错
13
14 #unsqueeze,增加维度,但不会影响数据的变化
15 #数据的范围是[-a.dim()-1,a.dim()+1)
16 print()#下面例子是[-5,5)
17 print(a.unsqueeze(0).shape)#torch.Size([1, 4, 1, 28, 28])
18 print(a.unsqueeze(-1).shape)#torch.Size([4, 1, 28, 28, 1])
19 print(a.unsqueeze(4).shape)#torch.Size([4, 1, 28, 28, 1])
20 print(a.unsqueeze(-4).shape)#torch.Size([4, 1, 1, 28, 28])
21 print(a.unsqueeze(-5).shape)#torch.Size([1, 4, 1, 28, 28])
22 #print(a.unsqueeze(5).shape)
23 a = torch.tensor([1.2,2.3])#a的shape是[2]
24 print(a.unsqueeze(-1))#tensor([[1.2000],
25 #[2.3000]])变成2行一列
26 print(a.unsqueeze(0))#tensor([[1.2000, 2.3000]])#shape变成[1,2],即一行二列
27 b = torch.rand(32)
28 f = torch.rand(4,3,14,14)
29 b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)#torch.Size([1, 32, 1, 1])
30 print(b.shape)
31
32 #维度减少
33 print()
34 print(b.shape)#torch.Size([1, 32, 1, 1])
35 print(b.squeeze().shape)#torch.Size([32]),所有为1的被挤压
36 print(b.squeeze(-1).shape)#torch.Size([1, 32, 1])
37 print(b.squeeze(0).shape)#torch.Size([32, 1, 1])
38 print(b.squeeze(1).shape)#torch.Size([1, 32, 1, 1]),因为不等于1 所以没有被挤压
39 print(b.squeeze(-4).shape)#torch.Size([32, 1, 1])
40
41 #expand扩展数据,进行数据拷贝,但不会主动复制数据,只会在需要的时候复制,推荐使用
42 print()
43 print(b.shape)#torch.Size([1, 32, 1, 1])
44 print(b.expand(4,32,14,14).shape)#torch.Size([4, 32, 14, 14]),只能对维度是1 的进行扩展
45 print(b.expand(-1,32,-1,-1).shape)#torch.Size([1, 32, 1, 1]),其他维度为-1,这样可以进行原维度不是一的进行扩展同样大小的维度
46 print(b.expand(-1,32,-1,-4).shape)#torch.Size([1, 32, 1, -4]) -4是无意义的
47
48 #repeat表示在原来维度上拷贝多少次,而不是扩展到多少,这个方法申请了新的空间,对空间使用加大
49 print()
50 print(b.shape)#torch.Size([1, 32, 1, 1])
51 print(b.repeat(4,32,1,1).shape)#torch.Size([4, 1024, 1, 1]),第二维表示拷贝愿来的32倍
52 print(b.repeat(4,1,1,1).shape)#torch.Size([4, 32, 1, 1])
53 print(b.repeat(4,1,32,32).shape)#torch.Size([4, 32, 32, 32])
54
55 #transpose实现指定维度之间的交换
56 a = torch.rand(4,3,32,32)
57 print(a.shape)#torch.Size([4, 3, 32, 32])
58 a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
59 print(a1.shape)#torch.Size([4, 3, 32, 32])
60 print(torch.all(torch.eq(a,a1)))#tensor(True)
61
62 #premute实现指定维度位置交换到指定位置
63 print(a.permute(0,2,3,1).shape)#torch.Size([4, 32, 32, 3])
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
作者:你的雷哥
本文版权归作者所有,欢迎转载,但未经作者同意必须在文章页面给出原文连接,否则保留追究法律责任的权利。
边栏推荐
- Livox Lidar+APX15 实时高精度雷达建图复现整理
- 凌云出海记 | 沐融科技&华为云:打造非洲金融SaaS解决方案样板
- ABAP随笔-物料主数据界面增强-页签增强
- Livox Lidar+海康Camera实时生成彩色点云
- Avoid using 100vh[easy to understand] at mobile terminal
- Spark BUG实践(包含的BUG:ClassCastException;ConnectException;NoClassDefFoundError;RuntimeExceptio等。。。。)
- Teach you how to print your own log -- how to customize log4j2 components
- Azure Kinect DK realizes 3D reconstruction (Jetson real-time version)
- Azure Kinect DK realizes 3D reconstruction (PC non real time version)
- Open source of local run / development library of hiplot online drawing tool
猜你喜欢

"I make the world cooler" 2022 Huaqing vision R & D product launch was a complete success

About the SQL injection of davwa, errors are reported: analysis and verification of the causes of legal mix of settlements for operation 'Union'

爬虫笔记(2)- 解析

Livox Lidar+APX15 实时高精度雷达建图复现整理

netERR_ CONNECTION_ Refused solution

How to participate in openharmony code contribution

average-population-of-each-continent

初识C语言 第二弹

雪糕还是雪“高”?

Summary of various loams (laser SLAM)
随机推荐
Avoid using 100vh[easy to understand] at mobile terminal
[essay]me53n add button to call URL
批量处理-Excel导入模板1.1-支持多Sheet页
[cloud based co creation] what is informatization? What is digitalization? What are the connections and differences between the two?
OData - API using SAP API hub in SAP S4 op
元气森林的5元有矿之死
Do you know the full meaning of log4j2 configurations? Take you through all the components of log4j2
Workflow automation low code is the key
《7天學會Go並發編程》第7天 go語言並發編程Atomic原子實戰操作含ABA問題
Follow the archiving tutorial to learn rnaseq analysis (IV): QC method for de analysis using deseq2
"I make the world cooler" 2022 Huaqing vision R & D product launch was a complete success
Advertising is too "wild", Yoshino "surrenders"
mysql操作入门(四)-----数据排序(升序、降序、多字段排序)
Azure Kinect DK 实现三维重建 (jetson实时版)
netERR_CONNECTION_REFUSED 解决大全
average-population-of-each-continent
Azure Kinect DK realizes 3D reconstruction (PC non real time version)
基于 ESXi 的黑群晖 DSM 7.0.1 安装 VMware Tools
Infiltration learning - problems encountered during SQL injection - explanation of sort=left (version(), 1) - understanding of order by followed by string
MySQL数据库 实验报告(一)