当前位置:网站首页>torch.autograd.grad求二阶导数
torch.autograd.grad求二阶导数
2022-08-05 00:40:00 【鬼道2022】
1 用法介绍
pytorch中torch.autograd.grad函数主要用于计算并返回输出相对于输入的梯度总和,具体的参数作用如下所示:
torch.tril(input, diagonal=0, *, out=None) * \longrightarrow *Tensor
- outputs(sequence of Tensor):表示微分函数的输出
- inputs (sequence of Tensor):表示微分函数的输入
- grad_outputs (sequence of Tensor):表示“向量-雅克比矩阵”的向量
- retain_graph (bool, optional):表示是否需要将计算图释放掉,当计算二阶导数时需要设置为True
- create_graph (bool, optional):表示是否需要将梯度将会加入到计算图中,当计算高阶导数或者其他计算时会将其设置为需要设置为True
- allow_unused (bool, optional):表示是否只返回输入的梯度,而不返回其他叶子节点的梯度
2 实例讲解
以下给出了具体的二阶导数解析解的数学实例
给定一个向量 x = ( x 1 , x 2 ) ⊤ {\bf{x}}=(x_1,x_2)^{\top} x=(x1,x2)⊤,可以得到向量 y = ( y 1 , y 2 ) ⊤ = ( x 1 2 , x 2 2 ) ⊤ {\bf{y}}=(y_1,y_2)^{\top}=(x^2_1,x^2_2)^{\top} y=(y1,y2)⊤=(x12,x22)⊤。对向量 y {\bf{y}} y的元素求平均可以得到损失函数 l o s s 1 \mathrm{loss}_1 loss1为: l o s s 1 ( x ) = m e a n ( y ) = x 1 2 + x 2 2 2 \mathrm{loss}_1({\bf{x}})=\mathrm{mean}({\bf{y}})=\frac{x_1^2+x^2_2}{2} loss1(x)=mean(y)=2x12+x22向量 y {\bf{y}} y元素的分量分别对 x {\bf{x}} x求偏导,然后相加求平均得到损失函数 l o s s 2 \mathrm{loss}_2 loss2为 { h 1 ( x ) = ∂ y 1 ∂ x = ( 2 x 1 , 0 ) ⊤ h 2 ( x ) = ∂ y 2 ∂ x = ( 0 , 2 x 2 ) ⊤ , l o s s 2 ( x ) = m e a n ( h 1 ( x 1 ) − h 2 ( x 2 ) ) = x 1 − x 2 \left\{\begin{aligned}h_1({\bf{x}})&=\frac{\partial y_1}{\partial {\bf{x}}}=(2x_1,0)^{\top}\\h_2({\bf{x}})&=\frac{\partial y_2}{\partial {\bf{x}}}=(0,2x_2)^{\top}\end{aligned}\right.,\quad \mathrm{loss}_2({\bf{x}})=\mathrm{mean}(h_1({\bf{x}}_1)-h_2({\bf{x}}_2))=x_1-x_2 ⎩⎨⎧h1(x)h2(x)=∂x∂y1=(2x1,0)⊤=∂x∂y2=(0,2x2)⊤,loss2(x)=mean(h1(x1)−h2(x2))=x1−x2将损失函数 l o s s 1 \mathrm{loss}_1 loss1与损失函数 l o s s 2 \mathrm{loss}_2 loss2相加可以得到 l o s s ( x ) = l o s s 1 ( x ) + l o s s 2 ( x ) = x 1 2 + x 2 2 2 + x 1 − x 2 \mathrm{loss}({\bf{x}})=\mathrm{loss}_1({\bf{x}})+\mathrm{loss}_2({\bf{x}})=\frac{x_1^2+x_2^2}{2}+x_1-x_2 loss(x)=loss1(x)+loss2(x)=2x12+x22+x1−x2最终损失函数 l o s s \mathrm{loss} loss对向量 x {\bf{x}} x的偏导数为 ∂ l o s s ∂ x = ( x 1 + 1 , x 2 − 1 ) ⊤ \frac{\partial {\mathrm{loss}}}{\partial{ {\bf{x}}}}=(x_1+1,x_2-1)^{\top} ∂x∂loss=(x1+1,x2−1)⊤
以下为用pytorch实现二阶导数相对应的代码实例:
import torch
x = torch.tensor([5.0, 7.0], requires_grad=True)
y = x**2
loss1 = torch.mean(y)
h1 = torch.autograd.grad(y[0], x, retain_graph = True, create_graph=True)
h2 = torch.autograd.grad(y[1], x, retain_graph = True, create_graph=True)
loss2 = torch.mean(h1[0] - h2[0])
loss = loss1 + loss2
result = torch.autograd.grad(loss, x)
print(result)
当向量 x {\bf{x}} x取值为 ( 5 , 7 ) ⊤ (5,7)^{\top} (5,7)⊤时,根据数学解析解得到的二阶导数为 ( 6 , 6 ) ⊤ (6,6)^{\top} (6,6)⊤,对应的代码运行的实验结果也为 ( 6 , 6 ) (6,6) (6,6)。
边栏推荐
- Pytorch usage and tricks
- 子连接中的参数传递
- 元宇宙:未来我们的每一个日常行为是否都能成为赚钱工具?
- typeScript - Partially apply a function
- ORA-00257
- lua 如何 实现一个unity协程的工具
- 2022杭电多校第三场 L题 Two Permutations
- "WEB Security Penetration Testing" (28) Burp Collaborator-dnslog out-band technology
- 《WEB安全渗透测试》(28)Burp Collaborator-dnslog外带技术
- could not build server_names_hash, you should increase server_names_hash_bucket_size: 32
猜你喜欢
机器学习(公式推导与代码实现)--sklearn机器学习库
简单的顺序结构程序(C语言)
"WEB Security Penetration Testing" (28) Burp Collaborator-dnslog out-band technology
【FreeRTOS】FreeRTOS与stm32内置堆栈的占用情况
电赛必备技能___定时ADC+DMA+串口通信
JUC thread pool (1): FutureTask use
Matlab uses plotting method for data simulation and simulation
[idea] idea configures sql formatting
Redis visual management software Redis Desktop Manager2022
oracle创建用户以后的权限问题
随机推荐
Countdown to 1 day!From August 2nd to 4th, I will talk with you about open source and employment!
英特尔WiFi 7产品将于2024年亮相 最高速度可达5.8Gbps
Software Testing Interview Questions: Qualifying Criteria for Software Acceptance Testing?
JUC线程池(一): FutureTask使用
ORA-00257
After the staged testing is complete, have you performed defect analysis?
leetcode:266. 回文全排列
RK3399平台开发系列讲解(内核调试篇)2.50、嵌入式产品启动速度优化
Helm Chart
2022 Multi-school Second Session K Question Link with Bracket Sequence I
The principle of NMS and its code realization
D - I Hate Non-integer Number (选数的计数dp
2022 Hangzhou Electric Power Multi-School Session 3 Question L Two Permutations
TinyMCE禁用转义
MAUI Blazor 权限经验分享 (定位,使用相机)
SV class virtual method of polymorphism
QSunSync Qiniu cloud file synchronization tool, batch upload
软件测试面试题:关于自动化测试工具?
Inter-process communication and inter-thread communication
oracle创建用户