当前位置:网站首页>1.线性回归
1.线性回归
2022-07-07 23:11:00 【booze-J】
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential 按顺序构成的模型
from keras.models import Sequential
# Dense 全连接层
from keras.layers import Dense
2.随机生成数据集
# 使用numpy生成100个随机点
x_data = np.random.rand(100)
# 噪音的形状和x_data的形状是一样的
noise = np.random.normal(0,0.01,x_data.shape)
# 设置w=0.1 b=0.2
y_data = x_data*0.1+0.2+noise
# y_data_no_noisy = x_data*0.1+0.2
# 显示随机点
plt.scatter(x_data,y_data)
# plt.scatter(x_data,y_data_no_noisy)
运行效果:
这是添加噪声的情况下y_data = x_data*0.1+0.2+noise
:
不添加噪声的情况下y_data_no_noisy = x_data*0.1+0.2
(w=0.1,b=0.2):
线性回归就是要根据添加噪声的散点图,拟合出一条与不添加噪声的散点图近似的直线。
3.线性回归
# 构建一个顺序模型
model = Sequential()
# 在模型中添加一个全连接层 在jupyter-notebook中,按shift+tab可以显示参数
model.add(Dense(units=1,input_dim=1))
# sgd:Stochastic gradient descent , 随机梯度下降法
# mse:Mean Squared Error , 均方误差
model.compile(optimizer='sgd',loss='mse')
# 训练3001个批次
for step in range(3001):
# 每次训练一个批次 的损失
cost = model.train_on_batch(x_data,y_data)
# 每500个batch打印一次cost
if step%500==0:
print("cost:",cost)
# 打印权值和批次值
W,b = model.layers[0].get_weights()
print("W:",W)
print("b:",b)
# x_data输入网络中得到预测值
y_pred = model.predict(x_data)
# 显示随机点
plt.scatter(x_data,y_data)
# 显示预测结果
plt.plot(x_data,y_pred,"r-",lw=3)
plt.show()
运行效果:
可以看到预测出来的w和b都十分接近我们设置的w和b。
注意
- 在jupyter-notebook中,按shift+tab可以显示参数
- train_on_batch的使用
- compile的使用
边栏推荐
- Jemter distributed
- 韦东山第三期课程内容概要
- 8道经典C语言指针笔试题解析
- 5G NR 系统消息
- After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
- How does the markdown editor of CSDN input mathematical formulas--- Latex syntax summary
- 13. Model saving and loading
- Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
- 新库上线 | CnOpenData中华老字号企业名录
- Service Mesh的基本模式
猜你喜欢
Jouer sonar
5g NR system messages
8道经典C语言指针笔试题解析
letcode43:字符串相乘
3 years of experience, can't you get 20K for the interview and test post? Such a hole?
国外众测之密码找回漏洞
How does the markdown editor of CSDN input mathematical formulas--- Latex syntax summary
New library online | cnopendata China Star Hotel data
12. RNN is applied to handwritten digit recognition
What has happened from server to cloud hosting?
随机推荐
Class head up rate detection based on face recognition
4.交叉熵
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
German prime minister says Ukraine will not receive "NATO style" security guarantee
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
手机上炒股安全么?
22年秋招心得
5.过拟合,dropout,正则化
图像数据预处理
Redis, do you understand the list
5g NR system messages
新库上线 | CnOpenData中华老字号企业名录
How is it most convenient to open an account for stock speculation? Is it safe to open an account on your mobile phone
Malware detection method based on convolutional neural network
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
串口接收一包数据
Marubeni official website applet configuration tutorial is coming (with detailed steps)
取消select的默认样式的向下箭头和设置select默认字样
牛客基础语法必刷100题之基本类型
51与蓝牙模块通讯,51驱动蓝牙APP点灯