当前位置:网站首页>【深度学习】:《PyTorch入门到项目实战》:简洁代码实现线性神经网络(附代码)
【深度学习】:《PyTorch入门到项目实战》:简洁代码实现线性神经网络(附代码)
2022-07-28 16:02:00 【JoJo的数据分析历险记】
【深度学习】:《PyTorch入门到项目实战》第三天:简洁代码实现线性神经网络(附代码)
- 本文收录于【深度学习】:《PyTorch入门到项目实战》专栏,此专栏主要记录如何使用
PyTorch实现深度学习笔记,尽量坚持每周持续更新,欢迎大家订阅! - 个人主页:JoJo的数据分析历险记
- 个人介绍:小编大四统计在读,目前保研到统计学top3高校继续攻读统计研究生
- 如果文章对你有帮助,欢迎
关注、点赞、收藏、订阅专栏
参考资料:本专栏主要以沐神《动手学深度学习》为学习资料,记录自己的学习笔记,能力有限,如有错误,欢迎大家指正。同时沐神上传了的教学视频和教材,大家可以前往学习。

文章目录
在上一节我们学习了如何使用pytorch从零实现一个线性回归模型。包括生成数据集,构建损失函数,梯度下降优化求解参数等。和很多其他机器学习框架一样,pytorch中也包含了许多可以自动实现机器学习的包。本章介绍一些如何使用nn简便的实现一个线性回归模型
1.生成数据集
import numpy as np
import torch
from torch.utils import data
接下来我们定义真实的w和b,以及生成模拟数据集,这一步和之前讲的一样
def synthetic_data(w, b, num_examples):
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
2.读取数据集
下面有一点不同的是我们通过data.TensorDataset把数据放在Tensor结构中,如何通过batch_size来设置每次抽取几个样本
def load_array(data_arrays,batch_size,is_train=True):
# TensorDataset:将数据对应成Tensor列表
data_set = data.TensorDataset(*data_arrays)
return data.DataLoader(data_set,batch_size,shuffle=is_train)#将数据加载读取出来,batch_size定义个数
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(iter(data_iter))
[tensor([[-0.0384, 1.1566],
[-0.9023, -0.6922],
[-0.0652, 1.1757],
[-0.8569, -1.0172],
[ 1.3489, -0.6855],
[ 0.1463, 0.1577],
[ 0.1615, -2.1549],
[-0.0533, -0.3301],
[-0.9913, 0.2226],
[ 0.1432, -0.9537]]),
tensor([[ 0.1836],
[ 4.7540],
[ 0.0802],
[ 5.9541],
[ 9.2256],
[ 3.9620],
[11.8700],
[ 5.2242],
[ 1.4718],
[ 7.7181]])]
3.线性模型搭建
这里我们通过nn搭建一个线性的神经网络
from torch import nn
这里我们使用Sequential类来接收线性层,这里我们只有一个线性神经网络层,其实可以不设置,但是在后续我们介绍的其他算法中,往往都是多层的,因此我们可以把这个当做一个标准化流程。
它的作用是将不同层串在一起,首先将数据传入到第一层,然后将第一层的输出传入到第二层作为输入,以此类推
net = nn.Sequential(nn.Linear(2,1))#第一个参数表示输入特征的纬度,第二个参数表示输出层的纬度
4.初始化参数
在定义net之后,我们需要做的就是定义我们要估计的参数。还是和之前类似,这里我们也需要定义两个参数一个是weight相当于之前的w,一个是bias相当于之前的b
net[0].weight.data.normal_(0, 0.01)#第一层的weight初始化
net[0].bias.data.fill_(0)#第一层的bias初始化
tensor([0.])
5. 定义损失函数
这里我们使用均方误差MSE来作为我们的损失函数
loss = nn.MSELoss()
6. 选择优化方法
这里我们使用随机梯度下降进行优化。从而得到我们的参数.需要传入两个参数,一个是待估计参数,另一个是学习率。我在这里设置为0.03
trainer = torch.optim.SGD(net.parameters(), lr=0.03)
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X) ,y)
trainer.zero_grad()
l.backward()
trainer.step()#更新参数
l = loss(net(features), labels)
print(f'epoch {
epoch + 1}, loss {
l:f}')
epoch 1, loss 0.000102
epoch 2, loss 0.000103
epoch 3, loss 0.000104
7.完整代码
# 导入相关库
import numpy as np
import torch
from torch.utils import data
from torch import nn
''' 定义模拟数据集函数 '''
def synthetic_data(w, b, num_examples):
X = torch.normal(0, 1, (num_examples, len(w)))#生成标准正态分布
y = torch.matmul(X, w) + b#计算y
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
''' 生成数据集 '''
true_w = torch.tensor([2, -3.4])#定义w
true_b = 4.2#定义b
features, labels = synthetic_data(true_w, true_b, 1000)#生成模拟数据集
''' 加载数据集 '''
def load_array(data_arrays,batch_size,is_train=True):
# TensorDataset:将数据对应成Tensor列表
data_set = data.TensorDataset(*data_arrays)
return data.DataLoader(data_set,batch_size,shuffle=is_train)#将数据加载读取出来,batch_size定义个数
batch_size = 10
data_iter = load_array((features,labels),batch_size)
''' 创建线性神经网络 '''
net = nn.Sequential(nn.Linear(2,1))#第一个参数表示输入特征的纬度,第二个参数表示输出层的纬度
net[0].weight.data.normal_(0, 0.01)#第一层的weight初始化
net[0].bias.data.fill_(0)#第一层的bias初始化
''' 定义损失函数MSE '''
loss = nn.MSELoss()
''' 创建SGD优化方法 '''
trainer = torch.optim.SGD(net.parameters(), lr=0.03)#创建SGD优化训练器
''' 正式训练 '''
num_epochs = 3#迭代次数
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X) ,y)#计算损失函数
trainer.zero_grad()
l.backward()
trainer.step()#更新参数
l = loss(net(features), labels)#计算最终的loss
print(f'epoch {
epoch + 1}, loss {
l:f}')
epoch 1, loss 0.000240
epoch 2, loss 0.000099
epoch 3, loss 0.000100
推荐阅读
本章的介绍到此介绍,如果文章对你有帮助,请多多点赞、收藏、评论、关注支持!!
边栏推荐
- HM secondary development - data names and its use
- Interesting kotlin 0x08:what am I
- About mit6.828_ HW9_ Some problems of barriers xv6 homework9
- Interesting kotlin 0x09:extensions are resolved statically
- leetcode647. 回文子串
- Introduction and implementation of queue (detailed explanation)
- 阿里云-武林头条-建站小能手争霸赛
- 关于MIT6.828_HW9_barriers xv6 homework9的一些问题
- Call DLL file without source code
- Signal shielding and processing
猜你喜欢

WSL+Valgrind+Clion

Sort 3-select sort and merge sort (recursive implementation + non recursive implementation)

阿里大哥教你如何正确认识关于标准IO缓冲区的问题

Im im development optimization improves connection success rate, speed, etc

Kubeedge releases white paper on cloud native edge computing threat model and security protection technology

Leetcode daily practice - the number of digits in the offer 56 array of the sword finger

排序5-计数排序

IM即时通讯开发优化提升连接成功率、速度等

LeetCode每日一练 —— 160. 相交链表

排序2-冒泡排序与快速排序(递归加非递归讲解)
随机推荐
About the web docking pin printer, lodop uses
How to set ticdc synchronization data to only synchronize the specified library?
大学生参加六星教育PHP培训,找到了薪水远超同龄人的工作
PHP image upload
阿里云-武林头条-建站小能手争霸赛
【从零开始学习SLAM】将坐标系变换关系发布到 topic tf
日常开发方案设计指北
"Wei Lai Cup" 2022 Niuke summer multi school training camp 3 acfhj
Installation of QT learning
Interesting kotlin 0x09:extensions are resolved statically
asp.net大文件分块上传断点续传demo
El input limit can only input the specified number
LeetCode-学会对无序链表进行插入排序(详解)
ticdc同步数据怎么设置只同步指定的库?
Some suggestions on Oracle SQL tuning
TCP handshake, waving, time wait connection reset and other records
记录开发问题
UNP前六章 回射服务模型 解析
阿里大哥教你如何正确认识关于标准IO缓冲区的问题
"Wei Lai Cup" 2022 Niuke summer multi school training camp 3 a.ancestor lca+ violence count