当前位置:网站首页>B站刘二大人-反向传播
B站刘二大人-反向传播
2022-07-06 05:33:00 【宁然也】
系列文章:
B站刘二大人-线性回归及梯度下降
文章目录
代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)*(y_pred-y)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
for epoch in range(100):
l = 0
loss_sum = 0
for (x ,y) in zip(x_data, y_data):
# l, loss_sum都是张量,无梯度跟踪
l = loss(x,y)
loss_sum += l.data
# 我的疑问:反向传播、梯度更新进行了epoch*len(x_data),
# 为什么不能进行epoch次。
l.backward()
w.data = w.data - alpha*w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
# 获取张量上的值需要转换为numpy
loss_list.append(loss_sum.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()

反向传播、梯度更新进行lepoch次的代码
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w是张量且设置了梯度跟踪
w = torch.Tensor([1.0])
w.requires_grad = True
print("w=",w)
def forward(x):
return x*w
# 单个数据的损失函数
def loss(xs, ys):
# y_pred = forward(x)
# return (y_pred - y)*(y_pred-y)
loss_sum = 0
for(x, y) in zip(xs, ys):
y_pred = forward(x)
loss_sum += (y_pred-y)*(y_pred-y)
return loss_sum/len(xs)
# 学习率
alpha = 0.001
epoch_list = []
w_list = []
loss_list = []
# 进行epoch次梯度更新、损失函数计算
for epoch in range(100):
# 计算所有数据的损失函数
l = loss(x_data, y_data)
l.backward()
# 梯度更新
w.data = w.data - alpha * w.grad.data
w.grad.data.zero_()
w_list.append(w.data)
epoch_list.append(epoch)
loss_list.append(l.data.numpy()[0])
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_sum")
plt.show()

两种方式跑出来的图在 损失下降率、结果有一定的差别。
目前我也有疑问,不知道那种合适。
B站老师写的是第一个代码
边栏推荐
- Using stopwatch to count code time
- Summary of deep learning tuning tricks
- Quantitative description of ANC noise reduction
- Notes, continuation, escape and other symbols
- Web Security (V) what is a session? Why do I need a session?
- Check the useful photo lossless magnification software on Apple computer
- 【torch】|torch.nn.utils.clip_grad_norm_
- Installation de la Bibliothèque de processus PDK - csmc
- [imgui] unity MenuItem shortcut key
- How to use PHP string query function
猜你喜欢

Huawei equipment is configured with OSPF and BFD linkage

flutter 实现一个有加载动画的按钮(loadingButton)

26file filter anonymous inner class and lambda optimization

PDK工艺库安装-CSMC

Review of double pointer problems

03. Login of development blog project

29io stream, byte output stream continue write line feed

02. 开发博客项目之数据存储

Promise summary

注释、接续、转义等符号
随机推荐
MySQL advanced learning summary 9: create index, delete index, descending index, and hide index
Note the various data set acquisition methods of jvxetable
指針經典筆試題
Pytorch代码注意的细节,容易敲错的地方
网站进行服务器迁移前应做好哪些准备?
Yyds dry inventory SSH Remote Connection introduction
【经验】win11上安装visio
Solution of QT TCP packet sticking
How to get list length
Vulhub vulnerability recurrence 72_ uWSGI
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
用StopWatch 统计代码耗时
Promotion hung up! The leader said it wasn't my poor skills
Nacos - TC Construction of High available seata (02)
Talking about the type and function of lens filter
How to download GB files from Google cloud hard disk
pix2pix:使用条件对抗网络的图像到图像转换
Promotion hung up! The leader said it wasn't my poor skills
Easy to understand IIC protocol explanation
巨杉数据库再次亮相金交会,共建数字经济新时代