当前位置:网站首页>关于pytorch反向传播的思考
关于pytorch反向传播的思考
2022-07-27 05:13:00 【Mr_health】
实现一个简单的全来凝结网络,完成的工作是:

然后损失函数就是经典的L2损失:

代码如下,这里我们将偏执设置为0:
class network(nn.Module):
def __init__(self):
super().__init__()
self.w = torch.Tensor([1.0])
self.w.requires_grad = True
def forward(self, x):
output = self.w * x
return output
def gradient(x, y, w):
return 2*x*(x*w - y)
if __name__ == '__main__':
model = network()
x_data = torch.Tensor([1.0,2.0,3.0])
y_data = torch.Tensor([2.0,4.0,6.0])
output = model(x_data)
loss = (output-y_data).pow(2).mean()
loss.backward()
grads = gradient(x_data, y_data, model.w).mean()#手动实现 反向传播
print(model.w.grad.data) #
print(grads)运行的结果是一致的,都为-8.3333
具体推导一下:

下面修改一下loss的形式,相当于每个样本的权重不一样,增加最后一个样本的权重,可以看到两个样本的权重为1/4,最后一个样本的权重为1/2。从这个角度来看,实际上上面按照标准的loss反向传播,实际就是每个样本等权重。
model = network()
x_data = torch.Tensor([1.0,2.0,3.0])
y_data = torch.Tensor([2.0,4.0,6.0])
output = model(x_data)
loss = (output-y_data).pow(2)
loss = (loss[0:2].mean() + loss[2])/2 #修改loss
loss.backward()
# grads = gradient(x_data, y_data, model.w).mean()#手动实现 反向传播
print(model.w.grad.data) #
计算结果是-11.5,下面手动计算下:

边栏推荐
- Gbase 8C - SQL reference 6 SQL syntax (4)
- vscode打造golang开发环境以及golang的debug单元测试
- 4.张量数据类型和创建Tensor
- How to not overwrite the target source data when dBSwitch data migrates data increments
- Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter
- Seektiger will launch STI fusion mining function to obtain Oka pass
- GBASE 8C——SQL参考6 sql语法(13)
- Only one looper may be created per thread
- Seektiger's okaleido has a big move. Will the STI of ecological pass break out?
- Andorid detects GPU rendering speed and over rendering
猜你喜欢

数字图像处理——第九章 形态学图像处理

Day 9. Graduate survey: A love–hurt relationship

15.GPU加速、minist测试实战和visdom可视化

Jenkins build image automatic deployment

Emoji Emoji for text emotion analysis -improving sentimental analysis accuracy with Emoji embedding

Graph node deployment

12.优化问题实战

jenkins构建镜像自动化部署

Move protocol launched a beta version, and you can "0" participate in p2e

【好文种草】根域名的知识 - 阮一峰的网络日志
随机推荐
Gbase 8C - SQL reference 6 SQL syntax (2)
Rk3288 board HDMI displays logo images of uboot and kernel
数字图像处理 第二章 数字图像基础
Deploy redis with docker for high availability master-slave replication
4.张量数据类型和创建Tensor
Day 9. Graduate survey: A love–hurt relationship
19.上下采样与BatchNorm
Count the quantity in parallel after MySQL grouping
ES对比两个索引的数据差
Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter
Emoji表情符号用于文本情感分析-Improving sentiment analysis accuracy with emoji embedding
数字图像处理——第九章 形态学图像处理
Gbase 8C - SQL reference 6 SQL syntax (9)
5.索引和切片
15.GPU加速、minist测试实战和visdom可视化
GBASE 8C——SQL参考6 sql语法(12)
2021中大厂php+go面试题(2)
Gbase 8C - SQL reference 6 SQL syntax (5)
GBASE 8C——SQL参考6 sql语法(13)
Uboot supports LCD and HDMI to display different logo images