当前位置:网站首页>Lecture 5 Using pytorch to implement linear regression
Lecture 5 Using pytorch to implement linear regression
2022-08-05 05:23:00 【A long way to go】
使用pytorch实现线性回归
The fifth lecture
广播机制
For example, adding between matrices of different shapes,will broadcast,Expand to the same shape and perform the operation
广播前:
广播后:
The broadcast mechanism is also used below,y1,y2,y3is not a vector,而是一个矩阵,因此w需要进行广播,再与x1,x2,x3进行数乘
init构造函数
initConstructors are used to initialize objects
简述 init、new、call 方法
用Module构造的对象,automatically according to the calculation graph,实现backward的过程
使用pytorch实现线性回归
需要注意一些问题
- 1.#Module 中实现了forward,Therefore the following needs to be rewrittenforward函数覆盖掉Module中的forward,因此LinearModel必须重写forward
# Module 中实现了forward,Therefore the following needs to be rewrittenforward函数覆盖掉Module中的forward
# Linear also constructed in Module,Hence also a callable object
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
#torch.nn.Linearactually constructing an object,Weights and biases are included,继承自Module
# (1,1)refers to each input samplexand every output sampley的特征维度,这里数据集中的x和y的特征都是1维的
# 该线性层需要学习的参数是w和b 获取w/b的方式分别是~linear.weight/linear.bias
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
#self.linear(x) add after the object()Means a callable object is implemented
y_pred = self.linear(x)
return y_pred
- 如果将torch.nn.MSELoss的参数设置为size_average=False,在pycharm中会报错size_average and reduce args will be deprecated, please use reduction=‘sum’ ,Maybe it's because of the compiler

# 构造损失函数和优化器MSE
# MSELoss也继承自 nn.Module
#criterion = torch.nn.MSELoss(size_average=False)中不能设置size_average=False,会出现以下报错
#UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # model.parameters()自动完成参数的初始化操作
- The number of times the model is trained
Reduce the loss on the training set if needed,You can increase the number of training sessions,即 for epoch in range(1000)
But there are dangers in doing so,Because the loss on the training set is getting smaller and smaller,The loss on the test set may get bigger and bigger,产生过拟合问题
Linear regression implementation code
import torch
# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
# Module 中实现了forward,Therefore the following needs to be rewrittenforward函数覆盖掉Module中的forward
# Linear also constructed in Module,Hence also a callable object
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
#torch.nn.Linearactually constructing an object,Weights and biases are included,继承自Module
# (1,1)refers to each input samplexand every output sampley的特征维度,这里数据集中的x和y的特征都是1维的
# 该线性层需要学习的参数是w和b 获取w/b的方式分别是~linear.weight/linear.bias
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
#self.linear(x) add after the object()Means a callable object is implemented
y_pred = self.linear(x)
return y_pred
#model是一个callable,即可调用的对象,可以model(x)
model = LinearModel()
# 构造损失函数和优化器MSE
# MSELoss也继承自 nn.Module
#criterion = torch.nn.MSELoss(size_average=False)中不能设置size_average=False,会出现以下报错
#UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # model.parameters()自动完成参数的初始化操作
# Reduce the loss on the training set if needed,You can increase the number of training sessions,即 for epoch in range(1000)
# But there are dangers in doing so,Because the loss on the training set is getting smaller and smaller,The loss on the test set may get bigger and bigger,产生过拟合问题
for epoch in range(100):
y_pred = model(x_data) # forward:predict
loss = criterion(y_pred, y_data) # forward: loss
print(epoch, loss.item())
optimizer.zero_grad()
loss.backward() # backward: autograd,自动计算梯度
optimizer.step() # update 参数,即更新w和b的值
print('w = ', model.linear.weight.item())#weight是一个矩阵,So the value needs to be calleditem()
print('b = ', model.linear.bias.item())
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
运行结果
0 35.37689208984375
1 15.92350959777832
2 7.260907173156738
3 3.402086019515991
4 1.6818101406097412
5 0.9135867357254028
6 0.569225549697876
............
94 0.08291453868150711
95 0.08172288537025452
96 0.08054838329553604
97 0.07939067482948303
98 0.07824991643428802
99 0.07712534815073013
w = 1.8151198625564575
b = 0.4202759563922882
y_pred = tensor([[7.6808]])
边栏推荐
- Flutter真机运行及模拟器运行
- Flutter learning 2-dart learning
- Flutter real machine running and simulator running
- 学习总结week3_1函数
- Database experiment five backup and recovery
- SQL(二) —— join窗口函数视图
- 第二讲 Linear Model 线性模型
- jvm three heap and stack
- 电话溥功能
- pycharm中调用Matlab配置:No module named ‘matlab.engine‘; ‘matlab‘ is not a package
猜你喜欢

MySQL Foundation (1) - Basic Cognition and Operation

The role of DataContext in WPF

【cesium】Load and locate 3D Tileset

DOM及其应用

结构光三维重建(一)条纹结构光三维重建

第三讲 Gradient Tutorial梯度下降与随机梯度下降

pycharm中调用Matlab配置:No module named ‘matlab.engine‘; ‘matlab‘ is not a package

shell函数
![[cesium] element highlighting](/img/99/504ca9802db83eb33bc6d91b34fa84.png)
[cesium] element highlighting
![[Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)](/img/86/9c9a2541f2b7089ae47e9832fffdb3.png)
[Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)
随机推荐
位运算符与逻辑运算符的区别
flex布局青蛙游戏通关攻略
【Transfer】What is etcd
vscode+pytorch使用经验记录(个人记录+不定时更新)
结构光三维重建(一)条纹结构光三维重建
coppercam入门手册[6]
Error creating bean with name 'configDataContextRefresher' defined in class path resource
jvm 三 之堆与栈
u-boot debugging and positioning means
【过一下16】回顾一下七月
Multi-threaded query results, add List collection
number_gets the specified number of decimals
Returned object not currently part of this pool
entry point injection
redis 缓存清除策略
Flutter learning - the beginning
Geek卸载工具
第二讲 Linear Model 线性模型
RL强化学习总结(一)
2022 The 4th C.Easy Counting Problem (EGF+NTT)