当前位置:网站首页>【线性神经网络】线性回归 / 基础优化方法
【线性神经网络】线性回归 / 基础优化方法
2022-07-30 05:38:00 【PBemmm】
有点猪脑过载了,,,
线性回归
好像是唯一有显示解的模型,推导没怎么看懂,先空着
线性模型
对于n维输入,n维权重,和一个标准偏差b

损失函数(平方损失)
用来评估实际值与预测值损失

对于n个样本(别忘除n)

随机梯度下降
梯度下降长这样,学习率是步长的超参数

对于每次迭代,都需要对整个样本重新求梯度,代价过大,所以一般选择随机对样本进行取样再求梯度,叫做随机梯度下降。
我们每次采集b个样本,这里的批量大小b是另一个重要的超参数
算法实现
造轮子实现法
生成数据
def synthetic_data(w, b, num_examples): #@save
"""生成y=Xw+b+噪声"""
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)数据分批量
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]定义模型
def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b定义损失函数
def squared_loss(y_hat, y): #@save
"""均方损失"""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2定义优化算法
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()训练
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss
for epoch in range(num_epochs):
for X, y in data_iter(batch_size, features, labels):
l = loss(net(X, w, b), y) # X和y的小批量损失
# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
# 并以此计算关于[w,b]的梯度
l.sum().backward()
sgd([w, b], lr, batch_size) # 使用参数的梯度更新参数
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')调包实现
生成数据集 / 读取数据集
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)def load_array(data_arrays, batch_size, is_train=True): #@save
"""构造一个PyTorch数据迭代器"""
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)定义模型
from torch import nn
net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)Linear(2,1)指输入特征形状为2,输出为一个标量,形状1
net[0]访问第一层,weight.data和bias.data可以访问参数
损失函数
loss = nn.MSELoss()这个就是现成的均方损失,或平方L2范数
优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)SGD:梯度下降法
optim:实现各种优化算法的包
lr:学习率
训练
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X) ,y)
trainer.zero_grad()
l.backward()
trainer.step()
l = loss(net(features), labels)
print(f'epoch {epoch + 1}, loss {l:f}')net(X)预测值,输出的是标量预测y,net里面存的是参数权重w和噪音b,输出是1维标量,其实对应的就是前面手写版里的向量内积和b求预测值,y = <X, w> + b
trainer.step()迭代,更新参数,Wt = Wt-1 - lr * 梯度
data_iter是随机批量的数据,后面输出的l = loss(net(feartures), labels)是将更新来的参数代入整个数据中求loss
结果:,,,,挺抽象的

边栏推荐
- 安装pytorch
- PyCharm usage tutorial (more detailed, picture + text)
- JVM 内存结构 超详细学习笔记(一)
- [Mysql] CONVERT函数
- 【Pytorch】torch.manual_seed()、torch.cuda.manual_seed() 解释
- Error: listen EADDRINUSE: address already in use 127.0.0.1:3000
- It's time to have to learn English, give yourself multiple paths
- 分布式事务之 LCN框架的原理和使用(二)
- MySQL (2)
- net start mysql MySQL 服务正在启动 . MySQL 服务无法启动。 服务没有报告任何错误。
猜你喜欢

MySQL stored procedure

从驱动表和被驱动表来快速理解MySQL中的内连接和外连接

MySQL (2)

It is enough for MySQL to have this article (37k words, just like Bojun!!!)

Error: npm ERR code EPERM

cmd(命令行)操作或连接mysql数据库,以及创建数据库与表

cnpm installation steps

ClickHouse data insert, update and delete operations SQL

分布式事务之 Seata框架的原理和实战使用(三)

MySQL Soul 16 Questions, how many questions can you last?
随机推荐
cnpm installation steps
坠落的蚂蚁(北京大学考研机试题)
4461. Range Partition (Google Kickstart2022 Round C Problem B)
mysql 中 in 的用法
Error: listen EADDRINUSE: address already in use 127.0.0.1:3000
idea设置自动带参数的方法注释(有效)
How is crawler data collected and organized?
Introduction to Oracle Patch System and Opatch Tool
4、nerf(pytorch)
The difference between asyncawait and promise
MySQL 用户授权
Navicat cannot connect to mysql super detailed processing method
Pytorch to(device)
G Bus Count (Google Kickstart2014 Round D Problem B) (DAY 89)
cmd(命令行)操作或连接mysql数据库,以及创建数据库与表
破纪录者(Google Kickstart2020 Round D Problem A)
[Other] DS5
2022 SQL big factory high-frequency practical interview questions (detailed analysis)
It's time to have to learn English, give yourself multiple paths
navicat无法连接mysql超详细处理方法