当前位置:网站首页>第5讲 使用pytorch实现线性回归
第5讲 使用pytorch实现线性回归
2022-08-05 05:13:00 【长路漫漫 大佬为伴】
使用pytorch实现线性回归
第五讲随笔
广播机制
例如不同形状矩阵之间相加,则会进行广播,扩张到同样的形状再进行运算
广播前:
广播后:
下面也是采用了广播机制,y1,y2,y3并非一个向量,而是一个矩阵,因此w需要进行广播,再与x1,x2,x3进行数乘
init构造函数
init构造函数用来初始化对象
简述 init、new、call 方法
用Module构造的对象,会自动根据计算图,实现backward的过程
使用pytorch实现线性回归
需要注意一些问题
- 1.#Module 中实现了forward,因此下方需要重写forward函数覆盖掉Module中的forward,因此LinearModel必须重写forward
# Module 中实现了forward,因此下方需要重写forward函数覆盖掉Module中的forward
# Linear 也构造于 Module,因此也是可调用对象
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
#torch.nn.Linear实际上在构造一个对象,包含了权重和偏置,继承自Module
# (1,1)是指每一个输入样本x和每一个输出样本y的特征维度,这里数据集中的x和y的特征都是1维的
# 该线性层需要学习的参数是w和b 获取w/b的方式分别是~linear.weight/linear.bias
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
#self.linear(x) 在对象后面加()意味着实现了一个可调用对象
y_pred = self.linear(x)
return y_pred
- 如果将torch.nn.MSELoss的参数设置为size_average=False,在pycharm中会报错size_average and reduce args will be deprecated, please use reduction=‘sum’ ,可能是因为编译器的原因导致

# 构造损失函数和优化器MSE
# MSELoss也继承自 nn.Module
#criterion = torch.nn.MSELoss(size_average=False)中不能设置size_average=False,会出现以下报错
#UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # model.parameters()自动完成参数的初始化操作
- 模型训练的次数问题
如果需要减少训练集的损失,可以加大训练次数,即 for epoch in range(1000)
但是这种做法存在危险,因为训练集上的损失越来越小,测试集上的损失可能越来越大,产生过拟合问题
线性回归实现代码
import torch
# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
# Module 中实现了forward,因此下方需要重写forward函数覆盖掉Module中的forward
# Linear 也构造于 Module,因此也是可调用对象
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
#torch.nn.Linear实际上在构造一个对象,包含了权重和偏置,继承自Module
# (1,1)是指每一个输入样本x和每一个输出样本y的特征维度,这里数据集中的x和y的特征都是1维的
# 该线性层需要学习的参数是w和b 获取w/b的方式分别是~linear.weight/linear.bias
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
#self.linear(x) 在对象后面加()意味着实现了一个可调用对象
y_pred = self.linear(x)
return y_pred
#model是一个callable,即可调用的对象,可以model(x)
model = LinearModel()
# 构造损失函数和优化器MSE
# MSELoss也继承自 nn.Module
#criterion = torch.nn.MSELoss(size_average=False)中不能设置size_average=False,会出现以下报错
#UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # model.parameters()自动完成参数的初始化操作
# 如果需要减少训练集的损失,可以加大训练次数,即 for epoch in range(1000)
# 但是这种做法存在危险,因为训练集上的损失越来越小,测试集上的损失可能越来越大,产生过拟合问题
for epoch in range(100):
y_pred = model(x_data) # forward:predict
loss = criterion(y_pred, y_data) # forward: loss
print(epoch, loss.item())
optimizer.zero_grad()
loss.backward() # backward: autograd,自动计算梯度
optimizer.step() # update 参数,即更新w和b的值
print('w = ', model.linear.weight.item())#weight是一个矩阵,所以取值需要调用item()
print('b = ', model.linear.bias.item())
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
运行结果
0 35.37689208984375
1 15.92350959777832
2 7.260907173156738
3 3.402086019515991
4 1.6818101406097412
5 0.9135867357254028
6 0.569225549697876
。。。。。。。。。。。。
94 0.08291453868150711
95 0.08172288537025452
96 0.08054838329553604
97 0.07939067482948303
98 0.07824991643428802
99 0.07712534815073013
w = 1.8151198625564575
b = 0.4202759563922882
y_pred = tensor([[7.6808]])
边栏推荐
- 1.3 mysql batch insert data
- dedecms dream weaving tag tag does not support capital letters fix
- upload upload pictures to Tencent cloud, how to upload pictures
- Flutter学习三-Flutter基本结构和原理
- 2022 The 4th C.Easy Counting Problem (EGF+NTT)
- 服务器磁盘阵列
- Error creating bean with name ‘configDataContextRefresher‘ defined in class path resource
- Wise Force Deleter强制删除工具
- 入口点注入
- [Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)
猜你喜欢

基于Web的商城后台管理系统的设计与实现

OFDM Lecture 16 5 -Discrete Convolution, ISI and ICI on DMT/OFDM Systems
![coppercam入门手册[6]](/img/d3/a7d44aa19acfb18c5a8cacdc8176e9.png)
coppercam入门手册[6]

Difference between for..in and for..of

Qt制作18帧丘比特表白意中人、是你的丘比特嘛!!!

Flutter学习5-集成-打包-发布

【过一下14】自习室的一天

Structured light 3D reconstruction (1) Striped structured light 3D reconstruction

MySQL基础(一)---基础认知及操作

flex布局青蛙游戏通关攻略
随机推荐
Detailed explanation of each module of ansible
The role of DataContext in WPF
Mysql5.7 二进制 部署
[WeChat applet] WXML template syntax - conditional rendering
Algorithms - ones and zeros (Kotlin)
【无标题】
入口点注入
仪表板展示 | DataEase看中国:数据呈现中国资本市场
u-boot中的u-boot,dm-pre-reloc
C#关于set()和get()方法的理解及使用
Redis哨兵模式配置文件详解
[Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)
Redis - 13. Development Specifications
ESP32 485 Illuminance
Flutter学习4-基本UI组件
phone call function
2023 International Conference on Information and Communication Engineering (JCICE 2023)
Error creating bean with name 'configDataContextRefresher' defined in class path resource
Wise Force Deleter强制删除工具
redis事务