当前位置:网站首页>B站刘二大人-反向传播
B站刘二大人-反向传播
2022-07-06 05:33:00 【宁然也】
系列文章:
B站刘二大人-线性回归及梯度下降
文章目录
代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)*(y_pred-y)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
for epoch in range(100):
l = 0
loss_sum = 0
for (x ,y) in zip(x_data, y_data):
# l, loss_sum都是张量,无梯度跟踪
l = loss(x,y)
loss_sum += l.data
# 我的疑问:反向传播、梯度更新进行了epoch*len(x_data),
# 为什么不能进行epoch次。
l.backward()
w.data = w.data - alpha*w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
# 获取张量上的值需要转换为numpy
loss_list.append(loss_sum.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()
反向传播、梯度更新进行lepoch次的代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(xs, ys):
# y_pred = forward(x)
# return (y_pred - y)*(y_pred-y)
loss_sum = 0
for(x, y) in zip(xs, ys):
y_pred = forward(x)
loss_sum += (y_pred-y)*(y_pred-y)
return loss_sum/len(xs)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
# 进行epoch次梯度更新、损失函数计算
for epoch in range(100):
# 计算所有数据的损失函数
l = loss(x_data, y_data)
l.backward()
# 梯度更新
w.data = w.data - alpha * w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
loss_list.append(l.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()
两种方式跑出来的图在 损失下降率、结果有一定的差别。
目前我也有疑问,不知道那种合适。
B站老师写的是第一个代码
边栏推荐
- 26file filter anonymous inner class and lambda optimization
- RustDesk 搭建一个自己的远程桌面中继服务器
- Codeforces Round #804 (Div. 2) Editorial(A-B)
- 【torch】|torch.nn.utils.clip_grad_norm_
- 【经验】UltralSO制作启动盘时报错:磁盘/映像容量太小
- [untitled]
- 自建DNS服务器,客户端打开网页慢,解决办法
- Notes, continuation, escape and other symbols
- UCF (summer team competition II)
- [JVM] [Chapter 17] [garbage collector]
猜你喜欢
【经验】win11上安装visio
指針經典筆試題
用StopWatch 统计代码耗时
26file filter anonymous inner class and lambda optimization
Deep learning -yolov5 introduction to actual combat click data set training
Graduation design game mall
PDK工藝庫安裝-CSMC
Unity Vector3. Use and calculation principle of reflect
[leetcode daily question] number of enclaves
Pix2pix: image to image conversion using conditional countermeasure networks
随机推荐
(column 22) typical column questions of C language: delete the specified letters in the string.
Notes, continuation, escape and other symbols
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Pickle and savez_ Compressed compressed volume comparison
nacos-高可用seata之TC搭建(02)
Promise summary
[QNX hypervisor 2.2 user manual]6.3.3 using shared memory (shmem) virtual devices
Quantitative description of ANC noise reduction
[force buckle]43 String multiplication
04. 项目博客之日志
Force buckle 1189 Maximum number of "balloons"
jdbc使用call调用存储过程报错
[leetcode] 18. Sum of four numbers
[JVM] [Chapter 17] [garbage collector]
Vulhub vulnerability recurrence 73_ Webmin
28io stream, byte output stream writes multiple bytes
UCF(暑期团队赛二)
Solution of QT TCP packet sticking
Huawei od computer test question 2
Installation de la Bibliothèque de processus PDK - csmc