当前位置:网站首页>B站刘二大人-反向传播
B站刘二大人-反向传播
2022-07-06 05:33:00 【宁然也】
系列文章:
B站刘二大人-线性回归及梯度下降
文章目录
代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)*(y_pred-y)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
for epoch in range(100):
l = 0
loss_sum = 0
for (x ,y) in zip(x_data, y_data):
# l, loss_sum都是张量,无梯度跟踪
l = loss(x,y)
loss_sum += l.data
# 我的疑问:反向传播、梯度更新进行了epoch*len(x_data),
# 为什么不能进行epoch次。
l.backward()
w.data = w.data - alpha*w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
# 获取张量上的值需要转换为numpy
loss_list.append(loss_sum.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()

反向传播、梯度更新进行lepoch次的代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(xs, ys):
# y_pred = forward(x)
# return (y_pred - y)*(y_pred-y)
loss_sum = 0
for(x, y) in zip(xs, ys):
y_pred = forward(x)
loss_sum += (y_pred-y)*(y_pred-y)
return loss_sum/len(xs)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
# 进行epoch次梯度更新、损失函数计算
for epoch in range(100):
# 计算所有数据的损失函数
l = loss(x_data, y_data)
l.backward()
# 梯度更新
w.data = w.data - alpha * w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
loss_list.append(l.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()

两种方式跑出来的图在 损失下降率、结果有一定的差别。
目前我也有疑问,不知道那种合适。
B站老师写的是第一个代码
边栏推荐
- 毕业设计游戏商城
- Notes, continuation, escape and other symbols
- Safe mode on Windows
- 数字经济破浪而来 ,LTD是权益独立的Web3.0网站?
- Remember an error in MySQL: the user specified as a definer ('mysql.infoschema '@' localhost ') does not exist
- Zoom and pan image in Photoshop 2022
- [force buckle]43 String multiplication
- 03. 开发博客项目之登录
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- 【SQL server速成之路】——身份验证及建立和管理用户账户
猜你喜欢

Safe mode on Windows
![[JVM] [Chapter 17] [garbage collector]](/img/f4/e6ff0e3edccf23399ec12b7913749a.jpg)
[JVM] [Chapter 17] [garbage collector]
![[Tang Laoshi] C -- encapsulation: classes and objects](/img/4e/30d2d4652ea2d4cd5fa7cbbb795863.jpg)
[Tang Laoshi] C -- encapsulation: classes and objects

03. Login of development blog project

应用安全系列之三十七:日志注入

Codeforces Round #804 (Div. 2) Editorial(A-B)

Figure database ongdb release v-1.0.3

JS array list actual use summary

Sword finger offer II 039 Maximum rectangular area of histogram

29io stream, byte output stream continue write line feed
随机推荐
Sword finger offer II 039 Maximum rectangular area of histogram
Self built DNS server, the client opens the web page slowly, the solution
[cloud native] 3.1 kubernetes platform installation kubespher
Excel转换为Lua的配置文件
B站刘二大人-多元逻辑回归 Lecture 7
Can the feelings of Xi'an version of "Coca Cola" and Bingfeng beverage rush for IPO continue?
flutter 实现一个有加载动画的按钮(loadingButton)
59. Spiral matrix
LeetCode_ String inversion_ Simple_ 557. Reverse word III in string
Promotion hung up! The leader said it wasn't my poor skills
2022半年总结
Force buckle 1189 Maximum number of "balloons"
Codeforces Round #804 (Div. 2) Editorial(A-B)
Zoom and pan image in Photoshop 2022
指針經典筆試題
js Array 列表 实战使用总结
剑指 Offer II 039. 直方图最大矩形面积
Questions d'examen écrit classiques du pointeur
Installation de la Bibliothèque de processus PDK - csmc
Qt TCP 分包粘包的解决方法