当前位置:网站首页>【Numpy和Pytorch的数据处理】
【Numpy和Pytorch的数据处理】
2022-07-06 09:22:00 【喜欢库里的强化小白】
Numpy和Pytorch数据处理
最近在看多智能体强化学习代码,代码内有很多关于数据拼接、维度调整的操作,搞的整个人头晕转向。这里主要是numpy和torch两个工具包,为了读懂代码,查了一些博主的有关介绍,在此进行系统的总结,方便使用时查找,这里只总结代码里使用的语法和一些不太熟悉的语法。
一、Numpy的有关操作
1.np.nonzero()
np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。
a = [0,1,1,1,1,0,0,0,0]
a1 = [[0,1,1],[1,0,1],[1,1,0]]
b = np.nonzero(a)
b1 = np.nonzero(a1)
print(b,b[0])
print(b1,b1[0])
结果:
(array([1, 2, 3, 4], dtype=int64),) [1 2 3 4]
(array([0, 0, 1, 1, 2, 2], dtype=int64), array([1, 2, 0, 2, 0, 1], dtype=int64)) [0 0 1 1 2 2]
说明:
(1)一维列表,直接返回对应非零元素的索引,返回类型是元组tuple;b[0]则返回一个ndarray(4,)的数据,即[1,2,3,4].
(2)二维列表,返回两个元组tuple,前边对应行,后边对应列.
行:第一行有两个非零元素,故返回0,0;同理,第二行和第三行也有两个非零元素,返回1,1和2,2.
列:第一行的1、2号索引为非零元素,第二行0、2号元素为非零元素,第三行0、1号元素,故返回array([1,2,0,2,0,1]),b1[0]则返回行的结果
2.np.hstack()和np.vstack()
np.hstack()和np.vstack()是numpy中拼接数组的方法.
(1)np.hstack():水平方向上平铺,变长
(2)np.vstack():垂直方向上堆叠,变多
arr1= np.array([1,2,3])
arr2 = np.array([4,5,6])
arr_h = np.hstack([arr1,arr2])
arr_v = np.vstack([arr1,arr2])
print(arr_h)
print(arr_v)
arr3 = np.array([[1,2],[3,4]])
arr4 = np.array([[5,6],[7,8]])
arr1_h = np.hstack([arr3,arr4])
arr1_v = np.vstack([arr3,arr4])
print(arr1_h)
print(arr1_v)
结果:
[1 2 3 4 5 6]
[[1 2 3]
[4 5 6]]
[[1 2 5 6]
[3 4 7 8]]
[[1 2]
[3 4]
[5 6]
[7 8]]
说明:
(1)np.hstack()将[1,2,3]和[4,5,6]水平平铺成[1,2,3,4,5,6]
(2)np.vstack()将[1,2,3]和[4,5,6]垂直堆叠成[[1,2,3],[4,5,6]]
二维数组:
(1)np.hstack()将每个数组分别在水平方向平铺,[[1,2],[3,4]]和[[5,6],[7,8]]水平平铺成[[1,2,5,6],[3,4,7,8]]
(2)np.vstack()则在垂直方向堆叠,堆叠成[[1,2],[3,4],[5,6],[7,8]]
报错情况分析:
(1)无法水平平铺
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6],[7,8]])
想要水平平铺,第一维度的行数应该一致,这种情况arr3的行数为3,arr4的行数为2,无法进行水平拼接,但可以垂直堆叠,结果为:
[[1,2]
[3,4]
[5,6]
[5,6]
[7,8]]
(2)无法垂直堆叠
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6,7],[6,7,8],[7,8,9]])
想要垂直堆叠,第二维度的列数应该一致,这种情况arr3的列数为2,arr4的列数为3,无法进行垂直堆叠,但可以水平平铺,结果为:
[[1,2,5,6,7]
[3,4,6,7,8]
[5,6,7,8,9]]
3.np.concatenate()
np.concatenate 是numpy中对array进行拼接的函数.
a = np.random.random((2,3))
b = np.random.random((4,3))
c = np.concatenate((a,b),axis=0)
d = np.random.random((2,3))
e = np.random.random((2,5))
f = np.concatenate((d,e),axis=1)
print(a)
print(b)
print(c)
print(d)
print(e)
print(f)
结果:
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799]]
[[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799],
[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.71757919 0.50921365 0.87668144],
[0.69402294 0.25790403 0.2260284 ]]
[[0.27418893 0.68954413 0.63393597 0.18947664 0.11354748],
[0.38312499 0.68509913 0.22171717 0.1272298 0.10387313]]
[[0.71757919 0.50921365 0.87668144 0.27418893 0.68954413 0.63393597, 0.18947664 0.11354748],
[0.69402294 0.25790403 0.2260284 0.38312499 0.68509913 0.22171717, 0.1272298 0.10387313]]
说明:
(1)axis=0,表示按行拼接,要求列的维度一致
(2)axis=1,表示按列拼接,要求行的维度一致
二、Pytorch的有关操作
1.torch.tensor()
torch.tensor()是为了将数据转为会tensor数据类型,对numpy和list都可以进行转换,方便pytorch进行后续的训练。
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(list1)
tensor2 = torch.tensor(arr1)
结果:
tensor1:tensor([1, 2, 3, 4, 5, 6])
tensor2:tensor([1, 2, 3, 4, 5, 6])
2.torch.unsqueeze()和torch.squeeze()
torch.unsqueeze()和torch.squeeze()是进行调整尺寸的函数.
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(arr1)
tensor2 = tensor1.unsqueeze(0)
tensor3 = tensor2.squeeze(0)
print(tensor2)
print(tensor3)
结果:
tensor([[1, 2, 3, 4, 5, 6]], dtype=torch.int32)
tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32)
说明:
torch.unsqueeze(0)表示在既定位置插入维度1,扩展维度,将维度变为torch.size([1,6])
torch.squeeze(0)表示在既定位置中的维度除去,缩减维度.
3.torch.sum()
torch.tensor(a,dim)为求和元素,a为操作对象,dim为求和的维度.
a = torch.ones((2,3))
a1 = torch.sum(a,dim=0)
a2 = torch.sum(a,dim=0,keepdim=True)
b1 = torch.sum(a,dim=1)
b2 = torch.sum(a,dim=1,keepdim=True)
c1 = torch.sum(a)
print(a)
print(a1)
print(b1)
print(c1)
print(a2)
print(b2)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor(6.)
tensor([[2., 2., 2.]])
tensor([[3.],
[3.]])
说明:
(1)如果不加dim,即表示所有数求和,为6
(2)dim = 0,表示按行求和,[2,2,2],dim=0这个维度被去掉
(3)dim = 1,表示按列求和,[3,3],dim=1这个维度被去掉
(4)keepdim = True,表示dim的维度不会被squeeze
例如:a2数据变为了torch.size([1,3])
例如:b2数据变为了torch.size([2,1])
4.torch.repeat()
torch.repeat()用于对张量进行重复扩充.
(1)当参数只有两个时:(行的重复倍数,列的重复倍数),1表示不重复
(2)当参数有三个时:(通道数的重复倍数,行的重复倍数,列重复倍数)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = a.repeat(2,2)
c = a.repeat(2,2,2)
print(b)
print(c)
结果:
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]])
tensor([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]],
[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]]])
5.torch.ones_like()和torch.zeros_like()
torch.ones_like函数和torch.zeros_like函数的基本功能是根据给定张量,生成与其形状相同的全1或全0张量,这里以torch.ones_like为例.
a = torch.ones((2,3))
print(a)
b = torch.ones_like(a)
print(b)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
说明:torch.ones()如果只传一个参数,生成的张量不是方阵.
6.torch.reshape()
变换张量tensor的形状,注意两个数据类型都是张量.
a = torch.rand(12)
b= torch.reshape(a,(3,4))
c= torch.reshape(a,(-1,2,3))
print(a)
print(b)
print(c)
结果:
tensor([0.4659, 0.8868, 0.9371, 0.0049, 0.6908, 0.2511, 0.2538, 0.2651, 0.0320, 0.0593, 0.6062, 0.4983])
tensor([[0.4659, 0.8868, 0.9371, 0.0049],
[0.6908, 0.2511, 0.2538, 0.2651],
[0.0320, 0.0593, 0.6062, 0.4983]])
tensor([[[0.4659, 0.8868, 0.9371],
[0.0049, 0.6908, 0.2511]],
[[0.2538, 0.2651, 0.0320],
[0.0593, 0.6062, 0.4983]]])
说明:
(1)torch.reshape()可以用于处理同样维度的数据,比如把torch.size([3,4])处理为torch.size([2,6])
(2)torch.reshape()可以用于处理不同维度的数据,比如可以把二维变成三维,-1代表根据二者的关系自动匹配该位置的维度.
7.view()
变换张量tensor的形状,类似于torch.reshape(),但没有torch.view()这一函数,用法为在变量后加.view(),个人感觉没有reshape好用
a = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12])
b= a.view((2,6))
c= a.view((-1,2,3))
print(a)
print(b)
print(c)
结果:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
8.expand()
函数返回张量在某一个维度扩展之后的张量,用法为在变量后加.expand(),个人感觉没有repeat好用
如果为.expand(-1,-1)则表示当前维度不进行扩充
a = torch.tensor([1,2,3,4])
b= a.expand((2,4))
print(b)
c = torch.tensor([[1],[2]])
d = c.expand((2,4))
print(d)
tensor([[1, 2, 3, 4],
[1, 2, 3, 4]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
9.torch.cat()
torch.cat是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接.
a = torch.ones((2,3))
b = 2*torch.ones((4,3))
c = 3*torch.ones((2,5))
d = torch.cat((a,b),dim=0)
e = torch.cat((a,c),dim=1)
print(a)
print(b)
print(c)
print(d)
print(e)
结果:
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.]])
tensor([[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[1., 1., 1., 3., 3., 3., 3., 3.],
[1., 1., 1., 3., 3., 3., 3., 3.]])
说明:
(1)dim=0表示按行拼接,要保证列数的维度一致
(2)dim=1表示按列拼接,要保证行数的维度一致
暂且用到这些函数,其他的还没有遇到,先总结这么多,有一些写错的地方,会在后续的学习中进行更正;后续如果有其他用到的函数,也会进行更新。
边栏推荐
- Why use redis
- Zatan 0516
- [中国近代史] 第六章测验
- 2022泰迪杯数据挖掘挑战赛C题思路及赛后总结
- SRC挖掘思路及方法
- Service ability of Hongmeng harmonyos learning notes to realize cross end communication
- 使用Spacedesk实现局域网内任意设备作为电脑拓展屏
- Record a penetration of the cat shed from outside to inside. Library operation extraction flag
- 编写程序,模拟现实生活中的交通信号灯。
- 深度强化文献阅读系列(一):Courier routing and assignment for food delivery service using reinforcement learning
猜你喜欢
FAQs and answers to the imitation Niuke technology blog project (III)
2. Preliminary exercises of C language (2)
It's never too late to start. The tramp transformation programmer has an annual salary of more than 700000 yuan
4. Branch statements and loop statements
.Xmind文件如何上传金山文档共享在线编辑?
FAQs and answers to the imitation Niuke technology blog project (II)
3. C language uses algebraic cofactor to calculate determinant
Difference and understanding between detected and non detected anomalies
MySQL事务及实现原理全面总结,再也不用担心面试
一段用蜂鸣器编的音乐(成都)
随机推荐
Detailed explanation of redis' distributed lock principle
[面试时]——我如何讲清楚TCP实现可靠传输的机制
5. Function recursion exercise
Why use redis
canvas基础1 - 画直线(通俗易懂)
C语言入门指南
QT meta object qmetaobject indexofslot and other functions to obtain class methods attention
Implementation of count (*) in MySQL
7-1 输出2到n之间的全部素数(PTA程序设计)
使用Spacedesk实现局域网内任意设备作为电脑拓展屏
【九阳神功】2020复旦大学应用统计真题+解析
实验五 类和对象
[modern Chinese history] Chapter 9 test
7-11 机工士姆斯塔迪奥(PTA程序设计)
7-9 制作门牌号3.0(PTA程序设计)
4. Binary search
简单理解ES6的Promise
hashCode()与equals()之间的关系
7-6 矩阵的局部极小值(PTA程序设计)
【九阳神功】2017复旦大学应用统计真题+解析