当前位置:网站首页>torch.autograd.grad求二阶导数
torch.autograd.grad求二阶导数
2022-08-05 00:40:00 【鬼道2022】
1 用法介绍
pytorch中torch.autograd.grad函数主要用于计算并返回输出相对于输入的梯度总和,具体的参数作用如下所示:
torch.tril(input, diagonal=0, *, out=None) * \longrightarrow *Tensor
- outputs(sequence of Tensor):表示微分函数的输出
- inputs (sequence of Tensor):表示微分函数的输入
- grad_outputs (sequence of Tensor):表示“向量-雅克比矩阵”的向量
- retain_graph (bool, optional):表示是否需要将计算图释放掉,当计算二阶导数时需要设置为True
- create_graph (bool, optional):表示是否需要将梯度将会加入到计算图中,当计算高阶导数或者其他计算时会将其设置为需要设置为True
- allow_unused (bool, optional):表示是否只返回输入的梯度,而不返回其他叶子节点的梯度
2 实例讲解
以下给出了具体的二阶导数解析解的数学实例
给定一个向量 x = ( x 1 , x 2 ) ⊤ {\bf{x}}=(x_1,x_2)^{\top} x=(x1,x2)⊤,可以得到向量 y = ( y 1 , y 2 ) ⊤ = ( x 1 2 , x 2 2 ) ⊤ {\bf{y}}=(y_1,y_2)^{\top}=(x^2_1,x^2_2)^{\top} y=(y1,y2)⊤=(x12,x22)⊤。对向量 y {\bf{y}} y的元素求平均可以得到损失函数 l o s s 1 \mathrm{loss}_1 loss1为: l o s s 1 ( x ) = m e a n ( y ) = x 1 2 + x 2 2 2 \mathrm{loss}_1({\bf{x}})=\mathrm{mean}({\bf{y}})=\frac{x_1^2+x^2_2}{2} loss1(x)=mean(y)=2x12+x22向量 y {\bf{y}} y元素的分量分别对 x {\bf{x}} x求偏导,然后相加求平均得到损失函数 l o s s 2 \mathrm{loss}_2 loss2为 { h 1 ( x ) = ∂ y 1 ∂ x = ( 2 x 1 , 0 ) ⊤ h 2 ( x ) = ∂ y 2 ∂ x = ( 0 , 2 x 2 ) ⊤ , l o s s 2 ( x ) = m e a n ( h 1 ( x 1 ) − h 2 ( x 2 ) ) = x 1 − x 2 \left\{\begin{aligned}h_1({\bf{x}})&=\frac{\partial y_1}{\partial {\bf{x}}}=(2x_1,0)^{\top}\\h_2({\bf{x}})&=\frac{\partial y_2}{\partial {\bf{x}}}=(0,2x_2)^{\top}\end{aligned}\right.,\quad \mathrm{loss}_2({\bf{x}})=\mathrm{mean}(h_1({\bf{x}}_1)-h_2({\bf{x}}_2))=x_1-x_2 ⎩⎨⎧h1(x)h2(x)=∂x∂y1=(2x1,0)⊤=∂x∂y2=(0,2x2)⊤,loss2(x)=mean(h1(x1)−h2(x2))=x1−x2将损失函数 l o s s 1 \mathrm{loss}_1 loss1与损失函数 l o s s 2 \mathrm{loss}_2 loss2相加可以得到 l o s s ( x ) = l o s s 1 ( x ) + l o s s 2 ( x ) = x 1 2 + x 2 2 2 + x 1 − x 2 \mathrm{loss}({\bf{x}})=\mathrm{loss}_1({\bf{x}})+\mathrm{loss}_2({\bf{x}})=\frac{x_1^2+x_2^2}{2}+x_1-x_2 loss(x)=loss1(x)+loss2(x)=2x12+x22+x1−x2最终损失函数 l o s s \mathrm{loss} loss对向量 x {\bf{x}} x的偏导数为 ∂ l o s s ∂ x = ( x 1 + 1 , x 2 − 1 ) ⊤ \frac{\partial {\mathrm{loss}}}{\partial{ {\bf{x}}}}=(x_1+1,x_2-1)^{\top} ∂x∂loss=(x1+1,x2−1)⊤
以下为用pytorch实现二阶导数相对应的代码实例:
import torch
x = torch.tensor([5.0, 7.0], requires_grad=True)
y = x**2
loss1 = torch.mean(y)
h1 = torch.autograd.grad(y[0], x, retain_graph = True, create_graph=True)
h2 = torch.autograd.grad(y[1], x, retain_graph = True, create_graph=True)
loss2 = torch.mean(h1[0] - h2[0])
loss = loss1 + loss2
result = torch.autograd.grad(loss, x)
print(result)
当向量 x {\bf{x}} x取值为 ( 5 , 7 ) ⊤ (5,7)^{\top} (5,7)⊤时,根据数学解析解得到的二阶导数为 ( 6 , 6 ) ⊤ (6,6)^{\top} (6,6)⊤,对应的代码运行的实验结果也为 ( 6 , 6 ) (6,6) (6,6)。
边栏推荐
- 软件测试面试题:测试生命周期,测试过程分为几个阶段,以及各阶段的含义及使用的方法?
- 标识符、关键字、常量 和变量(C语言)
- E - Distance Sequence (前缀和优化dp
- 怎样进行在不改变主线程执行的时候,进行日志的记录
- Software testing interview questions: test life cycle, the test process is divided into several stages, and the meaning of each stage and the method used?
- 《WEB安全渗透测试》(28)Burp Collaborator-dnslog外带技术
- 网站最终产品页使用单一入口还是多入口?
- MongoDB construction and basic operations
- The method of freely controlling concurrency in the sync package in GO
- 软件测试面试题:系统测试的策略有?
猜你喜欢
随机推荐
软件测试面试题:系统测试的策略有?
软件测试面试题:软件都有多少种分类?
软件测试面试题:一套完整的测试应该由哪些阶段组成?
Software testing interview questions: What are the seven-layer network protocols?
leetcode:266. 回文全排列
Software Testing Interview Questions: Qualifying Criteria for Software Acceptance Testing?
Software testing interview questions: test life cycle, the test process is divided into several stages, and the meaning of each stage and the method used?
Software Testing Interview Questions: What's the Key to a Good Test Plan?
oracle create tablespace
Software Testing Interview Questions: What do you think about software process improvement? Is there something that needs improvement in the enterprise you have worked for? What do you expect the idea
数据类型-整型(C语言)
matlab中rcosdesign函数升余弦滚降成型滤波器
could not build server_names_hash, you should increase server_names_hash_bucket_size: 32
软件开发工具的技术要素
what?测试/开发程序员要被淘汰了?年龄40被砍到了32?一瞬间,有点缓不过神来......
tiup status
软件测试面试题:请你分别画出 OSI 的七层网络结构图和 TCP/IP 的四层结构图?
软件测试面试题:软件验收测试的合格通过准则?
怎样进行在不改变主线程执行的时候,进行日志的记录
"WEB Security Penetration Testing" (28) Burp Collaborator-dnslog out-band technology