当前位置:网站首页>B站刘二大人-反向传播
B站刘二大人-反向传播
2022-07-06 05:33:00 【宁然也】
系列文章:
B站刘二大人-线性回归及梯度下降
文章目录
代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)*(y_pred-y)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
for epoch in range(100):
l = 0
loss_sum = 0
for (x ,y) in zip(x_data, y_data):
# l, loss_sum都是张量,无梯度跟踪
l = loss(x,y)
loss_sum += l.data
# 我的疑问:反向传播、梯度更新进行了epoch*len(x_data),
# 为什么不能进行epoch次。
l.backward()
w.data = w.data - alpha*w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
# 获取张量上的值需要转换为numpy
loss_list.append(loss_sum.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()

反向传播、梯度更新进行lepoch次的代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(xs, ys):
# y_pred = forward(x)
# return (y_pred - y)*(y_pred-y)
loss_sum = 0
for(x, y) in zip(xs, ys):
y_pred = forward(x)
loss_sum += (y_pred-y)*(y_pred-y)
return loss_sum/len(xs)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
# 进行epoch次梯度更新、损失函数计算
for epoch in range(100):
# 计算所有数据的损失函数
l = loss(x_data, y_data)
l.backward()
# 梯度更新
w.data = w.data - alpha * w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
loss_list.append(l.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()

两种方式跑出来的图在 损失下降率、结果有一定的差别。
目前我也有疑问,不知道那种合适。
B站老师写的是第一个代码
边栏推荐
- 59. Spiral matrix
- Solution of QT TCP packet sticking
- In 2022, we must enter the big factory as soon as possible
- Game push image / table /cv/nlp, multi-threaded start
- 【华为机试真题详解】检查是否存在满足条件的数字组合
- flutter 实现一个有加载动画的按钮(loadingButton)
- 备忘一下jvxetable的各种数据集获取方法
- 02. 开发博客项目之数据存储
- ByteDance program yuan teaches you how to brush algorithm questions: I'm not afraid of the interviewer tearing the code
- How to download GB files from Google cloud hard disk
猜你喜欢

In 2022, we must enter the big factory as soon as possible

Game push image / table /cv/nlp, multi-threaded start

Can the feelings of Xi'an version of "Coca Cola" and Bingfeng beverage rush for IPO continue?
![[mask requirements of OSPF and Isis in multi access network]](/img/7d/1ba80bb906caa9be4bef165ac26d2c.png)
[mask requirements of OSPF and Isis in multi access network]

Deep learning -yolov5 introduction to actual combat click data set training

图数据库ONgDB Release v-1.0.3

【云原生】3.1 Kubernetes平台安装KubeSpher

PDK工藝庫安裝-CSMC

Check the useful photo lossless magnification software on Apple computer

Yyds dry inventory SSH Remote Connection introduction
随机推荐
Excel转换为Lua的配置文件
图数据库ONgDB Release v-1.0.3
nacos-高可用seata之TC搭建(02)
Using stopwatch to count code time
指針經典筆試題
[leetcode] 18. Sum of four numbers
Review of double pointer problems
Configuration file converted from Excel to Lua
SQLite add index
Redis消息队列
Pytorch代码注意的细节,容易敲错的地方
C# AES对字符串进行加密
PDK工艺库安装-CSMC
Algorithm -- climbing stairs (kotlin)
巨杉数据库再次亮相金交会,共建数字经济新时代
Modbus协议通信异常
注释、接续、转义等符号
[machine learning notes] univariate linear regression principle, formula and code implementation
Text classification still stays at Bert? The dual contrast learning framework is too strong
Pickle and savez_ Compressed compressed volume comparison