当前位置:网站首页>NLP自然语言处理(二)
NLP自然语言处理(二)
2022-07-30 02:56:00 【敷衍zgf】
NLP自然语言处理(二)
一、pytorch反向传播和梯度计算方法
pytorch完成线性回归
- tensor中的require_grad参数
a.设置为True,表示会记录该tensor的计算过程,追踪对于该张量的所有操作,每一次计算都会修改其gard_fn属性,用来记录做过的操作。 - tensor中的grad_fn属性
a.用来保存计算的过程 - tensor不保留计算过程
a. with torch.no_grad()
为了防止跟踪历史记录(和使用内存),可以将代码块包装在with torch.no_grad():中,在评估模型中可使用,模型具有requires_grad = True的可训练参数,但是我们不需要在此过程对其进行梯度计算。 - 反向传播:
a. out.backward() 梯度计算,保存到x.gard中
b.导数保存在tensor.grad,默认梯度会累加 - tensor.data
a.获取tensor中值的引用操作(只有值) - tensor.numpy ()
a.当tensor中需要计算梯度的时候,grad_fn不为None的时候,
tensor.data.numpy()、tensor.detach().numpy()能够实现对tensor中的数据的深拷贝,转化为ndarray类型
二、线性回归的实现
基础模型是y = wx+b 其中w和b均为参数,使用 y = 3x + 0.8 来构造数据x,y 最后通过模型应该能得出w和b 的值接近3和0.8
import torch
import matplotlib.pyplot as plt
from numpy import *
learning_rate = 0.01
# 1.准备数据
# y = 3x + 0.8 只有一个x是一维
# 基础模型是y = wx+b 其中w b均为参数,使用 y = 3x + 0.8 来构造数据x,y 最后通过模型应该能得出w b 的值接近3 0.8
# 构造一个500行 1列的数据 rand() 0-1
x = torch.rand([500,1])
y_true = x*3 + 0.8 # x 与 y 都是 500行 1列
# 2.通过模型计算y_predict
# requires_grad=True表示会记录该tensor的计算过程,默认是False
w = torch.rand([1,1],requires_grad=True) # [1,1]是因为x[500,1]与[1,1]相乘是[500,1]
b = torch.tensor(0,requires_grad=True,dtype=torch.float32) # b全为0
# 4.通过循环,反向传播,更新参数
for i in range(2000):
# 3.计算损失值
y_predict = torch.matmul(x, w) + b # matmul()矩阵乘法
loss = (y_true - y_predict).pow(2).mean()
# 先判断w是不是一个数 即不为None
if w.grad is not None :
# 是一个数
w.data.zero_() # 将w _ 就地修改为0 归零操作 每次反向传播前将梯度置0
if b.grad is not None:
b.data.zero_()
loss.backward() # 反向传播
w.data = w.data - learning_rate * w.grad
b.data = b.data - learning_rate * b.grad
if i % 50 == 0 :
print("w,b,loss",w.item(),b.item(),loss.item()) # 通过item获得w和b的值
# 设置大小
plt.figure(figsize=(20,8))
# 散点图
plt.scatter(x.numpy().reshape(-1),y_true.numpy().reshape(-1))
# 直线
y_predict = torch.matmul(x, w) + b
plt.plot(x.numpy().reshape(-1),y_predict.detach().numpy().reshape(-1),c = 'r')
plt.show()


边栏推荐
- JUC (7): Thread Safety Analysis of Variables
- nrm ls 为什么前面不带 *了
- 【ModelArts系列】华为ModelArts Notebook训练yolov3模型(开发环境)
- 一文读懂Elephant Swap,为何为ePLATO带来如此高的溢价?
- 1050 graphics card, why is the graphics card usage ranking on Steam always the top five
- 雪花是否一样问题
- One book 1922 - table tennis
- JS Bom location 楼层导航效果 offsetTop data-n 方括号选择器
- 计算机复试面试题总结
- 运营人必须掌握的6大类26个基本模型
猜你喜欢

新手入门上位机开发 C#语言:PC串口发送数据
CSDN外链解决方法 (2022-07-28测试可用)

Dell's first pure soft product redefines next-generation object storage

【ModelArts系列】华为ModelArts Notebook训练yolov3模型(开发环境)

计算机复试面试题总结

一本通1922——乒乓球

Awesome, Tencent technical experts handed Redis technical notes, and the download volume has exceeded 30W
Linux Jenkins查找缓存文件及删除 (2022-07测试可用)

票房破7.9亿美元,最近这部恐龙爽片你看了吗?

杜教筛【莫比乌斯前缀和,欧拉函数前缀和】推导与模板【一千五百字】
随机推荐
Detailed explanation of carousel picture 2 - carousel pictures through left positioning
Mysql中事务是什么?有什么用?
影响小程序开发费用的三个因素!
Successfully resolved AttributeError: 'PngImageFile' object has no attribute 'imshow'
selenium应用之拉勾简历邀约数据抓取与分析
使用SqlSessionFactory工具类抽取
【机器学习】通俗易懂决策树(原理篇)
信息系统项目管理师核心考点(五十四)配置项分类、状态与版本
复星医药募资44.84亿:高毅资产认购20亿 成第三大股东
机器学习1一回归模型(一)
(RCE)远程代码/命令执行漏洞漏洞练习
【ModelArts系列】华为ModelArts Notebook训练yolov3模型(开发环境)
golang的channel实现原理
重写并自定义依赖的原生的Bean方法
答对这3个面试问题,薪资直涨20K
软件测试面试题及答案解析,2022最强版
One book 1922 - table tennis
【服务器存储数据恢复】华为OceanStor某型号存储raid5数据恢复案例
超详细的MySQL三万字总结
A transaction is in Mysql?What's the use?