当前位置:网站首页>pytorch---进阶篇(函数使用技巧/注意事项)
pytorch---进阶篇(函数使用技巧/注意事项)
2022-07-26 15:02:00 【hei_hei_hei_】
tensor.contiguous()
Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor:主要是为了辅助pytorch中其他函数,返回原始tensor改变纬度后的深拷贝数据。- 常用方法
contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形,因为view操作要求tensor在内存中是连续的(如:tensor.contiguous().view() ),如下:
x = torch.Tensor(2,3)
y = x.permute(1,0) # permute:二维tensor的维度变换,此处功能相当于转置transpose
y.view(-1) # 报错,view使用前需调用contiguous()函数
y = x.permute(1,0).contiguous()
y.view(-1) # OK
- 说明解释
说明:在PyTorch中,有一些对Tensor的操作不会真正改变Tensor的内容,改变的仅仅是Tensor中字节位置的索引。例如:
narrow(),view(),expand(),transpose(),permute()
这些函数是对原始数据纬度的变化,是对原始数据的浅拷贝。在执行这些操作时,pytorch并不会创建新的张量,而是修改了张量中的一些属性,但是二者在内存上是共享的。因此对执行transpose()操作后的张量的改变也会改变原始张量,如下:
x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0]) # 233
在这个例子中,x是连续的,y不是连续的。y的布局方式与重新创建一个tensor的布局方式是不一样的,当对y调用了contiguous()函数之后,pytorch会强制拷贝一份tensor,使得它的布局是连续的,这也是执行view操作的条件。
[:,:,None]用法
- 用途:用于在None纬度上增加一维,新增纬度为1
x = torch.arange(12).reshape(3,4)
y = x[:,:,None]
print(x.shape,'\n',y.shape)
# torch.Size([3, 4])
# torch.Size([3, 4, 1])
不同shape张量计算(+ - * /)
- 说明:当张量的纬度长度相同时不做任何处理,对于纬度长度不同的情况,会自动填充相应的纬度使其满足纬度相同之后再进行运算。填充方式为复制。
x = torch.arange(4).reshape(1,4)
y = torch.arange(4).reshape(4,1)
print(x+y)
a = x.repeat(4,1)
b = y.repeat(1,4)
print(a+b)
register_buffer用法
在pytorch中模型的参数一种nn.Parameter()定义的,包括各种模块中的参数,这种会随着optimazer.step()更新;另一种是buffer,这种不会更新,相当于“常数”,不会在训练中改变
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
DropPath层
- 说明:若x为输入的张量,其通道为[B,C,H,W],那么drop_path的含义为在一个Batch_size中,随机有drop_prob的样本,不经过主干,而直接由分支进行恒等映射。
- 区别于dropout:dropout是对神经元随机失效;而DropPath是对batch中的样本随机失效。
ps:需要导入外部包from timm.models.layers import DropPath - 使用
from timm.models.layers import DropPath
self.drop_path = DropPath(drop_prob) if drop_prob > 0. else nn.Identity()
x = x + self.drop_path(self.mlp(self.norm2(x)))
表明有一些分支(batch中的样本)不经过norm和mlp,直接进行恒等变换。也就是加入了残差。
torch.roll()
torch.roll(input, shifts, dims=None) → Tensor说明:对张量沿指定纬度平移指定的places。
shifts:int or tuple。表示延指定纬度移动的位置个数
dims:表示shifts对应的纬度。若shifts是元组,则dims需要相对应
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
# 向0纬度正方向平移1个位置
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
[1, 2],
[3, 4],
[5, 6]])
# 向0纬度负方向平移1个位置
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
[5, 6],
[7, 8],
[1, 2]])
# 向0纬度正方向平移2个位置,1纬度正方向平移1个位置
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
[8, 7],
[2, 1],
[4, 3]])
masked_fill()
Tensor.masked_fill(mask, value) → Tensor:将tensor中mask为true的位置替换为value。此函数不改变原始tensor,返回改变后的tensor。mask为一个张量,与tensor形状一致。
x = torch.arange(24).reshape(2,3,4)
y = x.masked_fill(x>10,10)
print(x,'\n\n',y)
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 10]],
# [[10, 10, 10, 10],
# [10, 10, 10, 10],
# [10, 10, 10, 10]]])
nn.Parameter()
- 用途:本质上仍是一个tensor(tensor的子类)。当它被指定为Module的属性时,会被自动地添加到参数列表中,并将出现在Parameters()迭代器中,会被自动优化
- 使用:
torch.nn.parameter.Parameter(data=None, requires_grad=True)或者torch.nn.Parameter(data=None, requires_grad=True):参数是一个tensor,类型为浮点数
y = torch.arange(24).float() # 可以看做初始化
x = nn.Parameter(y) # 如果在Module中则会被自动优化
torch.meshgrid()
- 用途:用于生成坐标网格。常用于作图
- 使用:
torch.meshgrid(*tensors, indexing=None)
#二维举例
x = torch.arange(2) # 行坐标,长度为2
y = torch.arange(2,5) # 列坐标,长度为3
a1, a2 = torch.meshgrid(x,y) # 返回一个tuple,2个tensor(2,3)
print(a1)
# tensor([[0, 0, 0],
# [1, 1, 1]])
print(a2)
# tensor([[2, 3, 4],
# [2, 3, 4]])
# 三维举例
z = torch.arange(3,7) # 第三维坐标,长度为4
b1, b2, b3 = torch.meshgrid(x,y,z) # 返回一个tuple,3个tensor(2,3,4)
print(b1) # 只有第一维的元素不同,相当于b1 = torch.arange(2).reshape(2,1,1).repeat(1,3,4)
# tensor([[[0, 0, 0, 0],
# [0, 0, 0, 0],
# [0, 0, 0, 0]],
# [[1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1]]])
tensor.detach()
- 用途: 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。如果继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播。
- 注意:返回的tensor与原始tensor指向同一片内存,因此对返回张量的修改也会影响原始张量。detached tensor无法求导,但原始tensor可以;如果更改了detached tensor,则原始张量的backward会出错。
- 举例:如果我们有两个网络 A,B, 两个关系是这样的 y=A(x),z=B(y)现在我们想用 z.backward() 来为 B网络的参数来求梯度,但是又不想求 A 网络参数的梯度,可以使用detach。
# y=A(x), z=B(y) 求B中参数的梯度,不求A中参数的梯度
y = A(x)
z = B(y.detach())
z.backward()
边栏推荐
- 外文文献查找技巧方法有哪些
- R语言使用lattice包中的histogram函数可视化直方图(histogram plot)、col参数自定义填充色、type参数自定义直方图显示模式(计数或者比例)
- 2023餐饮业展,中国餐饮供应链展,江西餐饮食材展2月举办
- Parallel d-pipeline: a cuckoo hashing implementation for increased throughput
- 下一代视觉Transformer:解锁CNN和Transformer正确结合方法
- JMeter distributed
- The practice of software R & D should start from the design
- R language tests the significance of correlation coefficient: use Cor The test function calculates the value and confidence interval of the correlation coefficient and its statistical significance (if
- Notes (5)
- The IPO of shengtaier technology was terminated: it was planned to raise 560million yuan, and Qiming and Jifeng capital were shareholders
猜你喜欢

cs224w(图机器学习)2021冬季课程学习笔记5

【基础】动态链接库/静态链接库的区别

带你熟悉云网络的“电话簿”:DNS

双屏协作效率翻倍 灵耀X双屏Pro引领双屏科技新潮流

Vs add settings for author information and time information

Data permissions should be designed like this, yyyds!

如何查找国内各大学本科学位论文?

领导抢功劳,我改个变量名让他下岗了

If food manufacturing enterprises want to realize intelligent and collaborative supplier management, it is enough to choose SRM supplier system

晋拓股份上交所上市:市值26亿 张东家族企业色彩浓厚
随机推荐
Advanced Qt development: how to fit the window width and height when displaying (fitwidth+fitheight)
R语言使用lm函数构建多元回归模型(Multiple Linear Regression)、并根据模型系数写出回归方程、使用fitted函数计算出模型的拟合的y值(响应值)向量
兆骑科创高端人才项目引进落地,双创大赛承办,线上直播路演
Google tries to introduce password strength indicator for chromeos to improve online security
Minecraft 1.16.5 module development (52) modify the original biological trophy (lot table)
如何进行学术文献翻译?
Next generation visual transformer: Unlocking the correct combination of CNN and transformer
jetson nano上远程桌面
Leetcode summary
jmeter分布式
Everything is available Cassandra: the fairy database behind Huawei tag
领导抢功劳,我改个变量名让他下岗了
[basic] the difference between dynamic link library and static link library
Remote desktop on Jetson nano
9. Learn MySQL delete statement
如何查询外文文献?
Cs224w (Figure machine learning) 2021 winter course learning notes 5
R language ggplot2 visualization: use the ggdotplot function of ggpubr package to visualize dot plot, set the add parameter to add the mean and standard deviation vertical lines, and set the error.plo
Unity URP入门实战
FOC learning notes - coordinate transformation and simulation verification