当前位置:网站首页>1.线性回归
1.线性回归
2022-07-07 23:11:00 【booze-J】
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential 按顺序构成的模型
from keras.models import Sequential
# Dense 全连接层
from keras.layers import Dense
2.随机生成数据集
# 使用numpy生成100个随机点
x_data = np.random.rand(100)
# 噪音的形状和x_data的形状是一样的
noise = np.random.normal(0,0.01,x_data.shape)
# 设置w=0.1 b=0.2
y_data = x_data*0.1+0.2+noise
# y_data_no_noisy = x_data*0.1+0.2
# 显示随机点
plt.scatter(x_data,y_data)
# plt.scatter(x_data,y_data_no_noisy)
运行效果:
这是添加噪声的情况下y_data = x_data*0.1+0.2+noise
:
不添加噪声的情况下y_data_no_noisy = x_data*0.1+0.2
(w=0.1,b=0.2):
线性回归就是要根据添加噪声的散点图,拟合出一条与不添加噪声的散点图近似的直线。
3.线性回归
# 构建一个顺序模型
model = Sequential()
# 在模型中添加一个全连接层 在jupyter-notebook中,按shift+tab可以显示参数
model.add(Dense(units=1,input_dim=1))
# sgd:Stochastic gradient descent , 随机梯度下降法
# mse:Mean Squared Error , 均方误差
model.compile(optimizer='sgd',loss='mse')
# 训练3001个批次
for step in range(3001):
# 每次训练一个批次 的损失
cost = model.train_on_batch(x_data,y_data)
# 每500个batch打印一次cost
if step%500==0:
print("cost:",cost)
# 打印权值和批次值
W,b = model.layers[0].get_weights()
print("W:",W)
print("b:",b)
# x_data输入网络中得到预测值
y_pred = model.predict(x_data)
# 显示随机点
plt.scatter(x_data,y_data)
# 显示预测结果
plt.plot(x_data,y_pred,"r-",lw=3)
plt.show()
运行效果:
可以看到预测出来的w和b都十分接近我们设置的w和b。
注意
- 在jupyter-notebook中,按shift+tab可以显示参数
- train_on_batch的使用
- compile的使用
边栏推荐
- STL -- common function replication of string class
- 8道经典C语言指针笔试题解析
- Which securities company has a low, safe and reliable account opening commission
- Handwriting a simulated reentrantlock
- DNS series (I): why does the updated DNS record not take effect?
- Kubernetes Static Pod (静态Pod)
- ABAP ALV LVC template
- Play sonar
- Jemter distributed
- 5G NR 系统消息
猜你喜欢
51与蓝牙模块通讯,51驱动蓝牙APP点灯
语义分割模型库segmentation_models_pytorch的详细使用介绍
2022-07-07: the original array is a monotonic array with numbers greater than 0 and less than or equal to K. there may be equal numbers in it, and the overall trend is increasing. However, the number
From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
5G NR 系统消息
C # generics and performance comparison
英雄联盟胜负预测--简易肯德基上校
Deep dive kotlin synergy (XXII): flow treatment
深潜Kotlin协程(二十二):Flow的处理
4.交叉熵
随机推荐
[OBS] the official configuration is use_ GPU_ Priority effect is true
Malware detection method based on convolutional neural network
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
华泰证券官方网站开户安全吗?
基于人脸识别实现课堂抬头率检测
ReentrantLock 公平锁源码 第0篇
Leetcode brush questions
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
Cause analysis and solution of too laggy page of [test interview questions]
8.优化器
Su embedded training - day4
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
ReentrantLock 公平锁源码 第0篇
Which securities company has a low, safe and reliable account opening commission
Summary of the third course of weidongshan
丸子官网小程序配置教程来了(附详细步骤)
Experience of autumn recruitment in 22 years
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)