当前位置:网站首页>[data processing of numpy and pytoch]
[data processing of numpy and pytoch]
2022-07-06 13:53:00 【I like the strengthened Xiaobai in Curie】
Numpy and Pytorch Data processing
Recently, I'm looking at Multi-Agent Reinforcement learning code , There is a lot about data splicing in the code 、 Dimension adjustment , Make the whole person dizzy and turn . This is mainly numpy and torch Two kits , In order to understand the code , Checked some relevant introductions of bloggers , Here is a systematic summary , Easy to use to find , Here is a summary of the syntax used in the code and some unfamiliar syntax .
One 、Numpy Related operations of
1.np.nonzero()
np.nonzero The function is numpy Is used to get an array array The position of the non-zero element ( Array index ) Function of .
a = [0,1,1,1,1,0,0,0,0]
a1 = [[0,1,1],[1,0,1],[1,1,0]]
b = np.nonzero(a)
b1 = np.nonzero(a1)
print(b,b[0])
print(b1,b1[0])
result :
(array([1, 2, 3, 4], dtype=int64),) [1 2 3 4]
(array([0, 0, 1, 1, 2, 2], dtype=int64), array([1, 2, 0, 2, 0, 1], dtype=int64)) [0 0 1 1 2 2]
explain :
(1) One dimensional list , Directly return the index of the corresponding non-zero element , The return type is tuple tuple;b[0] Returns a ndarray(4,) The data of , namely [1,2,3,4].
(2) 2 d list , Returns two tuples tuple, The corresponding line in front , The corresponding column behind .
That's ok : The first line has two non-zero elements , So go back to 0,0; Empathy , The second and third lines also have two non-zero elements , return 1,1 and 2,2.
Column : First line 1、2 Index number is a non-zero element , The second line 0、2 Element number is a non-zero element , The third line 0、1 Element number , So go back to array([1,2,0,2,0,1]),b1[0] Then return the result of the row
2.np.hstack() and np.vstack()
np.hstack() and np.vstack() yes numpy Method of splicing arrays in .
(1)np.hstack(): Tile horizontally , Lengthening
(2)np.vstack(): Stack vertically , More
arr1= np.array([1,2,3])
arr2 = np.array([4,5,6])
arr_h = np.hstack([arr1,arr2])
arr_v = np.vstack([arr1,arr2])
print(arr_h)
print(arr_v)
arr3 = np.array([[1,2],[3,4]])
arr4 = np.array([[5,6],[7,8]])
arr1_h = np.hstack([arr3,arr4])
arr1_v = np.vstack([arr3,arr4])
print(arr1_h)
print(arr1_v)
result :
[1 2 3 4 5 6]
[[1 2 3]
[4 5 6]]
[[1 2 5 6]
[3 4 7 8]]
[[1 2]
[3 4]
[5 6]
[7 8]]
explain :
(1)np.hstack() take [1,2,3] and [4,5,6] Tile horizontally [1,2,3,4,5,6]
(2)np.vstack() take [1,2,3] and [4,5,6] Stack vertically into [[1,2,3],[4,5,6]]
Two dimensional array :
(1)np.hstack() Tile each array horizontally ,[[1,2],[3,4]] and [[5,6],[7,8]] Tile horizontally [[1,2,5,6],[3,4,7,8]]
(2)np.vstack() Stack vertically , Stack into [[1,2],[3,4],[5,6],[7,8]]
Analysis of error reporting :
(1) Cannot tile horizontally
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6],[7,8]])
Want to tile horizontally , The number of rows in the first dimension should be consistent , This situation arr3 The number of lines is 3,arr4 The number of lines is 2, Horizontal splicing is not possible , But it can be stacked vertically , The result is :
[[1,2]
[3,4]
[5,6]
[5,6]
[7,8]]
(2) Cannot stack vertically
arr3 = np.array([[1,2],[3,4],[5,6]])
arr4 = np.array([[5,6,7],[6,7,8],[7,8,9]])
Want to stack vertically , The number of columns in the second dimension should be consistent , This situation arr3 The number of columns in is 2,arr4 The number of columns in is 3, Cannot stack vertically , But it can be tiled horizontally , The result is :
[[1,2,5,6,7]
[3,4,6,7,8]
[5,6,7,8,9]]
3.np.concatenate()
np.concatenate yes numpy Chinese vs array Function for splicing .
a = np.random.random((2,3))
b = np.random.random((4,3))
c = np.concatenate((a,b),axis=0)
d = np.random.random((2,3))
e = np.random.random((2,5))
f = np.concatenate((d,e),axis=1)
print(a)
print(b)
print(c)
print(d)
print(e)
print(f)
result :
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799]]
[[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.35054413 0.40166885 0.81047908],
[0.05835757 0.41320597 0.56587799],
[0.08120434 0.64616073 0.2518427 ],
[0.5507864 0.59255662 0.05841272],
[0.55273764 0.85560033 0.56225599],
[0.20997285 0.33741127 0.01441785]]
[[0.71757919 0.50921365 0.87668144],
[0.69402294 0.25790403 0.2260284 ]]
[[0.27418893 0.68954413 0.63393597 0.18947664 0.11354748],
[0.38312499 0.68509913 0.22171717 0.1272298 0.10387313]]
[[0.71757919 0.50921365 0.87668144 0.27418893 0.68954413 0.63393597, 0.18947664 0.11354748],
[0.69402294 0.25790403 0.2260284 0.38312499 0.68509913 0.22171717, 0.1272298 0.10387313]]
explain :
(1)axis=0, It means stitching by line , The dimensions of the columns are required to be consistent
(2)axis=1, Indicates splicing by column , The dimension of the row is required to be consistent
Two 、Pytorch Related operations of
1.torch.tensor()
torch.tensor() Is to turn the data into tensor data type , Yes numpy and list Can be converted , convenient pytorch Follow up training .
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(list1)
tensor2 = torch.tensor(arr1)
result :
tensor1:tensor([1, 2, 3, 4, 5, 6])
tensor2:tensor([1, 2, 3, 4, 5, 6])
2.torch.unsqueeze() and torch.squeeze()
torch.unsqueeze() and torch.squeeze() Is a function of resizing .
list1 = [1,2,3,4,5,6]
arr1 = np.array(list1)
tensor1 = torch.tensor(arr1)
tensor2 = tensor1.unsqueeze(0)
tensor3 = tensor2.squeeze(0)
print(tensor2)
print(tensor3)
result :
tensor([[1, 2, 3, 4, 5, 6]], dtype=torch.int32)
tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32)
explain :
torch.unsqueeze(0) Means to insert a dimension at a given position 1, Expand dimensions , Change the dimension to torch.size([1,6])
torch.squeeze(0) Indicates that the dimension in the given position is removed , Reducing dimensions .
3.torch.sum()
torch.tensor(a,dim) To sum elements ,a Is the operation object ,dim Dimension for summation .
a = torch.ones((2,3))
a1 = torch.sum(a,dim=0)
a2 = torch.sum(a,dim=0,keepdim=True)
b1 = torch.sum(a,dim=1)
b2 = torch.sum(a,dim=1,keepdim=True)
c1 = torch.sum(a)
print(a)
print(a1)
print(b1)
print(c1)
print(a2)
print(b2)
result :
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor(6.)
tensor([[2., 2., 2.]])
tensor([[3.],
[3.]])
explain :
(1) If not dim, That is to say, sum all numbers , by 6
(2)dim = 0, To sum up by lines ,[2,2,2],dim=0 This dimension is removed
(3)dim = 1, To sum by columns ,[3,3],dim=1 This dimension is removed
(4)keepdim = True, Express dim The dimension of will not be squeeze
for example :a2 Data becomes torch.size([1,3])
for example :b2 Data becomes torch.size([2,1])
4.torch.repeat()
torch.repeat() It is used to expand the tensor repeatedly .
(1) When there are only two parameters :( Repetition multiple of row , Repetition times of columns ),1 Means not to repeat
(2) When there are three parameters :( Repetition multiple of the number of channels , Repetition multiple of row , Column repetition times )
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = a.repeat(2,2)
c = a.repeat(2,2,2)
print(b)
print(c)
result :
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]])
tensor([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]],
[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[7, 8, 9, 7, 8, 9]]])
5.torch.ones_like() and torch.zeros_like()
torch.ones_like Functions and torch.zeros_like The basic function of a function is based on a given tensor , Generate the same shape as the whole 1 Or all 0 tensor , Here we use torch.ones_like For example .
a = torch.ones((2,3))
print(a)
b = torch.ones_like(a)
print(b)
result :
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
explain :torch.ones() If only one parameter is passed , The generated tensor is not a square matrix .
6.torch.reshape()
Transformation tensor tensor The shape of the , Note that both data types are tensors .
a = torch.rand(12)
b= torch.reshape(a,(3,4))
c= torch.reshape(a,(-1,2,3))
print(a)
print(b)
print(c)
result :
tensor([0.4659, 0.8868, 0.9371, 0.0049, 0.6908, 0.2511, 0.2538, 0.2651, 0.0320, 0.0593, 0.6062, 0.4983])
tensor([[0.4659, 0.8868, 0.9371, 0.0049],
[0.6908, 0.2511, 0.2538, 0.2651],
[0.0320, 0.0593, 0.6062, 0.4983]])
tensor([[[0.4659, 0.8868, 0.9371],
[0.0049, 0.6908, 0.2511]],
[[0.2538, 0.2651, 0.0320],
[0.0593, 0.6062, 0.4983]]])
explain :
(1)torch.reshape() It can be used to process data of the same dimension , For example torch.size([3,4]) Treated as torch.size([2,6])
(2)torch.reshape() It can be used to process data of different dimensions , For example, you can change two dimensions into three dimensions ,-1 Represents the dimension that automatically matches the location according to the relationship between the two .
7.view()
Transformation tensor tensor The shape of the , Be similar to torch.reshape(), But there is no torch.view() This function , The usage is to add .view(), I don't think so reshape To use
a = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12])
b= a.view((2,6))
c= a.view((-1,2,3))
print(a)
print(b)
print(c)
result :
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
8.expand()
Function returns the tensor after the tensor is extended in a certain dimension , The usage is to add .expand(), I don't think so repeat To use
If .expand(-1,-1) It means that the current dimension is not expanded
a = torch.tensor([1,2,3,4])
b= a.expand((2,4))
print(b)
c = torch.tensor([[1],[2]])
d = c.expand((2,4))
print(d)
tensor([[1, 2, 3, 4],
[1, 2, 3, 4]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
9.torch.cat()
torch.cat Is to put two tensors (tensor) Splice together ,cat yes concatenate It means , That's splicing .
a = torch.ones((2,3))
b = 2*torch.ones((4,3))
c = 3*torch.ones((2,5))
d = torch.cat((a,b),dim=0)
e = torch.cat((a,c),dim=1)
print(a)
print(b)
print(c)
print(d)
print(e)
result :
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.]])
tensor([[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[1., 1., 1., 3., 3., 3., 3., 3.],
[1., 1., 1., 3., 3., 3., 3., 3.]])
explain :
(1)dim=0 It means stitching by line , Ensure that the dimensions of the number of columns are consistent
(2)dim=1 Indicates splicing by column , Ensure that the dimensions of the number of rows are consistent
Let's use these functions for the time being , Others haven't been met yet , So much to sum up first , There are some mistakes , It will be corrected in the follow-up study ; If there are other functions used later , It will also be updated .
边栏推荐
猜你喜欢
The latest tank battle 2022 full development notes-1
Differences among fianl, finally, and finalize
记一次猫舍由外到内的渗透撞库操作提取-flag
Using spacedesk to realize any device in the LAN as a computer expansion screen
9. Pointer (upper)
Relationship between hashcode() and equals()
PriorityQueue (large root heap / small root heap /topk problem)
hashCode()与equals()之间的关系
8. C language - bit operator and displacement operator
Principles, advantages and disadvantages of two persistence mechanisms RDB and AOF of redis
随机推荐
实验七 常用类的使用(修正帖)
8. C language - bit operator and displacement operator
Zatan 0516
FAQs and answers to the imitation Niuke technology blog project (III)
A comprehensive summary of MySQL transactions and implementation principles, and no longer have to worry about interviews
简单理解ES6的Promise
Reinforcement learning series (I): basic principles and concepts
Custom RPC project - frequently asked questions and explanations (Registration Center)
一段用蜂鸣器编的音乐(成都)
MySQL锁总结(全面简洁 + 图文详解)
Principles, advantages and disadvantages of two persistence mechanisms RDB and AOF of redis
Read only error handling
It's never too late to start. The tramp transformation programmer has an annual salary of more than 700000 yuan
The difference between cookies and sessions
强化学习基础记录
1143_ SiCp learning notes_ Tree recursion
Simply understand the promise of ES6
7-1 输出2到n之间的全部素数(PTA程序设计)
实验五 类和对象
甲、乙机之间采用方式 1 双向串行通信,具体要求如下: (1)甲机的 k1 按键可通过串行口控制乙机的 LEDI 点亮、LED2 灭,甲机的 k2 按键控制 乙机的 LED1