当前位置:网站首页>2.非线性回归
2.非线性回归
2022-07-07 23:11:00 【booze-J】
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential 按顺序构成的模型
from keras.models import Sequential
# Dense 全连接层
from keras.layers import Dense,Activation
from tensorflow.keras.optimizers import SGD
2.随机生成数据集
# 使用numpy生成200个随机点
# 在-0.5~0.5生成200个点
x_data = np.linspace(-0.5,0.5,200)
noise = np.random.normal(0,0.02,x_data.shape)
# y = x^2 + noise
y_data = np.square(x_data) + noise
# 显示随机点
plt.scatter(x_data,y_data)
plt.show()
运行结果:
3.非线性回归
# 构建一个顺序模型
model = Sequential()
# 按shift+tab可以显示参数
# 1-10-1 输入一个神经层,10个隐藏层,输出一个神经层
model.add(Dense(units=10,input_dim=1))
# 添加激活函数 激活函数默认情况下是线性的,但是我们是非线性回归,所以要对激活函数进行修改
# 添加激活函数 方式一:直接添加activation参数
# model.add(Dense(units=10,input_dim=1,activation="relu"))
# 添加激活函数 方式二:直接添加Activation激活层
model.add(Activation("tanh"))
model.add(Dense(units=1,input_dim=10))
# model.add(Dense(units=1,input_dim=10,activation="relu"))
model.add(Activation("tanh"))
# 定义优化算法 提高学习率可以降低迭代次数
sgd = SGD(lr=0.3)
# sgd:Stochastic gradient descent , 随机梯度下降法 默认的sgd学习率表较小 所以需要的迭代次数就比较多 消耗的时间也就更多
# mse:Mean Squared Error , 均方误差
model.compile(optimizer=sgd,loss='mse')
# 训练3001个批次
for step in range(3001):
# 每次训练一个批次
cost = model.train_on_batch(x_data,y_data)
# 每500个batch打印一次cost
if step%500==0:
print("cost:",cost)
# 打印权值和批次值
W,b = model.layers[0].get_weights()
print("W:",W)
print("b:",b)
# x_data输入网络中得到预测值
y_pred = model.predict(x_data)
# 显示随机点
plt.scatter(x_data,y_data)
# 显示预测结果
plt.plot(x_data,y_pred,"r-",lw=3)
plt.show()
运行结果:
注意
- 1.进行非线性回归的要注意修改激活函数,因为激活函数默认情况下是线性的。
- 2.默认的sgd学习率表较小 所以需要的迭代次数就比较多 消耗的时间也就更多,所以可以自定义sgd学习率进行学习。
- 3.整体上代码和线性回归的代码非常类似,只是修改了一下数据集,网络模型,激活函数,以及优化器。
边栏推荐
- v-for遍历元素样式失效
- 赞!idea 如何单窗口打开多个项目?
- ABAP ALV LVC template
- 【笔记】常见组合滤波电路
- Kubernetes Static Pod (静态Pod)
- 华泰证券官方网站开户安全吗?
- 韦东山第三期课程内容概要
- Jouer sonar
- Lecture 1: the entry node of the link in the linked list
- Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
猜你喜欢
取消select的默认样式的向下箭头和设置select默认字样
What does interface testing test?
RPA cloud computer, let RPA out of the box with unlimited computing power?
An error is reported during the process of setting up ADG. Rman-03009 ora-03113
基于微信小程序开发的我最在行的小游戏
Binder core API
4.交叉熵
What if the testing process is not perfect and the development is not active?
ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
随机推荐
Redis, do you understand the list
[OBS] the official configuration is use_ GPU_ Priority effect is true
5g NR system messages
Jouer sonar
51与蓝牙模块通讯,51驱动蓝牙APP点灯
Summary of weidongshan phase II course content
German prime minister says Ukraine will not receive "NATO style" security guarantee
Marubeni official website applet configuration tutorial is coming (with detailed steps)
取消select的默认样式的向下箭头和设置select默认字样
玩轉Sonar
5G NR 系统消息
How is it most convenient to open an account for stock speculation? Is it safe to open an account on your mobile phone
Four stages of sand table deduction in attack and defense drill
CVE-2022-28346:Django SQL注入漏洞
2022-07-07: the original array is a monotonic array with numbers greater than 0 and less than or equal to K. there may be equal numbers in it, and the overall trend is increasing. However, the number
深潜Kotlin协程(二十二):Flow的处理
3 years of experience, can't you get 20K for the interview and test post? Such a hole?
Langchao Yunxi distributed database tracing (II) -- source code analysis
Reentrantlock fair lock source code Chapter 0
Malware detection method based on convolutional neural network