当前位置:网站首页>torch.autograd.grad求二阶导数
torch.autograd.grad求二阶导数
2022-08-05 00:40:00 【鬼道2022】
1 用法介绍
pytorch中torch.autograd.grad函数主要用于计算并返回输出相对于输入的梯度总和,具体的参数作用如下所示:
torch.tril(input, diagonal=0, *, out=None) * \longrightarrow *Tensor
- outputs(sequence of Tensor):表示微分函数的输出
- inputs (sequence of Tensor):表示微分函数的输入
- grad_outputs (sequence of Tensor):表示“向量-雅克比矩阵”的向量
- retain_graph (bool, optional):表示是否需要将计算图释放掉,当计算二阶导数时需要设置为True
- create_graph (bool, optional):表示是否需要将梯度将会加入到计算图中,当计算高阶导数或者其他计算时会将其设置为需要设置为True
- allow_unused (bool, optional):表示是否只返回输入的梯度,而不返回其他叶子节点的梯度
2 实例讲解
以下给出了具体的二阶导数解析解的数学实例
给定一个向量 x = ( x 1 , x 2 ) ⊤ {\bf{x}}=(x_1,x_2)^{\top} x=(x1,x2)⊤,可以得到向量 y = ( y 1 , y 2 ) ⊤ = ( x 1 2 , x 2 2 ) ⊤ {\bf{y}}=(y_1,y_2)^{\top}=(x^2_1,x^2_2)^{\top} y=(y1,y2)⊤=(x12,x22)⊤。对向量 y {\bf{y}} y的元素求平均可以得到损失函数 l o s s 1 \mathrm{loss}_1 loss1为: l o s s 1 ( x ) = m e a n ( y ) = x 1 2 + x 2 2 2 \mathrm{loss}_1({\bf{x}})=\mathrm{mean}({\bf{y}})=\frac{x_1^2+x^2_2}{2} loss1(x)=mean(y)=2x12+x22向量 y {\bf{y}} y元素的分量分别对 x {\bf{x}} x求偏导,然后相加求平均得到损失函数 l o s s 2 \mathrm{loss}_2 loss2为 { h 1 ( x ) = ∂ y 1 ∂ x = ( 2 x 1 , 0 ) ⊤ h 2 ( x ) = ∂ y 2 ∂ x = ( 0 , 2 x 2 ) ⊤ , l o s s 2 ( x ) = m e a n ( h 1 ( x 1 ) − h 2 ( x 2 ) ) = x 1 − x 2 \left\{\begin{aligned}h_1({\bf{x}})&=\frac{\partial y_1}{\partial {\bf{x}}}=(2x_1,0)^{\top}\\h_2({\bf{x}})&=\frac{\partial y_2}{\partial {\bf{x}}}=(0,2x_2)^{\top}\end{aligned}\right.,\quad \mathrm{loss}_2({\bf{x}})=\mathrm{mean}(h_1({\bf{x}}_1)-h_2({\bf{x}}_2))=x_1-x_2 ⎩⎨⎧h1(x)h2(x)=∂x∂y1=(2x1,0)⊤=∂x∂y2=(0,2x2)⊤,loss2(x)=mean(h1(x1)−h2(x2))=x1−x2将损失函数 l o s s 1 \mathrm{loss}_1 loss1与损失函数 l o s s 2 \mathrm{loss}_2 loss2相加可以得到 l o s s ( x ) = l o s s 1 ( x ) + l o s s 2 ( x ) = x 1 2 + x 2 2 2 + x 1 − x 2 \mathrm{loss}({\bf{x}})=\mathrm{loss}_1({\bf{x}})+\mathrm{loss}_2({\bf{x}})=\frac{x_1^2+x_2^2}{2}+x_1-x_2 loss(x)=loss1(x)+loss2(x)=2x12+x22+x1−x2最终损失函数 l o s s \mathrm{loss} loss对向量 x {\bf{x}} x的偏导数为 ∂ l o s s ∂ x = ( x 1 + 1 , x 2 − 1 ) ⊤ \frac{\partial {\mathrm{loss}}}{\partial{ {\bf{x}}}}=(x_1+1,x_2-1)^{\top} ∂x∂loss=(x1+1,x2−1)⊤
以下为用pytorch实现二阶导数相对应的代码实例:
import torch
x = torch.tensor([5.0, 7.0], requires_grad=True)
y = x**2
loss1 = torch.mean(y)
h1 = torch.autograd.grad(y[0], x, retain_graph = True, create_graph=True)
h2 = torch.autograd.grad(y[1], x, retain_graph = True, create_graph=True)
loss2 = torch.mean(h1[0] - h2[0])
loss = loss1 + loss2
result = torch.autograd.grad(loss, x)
print(result)
当向量 x {\bf{x}} x取值为 ( 5 , 7 ) ⊤ (5,7)^{\top} (5,7)⊤时,根据数学解析解得到的二阶导数为 ( 6 , 6 ) ⊤ (6,6)^{\top} (6,6)⊤,对应的代码运行的实验结果也为 ( 6 , 6 ) (6,6) (6,6)。
边栏推荐
- 00、数组及字符串常用的 API(详细剖析)
- 电赛必备技能___定时ADC+DMA+串口通信
- 电子行业MES管理系统的主要功能与用途
- Software Testing Interview Questions: About Automated Testing Tools?
- Raw and scan of gorm
- 软件测试面试题:测试用例通常包括那些内容?
- Zombie and orphan processes
- Software Testing Interview Questions: What is Software Testing?The purpose and principle of software testing?
- what?测试/开发程序员要被淘汰了?年龄40被砍到了32?一瞬间,有点缓不过神来......
- BC(转)[js]js计算两个时间相差天数
猜你喜欢
随机推荐
软件开发工具的技术要素
QSunSync 七牛云文件同步工具,批量上传
数据类型及输入输出初探(C语言)
Software testing interview questions: What is the difference between load testing, capacity testing, and strength testing?
NMS原理及其代码实现
2022 The Third J Question Journey
Software Testing Interview Questions: What's the Key to a Good Test Plan?
元宇宙:未来我们的每一个日常行为是否都能成为赚钱工具?
Software Testing Interview Questions: What's the Difference Between Manual Testing and Automated Testing?
[FreeRTOS] FreeRTOS and stm32 built-in stack occupancy
Software Testing Interview Questions: What aspects should be considered when designing test cases, i.e. what aspects should different test cases test against?
Matlab uses plotting method for data simulation and simulation
00、数组及字符串常用的 API(详细剖析)
如何用 Solidity 创建一个“Hello World”智能合约
2022 Multi-school Second Session K Question Link with Bracket Sequence I
tiup uninstall
Software testing interview questions: What are the three modules of LoadRunner?
软件基础的理论
Software testing interview questions: How many types of software are there?
GO中sync包自由控制并发的方法









