当前位置:网站首页>Summary of common pytoch APIs
Summary of common pytoch APIs
2022-06-30 08:32:00 【Wu lele~】
List of articles
Preface
This paper is mainly used to record pytorch Commonly used API How to use and precautions . quite a lot api stay pytorch The official website of has given a detailed explanation , But I also want to organize my own . This article is updated from time to time .
1、torch.sum()
1.1 torch.sum(input, dtype=None) --> Tensor
First Will tensor input convert to dtype Format , After the input Sum all elements in the , Returns a one-dimensional tensor .
a = torch.tensor([[1.,3.],[2.,4.]])
b = torch.sum(a,dtype=torch.int) # tensor(10, dtype=torch.int32)
1.2 torch.sum(input, dim, keepdim) --> Tensor
a、 This api There are many tutorials about drawing pictures on the Internet , Easy to understand but easy to forget . therefore , Just introduce a memorable way : First of all get input Of shape, such as (2,3), dimension dim=0 The position of the corresponds to 2, dimension dim=1 The position of the corresponds to 3; Now we have to be right dim=0 Sum up , Then after the sum shape by (1,3), Also is to dim=0 The dimension of is compressed into 1 dimension , Later, if keepdim=TRUE, In the end shape by (1,3), Otherwise (3,).
a = torch.tensor([[1,2,3],[4,5,6]]) # a.shape: [2,3]
a1 = torch.sum(a, dim=0, keepdim=True) # Yes dim=0 Sum up ,keepdim=T, so a1.shape: torch.Size([1, 3]) a1: tensor([[5, 7, 9]])
a2 = torch.sum(a, dim=0, keepdim=False) # Yes dim=0 Sum up ,keepdim=F, so a2.shape: torch.Size([3]) a2: tensor([5, 7, 9])
In a word , In which dim The upper summation will result in dim Pressed into 1!
b、 in addition , In use , Will see dim It's a list. such as dim=[1,0], Empathy , Traverse list Medium element , First the dim=1 Compressed into 1 After the dim=0 Compressed into 1. And in two dim Execute twice respectively on sum The effect is the same .
a = torch.tensor([[1,2,3],[4,5,6]]) # a.shape: [2,3]
a_list = torch.sum(a, dim=(1,0)) # tensor(21)
a1 = torch.sum(a,dim=1) # tensor([ 6, 15])
a2 = torch.sum(a1,dim=0)
print(a2 == a_list) # tensor(True)
2、torch.repeat()
This function will open up new memory . Take the two-dimensional coordinate combination of each pixel in the generated image as an example to understand the function :
Receive two horizontal and vertical coordinates :x and y, Aim to generate red coordinate values : namely x Ultimately for [0,1,2,0,1,2],y Ultimately for [0,0,0,1,1,1]. You can go through the following repeat Realization .
x = torch.tensor([0,1,2]) # size = (3)
y = torch.tensor([0,1]) # size = (2)
xx = x.repeat(len(y)) # size: 3 --> 3*2 =6
# size: (2) --> (2,1) --> Repeat on the first dimension or (2,1) -->
# Repeat on the second dimension 3 Time (2,1*3) --> .view(6)
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
print(xx, yy) # xx = [0,1,2,0,1,2] yy = [0,0,0,1,1,1]
among xx The generation process of is easy to understand : Direct will x repeat len(y) Time . here yy Make me simple to say :
because yy Of repeat yes (1,3) That is dimension 0 Repeat on 1 Time , In dimension 1 Repeat on 3 Time , It needs to be changed in turn yy. The key is to identify each step first repeat Post tensor size What has become of , Then we can deduce the elements in the tensor .
3、torch.expand()
4、torch.max()
4.1 torch.max(input)
return input The largest value in a tensor . Code example :
a = torch.randn(2,3)
print(torch.max(a)) # tensor(2.242)
4.2 torch.max(input, dim, keepdim)
The api analogy torch.sum Understand , Namely, in accordance with the dim Dimension to return input And the index corresponding to the maximum value .keepdim=False Indicates that the output tensor does not hold the sum input Same dimension . Code example :
a = torch.Tensor([[1,2,3,4],[8,7,6,5]])
# shape change : [2,4] --> [1,4] --> [4]
max_value, argmax = torch.max(a, dim=0, keepdim=False)
# tensor([8., 7., 6., 5.]) tensor([1, 1, 1, 1])
print(max_value, argmax)
5、torch.Tensor.new_full()
Tensor.new_full(size, fill_value, dtype=None, device=None, requires_grad=False) → Tensor
The api Used to return the specified size And filled with fill_value One of the Tensor. As shown in the following code : Default Return and a The same type device Of Tensor. Usually, you only need to specify the first two parameters .
a = torch.ones((2,), dtype= torch.long) #tensor([1, 1])
b = a.new_full((3,4), 666, dtype= torch.long, requires_grad=False)
print(b)
6、 Slice details
Suppose you have a two-dimensional tensor , Perform the following two slices respectively , Pay attention to the changes in the obtained dimensions :
a = torch.Tensor([[1,2,3,4],[5,6,7,8]]) # [2,4]
b = a[:,0] # [2]
c = a[:,0::4] # [2,1]
Sliced with double quotation marks will have one more dimension .
7、torch.nonzero
This is to find out the Central Africa of the tensor 0 Index of elements .as_tuple Parameter controls whether the return is in the form of Yuanzu . It's usually False.
a = torch.Tensor([-1,0,1,1,0]) # Construct a one-dimensional tensor
ind1 = torch.nonzero(a, as_tuple=False) # tensor([[0], [2], [3]])
ind2 = torch.nonzero(a>0, as_tuple=False) # tensor([[2], [3]])
ind3 = torch.nonzero(a, as_tuple=True) # (tensor([0, 2, 3]),)
ind4 = torch.nonzero(a>0, as_tuple=True) # (tensor([2, 3]),)
b = torch.Tensor([[1,0,0,0],[0,0,2,0]]) # Construct a two-dimensional tensor
id1 = torch.nonzero(b, as_tuple=False) # tensor([[0, 0], [1, 2]])
id2 = torch.nonzero(b, as_tuple=True) # (tensor([0, 1]), tensor([0, 2]))
8、torch.randperm(n)
This function returns a [0,n-1] The random sequence of ,n It's a int.
a = torch.Tensor([2,1,3,5,6])
random_id = torch.randperm(a.numel()) # tensor([0, 1, 3, 2, 4])
9、unique
The tensor is Take set operation .
output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
#output: tensor([1, 2, 3])
output, inverse_indices = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted = True, return_inverse = True)
# output: tensor([1, 2, 3])
# inverse_indices: tensor([0, 2, 1, 2])
output, inverse_indices = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted = True, return_inverse = True)
# output: tensor([1, 2, 3])
# inverse_indices: tensor([[0, 2],[1, 2]])
sorted = True Controls whether the returned set tensor is in ascending order ; and return_inverse=True, The tensor of one more index is returned . This tensor represents : The index of each element in the set tensor in the original tensor . such as 1 In the set tensor 0 The location ;2 In set tensor 1 The location .
10、gather
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
input It's the input vector ;dim That is, according to which dimension ;index Is the tensor that provides the index . The internal logic of this function is as follows :
out[i][j]= input[index[i][j]][j] # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1
Generally speaking ,out Of shape and index Of shape It's consistent . According to dim stay index After searching the index from input Extract the corresponding elements .
11、cat
torch.cat(tensors, dim=0, *, out=None) → Tensor
tensors It's a Yuanzu ;dim It's the dimension .
边栏推荐
- Pycharm Dlib library installation
- layer.open 当传值为数组或值太长时处理方法
- Qqquickpainteditem implements graffiti program drawing board
- File upload component on success event, add custom parameters
- Redis设计与实现(三)| 服务器与客户端的交互(事件IO模型)
- MIME type Encyclopedia
- Flink SQL 自定义 Connector
- 一次cpu 跌底排查
- 国债逆回购绝对安全吗 网上怎么开户
- C # listbox how to get the selected content (search many invalid articles)
猜你喜欢
Transformer architecture understanding
【NVMe2.0b 14-1】Abort、Asynchronous Event Request、Capacity Management command
Cesium learning notes (V) custom geometry and appearance
【NVMe2.0b 14-8】Set Features(下篇)
[data analysis and display]
[JUC series] overview of fork/join framework
【kotlin 协程】万字协程 一篇完成kotlin 协程进阶
涂鸦Wi-Fi&BLE SoC开发幻彩灯带
Cesium learning notes (IV) visual image & Terrain
【NVMe2.0b 14-5】Firmware Download/Commit command
随机推荐
Redis design and Implementation (V) | sentinel sentry
2021-02-22
Environment configuration of ROS Aubo manipulator
Wechat official account third-party platform development, zero foundation entry. I want to teach you
Rendering engine development
【NVMe2.0b 14-4】Directive Send/Receive command
云服务器上部署仿牛客网项目
mysql基础入门 day3 动力节点[老杜]课堂笔记
This point in JS
QT event cycle
Cesium learning notes (II) uploading data using rest API
Oracle expansion table space installed in docker
Cesium learning notes (V) custom geometry and appearance
[nvme2.0b 14-8] set features (Part 2)
Unit Test
Flink Exception -- No ExecutorFactory found to execute the application
Redis设计与实现(一)| 数据结构 & 对象
[untitled]
Wechat applet reports errors using vant web app
Flink SQL 自定义 Connector