当前位置:网站首页>【线性神经网络】线性回归 / 基础优化方法
【线性神经网络】线性回归 / 基础优化方法
2022-07-30 05:38:00 【PBemmm】
有点猪脑过载了,,,
线性回归
好像是唯一有显示解的模型,推导没怎么看懂,先空着
线性模型
对于n维输入,n维权重,和一个标准偏差b

损失函数(平方损失)
用来评估实际值与预测值损失

对于n个样本(别忘除n)

随机梯度下降
梯度下降长这样,学习率是步长的超参数

对于每次迭代,都需要对整个样本重新求梯度,代价过大,所以一般选择随机对样本进行取样再求梯度,叫做随机梯度下降。
我们每次采集b个样本,这里的批量大小b是另一个重要的超参数
算法实现
造轮子实现法
生成数据
def synthetic_data(w, b, num_examples): #@save
"""生成y=Xw+b+噪声"""
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)数据分批量
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]定义模型
def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b定义损失函数
def squared_loss(y_hat, y): #@save
"""均方损失"""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2定义优化算法
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()训练
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss
for epoch in range(num_epochs):
for X, y in data_iter(batch_size, features, labels):
l = loss(net(X, w, b), y) # X和y的小批量损失
# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
# 并以此计算关于[w,b]的梯度
l.sum().backward()
sgd([w, b], lr, batch_size) # 使用参数的梯度更新参数
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')调包实现
生成数据集 / 读取数据集
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)def load_array(data_arrays, batch_size, is_train=True): #@save
"""构造一个PyTorch数据迭代器"""
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)定义模型
from torch import nn
net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)Linear(2,1)指输入特征形状为2,输出为一个标量,形状1
net[0]访问第一层,weight.data和bias.data可以访问参数
损失函数
loss = nn.MSELoss()这个就是现成的均方损失,或平方L2范数
优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)SGD:梯度下降法
optim:实现各种优化算法的包
lr:学习率
训练
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X) ,y)
trainer.zero_grad()
l.backward()
trainer.step()
l = loss(net(features), labels)
print(f'epoch {epoch + 1}, loss {l:f}')net(X)预测值,输出的是标量预测y,net里面存的是参数权重w和噪音b,输出是1维标量,其实对应的就是前面手写版里的向量内积和b求预测值,y = <X, w> + b
trainer.step()迭代,更新参数,Wt = Wt-1 - lr * 梯度
data_iter是随机批量的数据,后面输出的l = loss(net(feartures), labels)是将更新来的参数代入整个数据中求loss
结果:,,,,挺抽象的

边栏推荐
- php数组实现根据某个键值将相同键值合并生成新二维数组的方法
- IDEA的database使用教程(使用mysql数据库)
- The use of Conluce, an online document management system
- MySQL 灵魂 16 问,你能撑到第几问?
- argparse —— 命令行选项、参数和子命令解析器
- [GLib] What is GType
- 4461. 范围分区(Google Kickstart2022 Round C Problem B)
- It is enough for MySQL to have this article (37k words, just like Bojun!!!)
- 如何使用FirewallD限制网络访问
- What is SOA (Service Oriented Architecture)?
猜你喜欢

CISP-PTE Zhenti Demonstration

optimizer.zero_grad()

Learn FPGA from the underlying structure (6) ---- Distributed RAM (DRAM, Distributed RAM)

JVM之GC 调优工具 Arthas 实战使用(二)

每日练习------输出一个整数的二进制数、八进制数、十六进制数。

Graphic mirror symmetry (schematic diagram)

Solve the problem that the local nacos is not configured but the localhost8848 connection exception always occurs

安装Nuxt.js时出现错误:TypeError:Cannot read property ‘eslint‘ of undefined

IDEA的database使用教程(使用mysql数据库)

Mysql8.+学习笔记
随机推荐
cnpm installation steps
Introduction to Oracle Patch System and Opatch Tool
Prime numbers (Tsinghua University computer test questions) (DAY 86)
MySQL user authorization
[Other] DS5
Thymeleaf简介
解决phpstudy无法启动MySQL服务
More fragrant open source projects than Ruoyi in 2022
MySQL Soul 16 Questions, how many questions can you last?
G Bus Count (Google Kickstart2014 Round D Problem B) (DAY 89)
爬虫数据是如何收集和整理的?
Oracle补丁体系及Opatch工具介绍
号称年薪30万占比最多的专业,你知道是啥嘛?
从驱动表和被驱动表来快速理解MySQL中的内连接和外连接
Pytorch to(device)
2022年比若依更香的开源项目
MySQL 数据库基础知识(系统化一篇入门)
CISP-PTE Zhenti Demonstration
微信小程序开发学习
CISP-PTE真题演示