当前位置:网站首页>【Numpy和Pytorch的数据处理】
【Numpy和Pytorch的数据处理】
2022-07-06 09:22:00 【喜欢库里的强化小白】
Numpy和Pytorch数据处理
最近在看多智能体强化学习代码,代码内有很多关于数据拼接、维度调整的操作,搞的整个人头晕转向。这里主要是numpy和torch两个工具包,为了读懂代码,查了一些博主的有关介绍,在此进行系统的总结,方便使用时查找,这里只总结代码里使用的语法和一些不太熟悉的语法。
一、Numpy的有关操作
1.np.nonzero()
np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。
a = [0,1,1,1,1,0,0,0,0]
a1 = [[0,1,1],[1,0,1],[1,1,0]]
b = np.nonzero(a)
b1 = np.nonzero(a1)
print(b,b[0])
print(b1,b1[0])
结果:
(array([1, 2, 3, 4], dtype=int64),) [1 2 3 4]
(array([0, 0, 1, 1, 2, 2], dtype=int64), array([1, 2, 0, 2, 0, 1], dtype=int64)) [0 0 1 1 2 2]
说明:
(1)一维列表,直接返回对应非零元素的索引,返回类型是元组tuple;b[0]则返回一个ndarray(4,)的数据,即[1,2,3,4].
(2)二维列表,返回两个元组tuple,前边对应行,后边对应列.
行:第一行有两个非零元素,故返回0,0;同理,第二行和第三行也有两个非零元素,返回1,1和2,2.
列:第一行的1、2号索引为非零元素,第二行0、2号元素为非零元素,第三行0、1号元素,故返回array([1,2,0,2,0,1]),b1[0]则返回行的结果
2.np.hstack()和np.vstack()
np.hstack()和np.vstack()是numpy中拼接数组的方法.
(1)np.hstack():水平方向上平铺,变长
(2)np.vstack():垂直方向上堆叠,变多
arr1= np.array([1,2,3])
arr2 = np.array([4,5,6])
arr_h = np.hstack([arr1,arr2])
arr_v = np.vstack([arr1,arr2])
print(arr_h)
print(arr_v)
arr3 = np.array([[1,2],[3,4]])
arr4 = np.array([[5,6],[7,8]])
arr1_h = np.hstack([arr3,arr4])
arr1_v = np.vstack([arr3,arr4])
print(arr1_h)
print(arr1_v)
结果:
[1 2 3 4 5 6]
[[1 2 3]
[4 5 6]]
[[1 2 5 6]
[3 4 7 8]]
[[1 2]
[3 4]
[5 6]
[7 8]]
说明:
(1)np.hstack()将[1,2,3]和[4,5,6]水平平铺成[1,2,3,4,5,6]
(2)np.vstack()将[1,2,3]和[4,5,6]垂直堆叠成[[1,2,3],[4,5,6]]
二维数组:
(1)np.hstack()将每个数组分别在水平方向平铺,[[1,2],[3,4]]和[[5,6],[7,8]]水平平铺成[[1,2,5,6],[3,4,7,8]]
(2)np.vstack()则在垂直方向堆叠,堆叠成[[1,2],[3,4],[5,6],[7,8]]
报错情况分析:
(1)无法水平平铺
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6],[7,8]])
想要水平平铺,第一维度的行数应该一致,这种情况arr3的行数为3,arr4的行数为2,无法进行水平拼接,但可以垂直堆叠,结果为:
[[1,2]
[3,4]
[5,6]
[5,6]
[7,8]]
(2)无法垂直堆叠
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6,7],[6,7,8],[7,8,9]])
想要垂直堆叠,第二维度的列数应该一致,这种情况arr3的列数为2,arr4的列数为3,无法进行垂直堆叠,但可以水平平铺,结果为:
[[1,2,5,6,7]
[3,4,6,7,8]
[5,6,7,8,9]]
3.np.concatenate()
np.concatenate 是numpy中对array进行拼接的函数.
a = np.random.random((2,3))
b = np.random.random((4,3))
c = np.concatenate((a,b),axis=0)
d = np.random.random((2,3))
e = np.random.random((2,5))
f = np.concatenate((d,e),axis=1)
print(a)
print(b)
print(c)
print(d)
print(e)
print(f)
结果:
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799]]
[[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799],
[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.71757919 0.50921365 0.87668144],
[0.69402294 0.25790403 0.2260284 ]]
[[0.27418893 0.68954413 0.63393597 0.18947664 0.11354748],
[0.38312499 0.68509913 0.22171717 0.1272298 0.10387313]]
[[0.71757919 0.50921365 0.87668144 0.27418893 0.68954413 0.63393597, 0.18947664 0.11354748],
[0.69402294 0.25790403 0.2260284 0.38312499 0.68509913 0.22171717, 0.1272298 0.10387313]]
说明:
(1)axis=0,表示按行拼接,要求列的维度一致
(2)axis=1,表示按列拼接,要求行的维度一致
二、Pytorch的有关操作
1.torch.tensor()
torch.tensor()是为了将数据转为会tensor数据类型,对numpy和list都可以进行转换,方便pytorch进行后续的训练。
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(list1)
tensor2 = torch.tensor(arr1)
结果:
tensor1:tensor([1, 2, 3, 4, 5, 6])
tensor2:tensor([1, 2, 3, 4, 5, 6])
2.torch.unsqueeze()和torch.squeeze()
torch.unsqueeze()和torch.squeeze()是进行调整尺寸的函数.
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(arr1)
tensor2 = tensor1.unsqueeze(0)
tensor3 = tensor2.squeeze(0)
print(tensor2)
print(tensor3)
结果:
tensor([[1, 2, 3, 4, 5, 6]], dtype=torch.int32)
tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32)
说明:
torch.unsqueeze(0)表示在既定位置插入维度1,扩展维度,将维度变为torch.size([1,6])
torch.squeeze(0)表示在既定位置中的维度除去,缩减维度.
3.torch.sum()
torch.tensor(a,dim)为求和元素,a为操作对象,dim为求和的维度.
a = torch.ones((2,3))
a1 = torch.sum(a,dim=0)
a2 = torch.sum(a,dim=0,keepdim=True)
b1 = torch.sum(a,dim=1)
b2 = torch.sum(a,dim=1,keepdim=True)
c1 = torch.sum(a)
print(a)
print(a1)
print(b1)
print(c1)
print(a2)
print(b2)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor(6.)
tensor([[2., 2., 2.]])
tensor([[3.],
[3.]])
说明:
(1)如果不加dim,即表示所有数求和,为6
(2)dim = 0,表示按行求和,[2,2,2],dim=0这个维度被去掉
(3)dim = 1,表示按列求和,[3,3],dim=1这个维度被去掉
(4)keepdim = True,表示dim的维度不会被squeeze
例如:a2数据变为了torch.size([1,3])
例如:b2数据变为了torch.size([2,1])
4.torch.repeat()
torch.repeat()用于对张量进行重复扩充.
(1)当参数只有两个时:(行的重复倍数,列的重复倍数),1表示不重复
(2)当参数有三个时:(通道数的重复倍数,行的重复倍数,列重复倍数)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = a.repeat(2,2)
c = a.repeat(2,2,2)
print(b)
print(c)
结果:
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]])
tensor([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]],
[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]]])
5.torch.ones_like()和torch.zeros_like()
torch.ones_like函数和torch.zeros_like函数的基本功能是根据给定张量,生成与其形状相同的全1或全0张量,这里以torch.ones_like为例.
a = torch.ones((2,3))
print(a)
b = torch.ones_like(a)
print(b)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
说明:torch.ones()如果只传一个参数,生成的张量不是方阵.
6.torch.reshape()
变换张量tensor的形状,注意两个数据类型都是张量.
a = torch.rand(12)
b= torch.reshape(a,(3,4))
c= torch.reshape(a,(-1,2,3))
print(a)
print(b)
print(c)
结果:
tensor([0.4659, 0.8868, 0.9371, 0.0049, 0.6908, 0.2511, 0.2538, 0.2651, 0.0320, 0.0593, 0.6062, 0.4983])
tensor([[0.4659, 0.8868, 0.9371, 0.0049],
[0.6908, 0.2511, 0.2538, 0.2651],
[0.0320, 0.0593, 0.6062, 0.4983]])
tensor([[[0.4659, 0.8868, 0.9371],
[0.0049, 0.6908, 0.2511]],
[[0.2538, 0.2651, 0.0320],
[0.0593, 0.6062, 0.4983]]])
说明:
(1)torch.reshape()可以用于处理同样维度的数据,比如把torch.size([3,4])处理为torch.size([2,6])
(2)torch.reshape()可以用于处理不同维度的数据,比如可以把二维变成三维,-1代表根据二者的关系自动匹配该位置的维度.
7.view()
变换张量tensor的形状,类似于torch.reshape(),但没有torch.view()这一函数,用法为在变量后加.view(),个人感觉没有reshape好用
a = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12])
b= a.view((2,6))
c= a.view((-1,2,3))
print(a)
print(b)
print(c)
结果:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
8.expand()
函数返回张量在某一个维度扩展之后的张量,用法为在变量后加.expand(),个人感觉没有repeat好用
如果为.expand(-1,-1)则表示当前维度不进行扩充
a = torch.tensor([1,2,3,4])
b= a.expand((2,4))
print(b)
c = torch.tensor([[1],[2]])
d = c.expand((2,4))
print(d)
tensor([[1, 2, 3, 4],
[1, 2, 3, 4]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
9.torch.cat()
torch.cat是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接.
a = torch.ones((2,3))
b = 2*torch.ones((4,3))
c = 3*torch.ones((2,5))
d = torch.cat((a,b),dim=0)
e = torch.cat((a,c),dim=1)
print(a)
print(b)
print(c)
print(d)
print(e)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.]])
tensor([[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[1., 1., 1., 3., 3., 3., 3., 3.],
[1., 1., 1., 3., 3., 3., 3., 3.]])
说明:
(1)dim=0表示按行拼接,要保证列数的维度一致
(2)dim=1表示按列拼接,要保证行数的维度一致
暂且用到这些函数,其他的还没有遇到,先总结这么多,有一些写错的地方,会在后续的学习中进行更正;后续如果有其他用到的函数,也会进行更新。
边栏推荐
- C language Getting Started Guide
- Nuxtjs quick start (nuxt2)
- 8. C language - bit operator and displacement operator
- 【毕业季·进击的技术er】再见了,我的学生时代
- [during the interview] - how can I explain the mechanism of TCP to achieve reliable transmission
- JS several ways to judge whether an object is an array
- 7-14 错误票据(PTA程序设计)
- Simply understand the promise of ES6
- 实验九 输入输出流(节选)
- Zatan 0516
猜你喜欢
[hand tearing code] single case mode and producer / consumer mode
C语言入门指南
2. Preliminary exercises of C language (2)
[au cours de l'entrevue] - Comment expliquer le mécanisme de transmission fiable de TCP
Thoroughly understand LRU algorithm - explain 146 questions in detail and eliminate LRU cache in redis
QT meta object qmetaobject indexofslot and other functions to obtain class methods attention
Difference and understanding between detected and non detected anomalies
仿牛客技术博客项目常见问题及解答(二)
A piece of music composed by buzzer (Chengdu)
PriorityQueue (large root heap / small root heap /topk problem)
随机推荐
编写程序,模拟现实生活中的交通信号灯。
【数据库 三大范式】一看就懂
Poker game program - man machine confrontation
Redis的两种持久化机制RDB和AOF的原理和优缺点
Miscellaneous talk on May 14
7-14 错误票据(PTA程序设计)
一段用蜂鸣器编的音乐(成都)
5. Download and use of MSDN
Analysis of penetration test learning and actual combat stage
Miscellaneous talk on May 27
2. First knowledge of C language (2)
Detailed explanation of redis' distributed lock principle
5. Function recursion exercise
【VMware异常问题】问题分析&解决办法
3. C language uses algebraic cofactor to calculate determinant
[the Nine Yang Manual] 2018 Fudan University Applied Statistics real problem + analysis
ABA问题遇到过吗,详细说以下,如何避免ABA问题
Mortal immortal cultivation pointer-1
(original) make an electronic clock with LCD1602 display to display the current time on the LCD. The display format is "hour: minute: Second: second". There are four function keys K1 ~ K4, and the fun
7-3 构造散列表(PTA程序设计)