当前位置:网站首页>Pytorch--- advanced chapter (function usage skills / precautions)
Pytorch--- advanced chapter (function usage skills / precautions)
2022-07-26 15:22:00 【hei_ hei_ hei_】
tensor.contiguous()
Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor: Mainly to assist pytorch Other functions in , Go back to the original tensor Deep copy data after changing latitude .- Common methods
contiguous Generally speaking, it is related to transpose,permute,view Use it with : Use transpose or permute After dimension transformation , call contiguous, Then it can be used view Deform the dimension , because view Operation requirements tensor It's continuous in memory ( Such as :tensor.contiguous().view() ), as follows :
x = torch.Tensor(2,3)
y = x.permute(1,0) # permute: A two-dimensional tensor Dimension transformation of , The function here is equivalent to transpose transpose
y.view(-1) # Report errors ,view Before use, you need to call contiguous() function
y = x.permute(1,0).contiguous()
y.view(-1) # OK
- Explanation
explain : stay PyTorch in , Some are right Tensor Your operation won't really change Tensor The content of , What has changed is just Tensor Index of byte position in . for example :
narrow(),view(),expand(),transpose(),permute()
These functions are changes in the latitude of the original data , It is a shallow copy of the original data . When doing these operations ,pytorch It does not create new tensors , Instead, some properties in the tensor are modified , But the two are shared in memory . Therefore, the implementation transpose() The change of the tensor after the operation will also change the original tensor , as follows :
x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0]) # 233
In this case ,x Is a continuous ,y Not continuous .y The layout method and re create a tensor The layout is different , When the y Called contiguous() After the function ,pytorch Will force a copy tensor, Make its layout continuous , It's also execution view Operating conditions .
[:,:,None] usage
- purpose : Used in None Add one dimension to the latitude , The new latitude is 1
x = torch.arange(12).reshape(3,4)
y = x[:,:,None]
print(x.shape,'\n',y.shape)
# torch.Size([3, 4])
# torch.Size([3, 4, 1])
Different shape Tensor calculation (+ - * /)
- explain : When the latitude length of the tensor is the same, no processing is done , For different latitudes and lengths , The corresponding latitude will be automatically filled to make it meet the same latitude before the operation . The filling method is copy .
x = torch.arange(4).reshape(1,4)
y = torch.arange(4).reshape(4,1)
print(x+y)
a = x.repeat(4,1)
b = y.repeat(1,4)
print(a+b)
register_buffer usage
stay pytorch One of the parameters of the model nn.Parameter() Defined , Including parameters in various modules , This will follow optimazer.step() to update ; The other is buffer, This will not be updated , amount to “ constant ”, Will not change in training
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
DropPath layer
- explain : if x The input tensor of , Its channel is [B,C,H,W], that drop_path It means in a Batch_size in , It's random drop_prob The sample of , Without going through the trunk , And the identity mapping is performed directly by the branch .
- The difference in dropout:dropout Is the random failure of neurons ; and DropPath It's right batch Random failure of samples in .
ps: You need to import external packagesfrom timm.models.layers import DropPath - Use
from timm.models.layers import DropPath
self.drop_path = DropPath(drop_prob) if drop_prob > 0. else nn.Identity()
x = x + self.drop_path(self.mlp(self.norm2(x)))
Indicates that there are some branches (batch Sample in ) Not pass norm and mlp, Direct identity transformation . That is, the residual error is added .
torch.roll()
torch.roll(input, shifts, dims=None) → Tensorexplain : Translate the specified tensor along the specified latitude places.
shifts:int or tuple. Indicates the number of positions moved along the specified latitude
dims: Express shifts Corresponding latitude . if shifts Is a tuple , be dims Need to correspond
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
# towards 0 Translation in the positive direction of latitude 1 A place
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
[1, 2],
[3, 4],
[5, 6]])
# towards 0 Translation in the negative direction of latitude 1 A place
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
[5, 6],
[7, 8],
[1, 2]])
# towards 0 Translation in the positive direction of latitude 2 A place ,1 Translation in the positive direction of latitude 1 A place
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
[8, 7],
[2, 1],
[4, 3]])
masked_fill()
Tensor.masked_fill(mask, value) → Tensor: take tensor in mask by true Replace the position of with value. This function does not change the original tensor, Go back to the changed tensor.mask For a tensor , And tensor In the same shape .
x = torch.arange(24).reshape(2,3,4)
y = x.masked_fill(x>10,10)
print(x,'\n\n',y)
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 10]],
# [[10, 10, 10, 10],
# [10, 10, 10, 10],
# [10, 10, 10, 10]]])
nn.Parameter()
- purpose : In essence, it is still a tensor(tensor Subclasses of ). When it is designated Module When the properties of the , Will be automatically added to the parameter list , And will appear in Parameters() In an iterator , Will be automatically optimized
- Use :
torch.nn.parameter.Parameter(data=None, requires_grad=True)perhapstorch.nn.Parameter(data=None, requires_grad=True): The parameter is one tensor, The type is floating point number
y = torch.arange(24).float() # It can be seen as initialization
x = nn.Parameter(y) # If in Module Will be automatically optimized
torch.meshgrid()
- purpose : Used to generate coordinate grids . Often used in drawing
- Use :
torch.meshgrid(*tensors, indexing=None)
# Two dimensional example
x = torch.arange(2) # Row coordinates , The length is 2
y = torch.arange(2,5) # Column coordinates , The length is 3
a1, a2 = torch.meshgrid(x,y) # Return to one tuple,2 individual tensor(2,3)
print(a1)
# tensor([[0, 0, 0],
# [1, 1, 1]])
print(a2)
# tensor([[2, 3, 4],
# [2, 3, 4]])
# 3D example
z = torch.arange(3,7) # The third coordinate , The length is 4
b1, b2, b3 = torch.meshgrid(x,y,z) # Return to one tuple,3 individual tensor(2,3,4)
print(b1) # Only the elements of the first dimension are different , amount to b1 = torch.arange(2).reshape(2,1,1).repeat(1,3,4)
# tensor([[[0, 0, 0, 0],
# [0, 0, 0, 0],
# [0, 0, 0, 0]],
# [[1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1]]])
tensor.detach()
- purpose : Back to a new tensor, Separated from the current calculation diagram , But it still points to the storage location of the original variable , The difference is that requires_grad by false, I got this tensor It is never necessary to calculate its gradient , No grad. If you continue to use this new tensor Calculate , Later, when we carry out back propagation , To this call detach() Of tensor Will stop , Can't continue to spread .
- Be careful : Back to tensor With primordial tensor Point to the same piece of memory , Therefore, the modification of the return tensor will also affect the original tensor .detached tensor Unable to derive , But primitive tensor Sure ; If changed detached tensor, Then the original tensor backward Will make mistakes .
- give an example : If we have two networks A,B, The two relationships are like this y=A(x),z=B(y) Now we want to use z.backward() for B Network parameters to find the gradient , But I don't want to ask A Gradient of network parameters , have access to detach.
# y=A(x), z=B(y) seek B The gradient of the parameter in , Don't beg A The gradient of the parameter in
y = A(x)
z = B(y.detach())
z.backward()
Reference resources :
pytorch Medium detach()、detach_()
pytorch-detach,detach_
边栏推荐
- Deep Packet Inspection Using Cuckoo Filter论文总结
- Parallel d-Pipeline: A Cuckoo Hashing Implementation for Increased Throughput论文总结
- How to translate academic documents?
- R语言使用lm函数构建多元回归模型(Multiple Linear Regression)、并根据模型系数写出回归方程、使用fitted函数计算出模型的拟合的y值(响应值)向量
- Continuous integration (II) introduction to the basic use of Jenkins
- DICOM学习资料收集
- R language ggplot2 visualization: use the ggballoonplot function of ggpubr package to visualize the balloon graph (visualize the contingency table composed of two classification variables), and config
- Sexy prime number (summer vacation daily question 1)
- 写综述,想用一个靠谱的整理文献的软件,有推荐的吗?
- R language ggplot2 visualization: use ggplot2 to visualize the scatter diagram, and use the theme of ggpubr package_ The pubclean function sets the theme without axis lines in the visual image
猜你喜欢

【五分钟Paper】基于参数化动作空间的强化学习

OpenGL learning diary 2 - shaders

QCF for deep packet inspection论文总结

Practical purchasing skills, purchasing methods of five bottleneck materials

怎样在nature上查文献?

装备制造业的变革时代,SCM供应链管理系统如何赋能装备制造企业转型升级

Deep Packet Inspection Using Quotient Filter论文总结

Cve-2022-33891 vulnerability recurrence

【基础】动态链接库/静态链接库的区别

Vs add settings for author information and time information
随机推荐
R语言ggplot2可视化:使用ggpubr包的ggdotplot函数可视化点阵图(dot plot)、设置add参数添加均值和标准差竖线、设置error.plot参数实际显示箱体
R语言wilcox.test函数比较两个非参数样本的总体的中心位置是否具有显著差异(如果两个样本数据是配对数据设置paired参数为TRUE)
OSPF and mGRE experiments
DICOM学习资料收集
Remote desktop on Jetson nano
不到一周我开发出了属于自己的知识共享平台
The leader took credit for it. I changed the variable name and laid him off
Where is the foreign literature needed to write the graduation thesis?
Huawei applications have called the checkappupdate interface. Why is there no prompt for version update in the application
生泰尔科技IPO被终止:曾拟募资5.6亿 启明与济峰资本是股东
FOC learning notes - coordinate transformation and simulation verification
示波器的使用
Practical purchasing skills, purchasing methods of five bottleneck materials
Write a summary, want to use a reliable software to sort out documents, is there any recommendation?
Use of oscilloscope
The most detailed patent application tutorial, teaching you how to apply for a patent
Character function and string function and memory function
Cve-2022-33891 Apache spark shell command injection vulnerability recurrence
Database expansion can also be so smooth, MySQL 100 billion level data production environment expansion practice
大学生如何申请实用新型专利?