当前位置:网站首页>【Numpy和Pytorch的数据处理】
【Numpy和Pytorch的数据处理】
2022-07-06 09:22:00 【喜欢库里的强化小白】
Numpy和Pytorch数据处理
最近在看多智能体强化学习代码,代码内有很多关于数据拼接、维度调整的操作,搞的整个人头晕转向。这里主要是numpy和torch两个工具包,为了读懂代码,查了一些博主的有关介绍,在此进行系统的总结,方便使用时查找,这里只总结代码里使用的语法和一些不太熟悉的语法。
一、Numpy的有关操作
1.np.nonzero()
np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。
a = [0,1,1,1,1,0,0,0,0]
a1 = [[0,1,1],[1,0,1],[1,1,0]]
b = np.nonzero(a)
b1 = np.nonzero(a1)
print(b,b[0])
print(b1,b1[0])
结果:
(array([1, 2, 3, 4], dtype=int64),) [1 2 3 4]
(array([0, 0, 1, 1, 2, 2], dtype=int64), array([1, 2, 0, 2, 0, 1], dtype=int64)) [0 0 1 1 2 2]
说明:
(1)一维列表,直接返回对应非零元素的索引,返回类型是元组tuple;b[0]则返回一个ndarray(4,)的数据,即[1,2,3,4].
(2)二维列表,返回两个元组tuple,前边对应行,后边对应列.
行:第一行有两个非零元素,故返回0,0;同理,第二行和第三行也有两个非零元素,返回1,1和2,2.
列:第一行的1、2号索引为非零元素,第二行0、2号元素为非零元素,第三行0、1号元素,故返回array([1,2,0,2,0,1]),b1[0]则返回行的结果
2.np.hstack()和np.vstack()
np.hstack()和np.vstack()是numpy中拼接数组的方法.
(1)np.hstack():水平方向上平铺,变长
(2)np.vstack():垂直方向上堆叠,变多
arr1= np.array([1,2,3])
arr2 = np.array([4,5,6])
arr_h = np.hstack([arr1,arr2])
arr_v = np.vstack([arr1,arr2])
print(arr_h)
print(arr_v)
arr3 = np.array([[1,2],[3,4]])
arr4 = np.array([[5,6],[7,8]])
arr1_h = np.hstack([arr3,arr4])
arr1_v = np.vstack([arr3,arr4])
print(arr1_h)
print(arr1_v)
结果:
[1 2 3 4 5 6]
[[1 2 3]
[4 5 6]]
[[1 2 5 6]
[3 4 7 8]]
[[1 2]
[3 4]
[5 6]
[7 8]]
说明:
(1)np.hstack()将[1,2,3]和[4,5,6]水平平铺成[1,2,3,4,5,6]
(2)np.vstack()将[1,2,3]和[4,5,6]垂直堆叠成[[1,2,3],[4,5,6]]
二维数组:
(1)np.hstack()将每个数组分别在水平方向平铺,[[1,2],[3,4]]和[[5,6],[7,8]]水平平铺成[[1,2,5,6],[3,4,7,8]]
(2)np.vstack()则在垂直方向堆叠,堆叠成[[1,2],[3,4],[5,6],[7,8]]
报错情况分析:
(1)无法水平平铺
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6],[7,8]])
想要水平平铺,第一维度的行数应该一致,这种情况arr3的行数为3,arr4的行数为2,无法进行水平拼接,但可以垂直堆叠,结果为:
[[1,2]
[3,4]
[5,6]
[5,6]
[7,8]]
(2)无法垂直堆叠
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6,7],[6,7,8],[7,8,9]])
想要垂直堆叠,第二维度的列数应该一致,这种情况arr3的列数为2,arr4的列数为3,无法进行垂直堆叠,但可以水平平铺,结果为:
[[1,2,5,6,7]
[3,4,6,7,8]
[5,6,7,8,9]]
3.np.concatenate()
np.concatenate 是numpy中对array进行拼接的函数.
a = np.random.random((2,3))
b = np.random.random((4,3))
c = np.concatenate((a,b),axis=0)
d = np.random.random((2,3))
e = np.random.random((2,5))
f = np.concatenate((d,e),axis=1)
print(a)
print(b)
print(c)
print(d)
print(e)
print(f)
结果:
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799]]
[[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799],
[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.71757919 0.50921365 0.87668144],
[0.69402294 0.25790403 0.2260284 ]]
[[0.27418893 0.68954413 0.63393597 0.18947664 0.11354748],
[0.38312499 0.68509913 0.22171717 0.1272298 0.10387313]]
[[0.71757919 0.50921365 0.87668144 0.27418893 0.68954413 0.63393597, 0.18947664 0.11354748],
[0.69402294 0.25790403 0.2260284 0.38312499 0.68509913 0.22171717, 0.1272298 0.10387313]]
说明:
(1)axis=0,表示按行拼接,要求列的维度一致
(2)axis=1,表示按列拼接,要求行的维度一致
二、Pytorch的有关操作
1.torch.tensor()
torch.tensor()是为了将数据转为会tensor数据类型,对numpy和list都可以进行转换,方便pytorch进行后续的训练。
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(list1)
tensor2 = torch.tensor(arr1)
结果:
tensor1:tensor([1, 2, 3, 4, 5, 6])
tensor2:tensor([1, 2, 3, 4, 5, 6])
2.torch.unsqueeze()和torch.squeeze()
torch.unsqueeze()和torch.squeeze()是进行调整尺寸的函数.
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(arr1)
tensor2 = tensor1.unsqueeze(0)
tensor3 = tensor2.squeeze(0)
print(tensor2)
print(tensor3)
结果:
tensor([[1, 2, 3, 4, 5, 6]], dtype=torch.int32)
tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32)
说明:
torch.unsqueeze(0)表示在既定位置插入维度1,扩展维度,将维度变为torch.size([1,6])
torch.squeeze(0)表示在既定位置中的维度除去,缩减维度.
3.torch.sum()
torch.tensor(a,dim)为求和元素,a为操作对象,dim为求和的维度.
a = torch.ones((2,3))
a1 = torch.sum(a,dim=0)
a2 = torch.sum(a,dim=0,keepdim=True)
b1 = torch.sum(a,dim=1)
b2 = torch.sum(a,dim=1,keepdim=True)
c1 = torch.sum(a)
print(a)
print(a1)
print(b1)
print(c1)
print(a2)
print(b2)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor(6.)
tensor([[2., 2., 2.]])
tensor([[3.],
[3.]])
说明:
(1)如果不加dim,即表示所有数求和,为6
(2)dim = 0,表示按行求和,[2,2,2],dim=0这个维度被去掉
(3)dim = 1,表示按列求和,[3,3],dim=1这个维度被去掉
(4)keepdim = True,表示dim的维度不会被squeeze
例如:a2数据变为了torch.size([1,3])
例如:b2数据变为了torch.size([2,1])
4.torch.repeat()
torch.repeat()用于对张量进行重复扩充.
(1)当参数只有两个时:(行的重复倍数,列的重复倍数),1表示不重复
(2)当参数有三个时:(通道数的重复倍数,行的重复倍数,列重复倍数)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = a.repeat(2,2)
c = a.repeat(2,2,2)
print(b)
print(c)
结果:
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]])
tensor([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]],
[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]]])
5.torch.ones_like()和torch.zeros_like()
torch.ones_like函数和torch.zeros_like函数的基本功能是根据给定张量,生成与其形状相同的全1或全0张量,这里以torch.ones_like为例.
a = torch.ones((2,3))
print(a)
b = torch.ones_like(a)
print(b)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
说明:torch.ones()如果只传一个参数,生成的张量不是方阵.
6.torch.reshape()
变换张量tensor的形状,注意两个数据类型都是张量.
a = torch.rand(12)
b= torch.reshape(a,(3,4))
c= torch.reshape(a,(-1,2,3))
print(a)
print(b)
print(c)
结果:
tensor([0.4659, 0.8868, 0.9371, 0.0049, 0.6908, 0.2511, 0.2538, 0.2651, 0.0320, 0.0593, 0.6062, 0.4983])
tensor([[0.4659, 0.8868, 0.9371, 0.0049],
[0.6908, 0.2511, 0.2538, 0.2651],
[0.0320, 0.0593, 0.6062, 0.4983]])
tensor([[[0.4659, 0.8868, 0.9371],
[0.0049, 0.6908, 0.2511]],
[[0.2538, 0.2651, 0.0320],
[0.0593, 0.6062, 0.4983]]])
说明:
(1)torch.reshape()可以用于处理同样维度的数据,比如把torch.size([3,4])处理为torch.size([2,6])
(2)torch.reshape()可以用于处理不同维度的数据,比如可以把二维变成三维,-1代表根据二者的关系自动匹配该位置的维度.
7.view()
变换张量tensor的形状,类似于torch.reshape(),但没有torch.view()这一函数,用法为在变量后加.view(),个人感觉没有reshape好用
a = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12])
b= a.view((2,6))
c= a.view((-1,2,3))
print(a)
print(b)
print(c)
结果:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
8.expand()
函数返回张量在某一个维度扩展之后的张量,用法为在变量后加.expand(),个人感觉没有repeat好用
如果为.expand(-1,-1)则表示当前维度不进行扩充
a = torch.tensor([1,2,3,4])
b= a.expand((2,4))
print(b)
c = torch.tensor([[1],[2]])
d = c.expand((2,4))
print(d)
tensor([[1, 2, 3, 4],
[1, 2, 3, 4]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
9.torch.cat()
torch.cat是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接.
a = torch.ones((2,3))
b = 2*torch.ones((4,3))
c = 3*torch.ones((2,5))
d = torch.cat((a,b),dim=0)
e = torch.cat((a,c),dim=1)
print(a)
print(b)
print(c)
print(d)
print(e)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.]])
tensor([[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[1., 1., 1., 3., 3., 3., 3., 3.],
[1., 1., 1., 3., 3., 3., 3., 3.]])
说明:
(1)dim=0表示按行拼接,要保证列数的维度一致
(2)dim=1表示按列拼接,要保证行数的维度一致
暂且用到这些函数,其他的还没有遇到,先总结这么多,有一些写错的地方,会在后续的学习中进行更正;后续如果有其他用到的函数,也会进行更新。
边栏推荐
猜你喜欢
Matlab opens M file garbled solution
ABA问题遇到过吗,详细说以下,如何避免ABA问题
关于双亲委派机制和类加载的过程
强化学习系列(一):基本原理和概念
1. First knowledge of C language (1)
(original) make an electronic clock with LCD1602 display to display the current time on the LCD. The display format is "hour: minute: Second: second". There are four function keys K1 ~ K4, and the fun
7-7 7003 组合锁(PTA程序设计)
A piece of music composed by buzzer (Chengdu)
实验六 继承和多态
Mode 1 two-way serial communication is adopted between machine a and machine B, and the specific requirements are as follows: (1) the K1 key of machine a can control the ledi of machine B to turn on a
随机推荐
5. Download and use of MSDN
Poker game program - man machine confrontation
Mode 1 two-way serial communication is adopted between machine a and machine B, and the specific requirements are as follows: (1) the K1 key of machine a can control the ledi of machine B to turn on a
Beautified table style
Wei Pai: the product is applauded, but why is the sales volume still frustrated
渗透测试学习与实战阶段分析
[中国近代史] 第五章测验
7-14 错误票据(PTA程序设计)
为什么要使用Redis
A piece of music composed by buzzer (Chengdu)
The difference between cookies and sessions
5月27日杂谈
撲克牌遊戲程序——人機對抗
The latest tank battle 2022 - Notes on the whole development -2
Detailed explanation of redis' distributed lock principle
仿牛客技术博客项目常见问题及解答(二)
Leetcode. 3. Longest substring without repeated characters - more than 100% solution
canvas基础2 - arc - 画弧线
受检异常和非受检异常的区别和理解
7-6 矩阵的局部极小值(PTA程序设计)