当前位置:网站首页>8.优化器
8.优化器
2022-07-07 23:11:00 【booze-J】
文章
一、优化器
常见的一些优化器有:SGD、Adagrad、Adadelta、RMSprop、Adam、Adamax、Nadam、TFOptimizer等等。
1.SGD(Stochastic gradient descent)
标准梯度下降法:
标准梯度下降法计算所有样本汇总误差,然后根据总误差来更新权值。
随机梯度下降法:
随机梯度下降法随机抽取一个样本来计算误差,然后更新权值。
批量梯度下降法:
批量梯度下降算是一种折中的方案,从总样本中选取一个批次(比如一共有10000个样本,随机选取100个样本作为一个batch),然后计算这个batch的总误差,根据总误差来更新权值。
标准梯度下降法:速度慢,效果好
随机梯度下降法:速度快,效果差

2.Momentum
γ \gamma γ动力,通常设置为0.9。
当前权值的改变会受到上一次权值改变的影响,类似于小球向下滚动的时候带上了惯性。这样可以加快小球的向下的速度。
3.NAG(Nesterov accelerated gradient)
γ \gamma γ动力,通常设置为0.9。


4.Adagrad

ε:避免分母为0,取值一般是1e-8。
Adagrad主要的优势在于不需要人为的调节学习率,它可以自动调节。它的缺点在于,随着迭代次数的增多,学习率也会越来越低,最终会趋向于0。
5.RMSprop

γ \gamma γ动力,通常设置为0.9。
RMSprop是Adagrad的改进,RMSprop不会出现学习率越来越低的问题,而且也能自己调节学习率,可以得到一个比较好的效果。
6.Adadelta

γ \gamma γ动力,通常设置为0.9。
Adadelta也是Adagrad的改进,Adadelta不需要使用学习率也可以达到一个很好的效果。
7.Adam

β1:通常取0.9,β2:通常取0.999。
Adam是常用的一种优化器。Adam会存储之前衰减的平方梯度,同时它也会保存之前衰减的梯度。经过一些处理之后再用来更新权值W。
效果对比:


二、优化器的简单使用
以使用Adam优化器为例:
修改4.交叉熵中的
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=sgd,
loss="categorical_crossentropy",
metrics=['accuracy']
)
变化为
# 定义优化器
sgd = SGD(lr=0.2)
adam = Adam(lr=0.001)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=adam,
loss="categorical_crossentropy",
metrics=['accuracy']
)
使用前需要先导入from tensorflow.keras.optimizers import SGD,Adam。
运行结果:
完整代码
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD,Adam
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是10 输入是784,设置偏置为1,添加softmax激活函数
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# 定义优化器
sgd = SGD(lr=0.2)
adam = Adam(lr=0.001)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=adam,
loss="categorical_crossentropy",
metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
边栏推荐
- Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
- Malware detection method based on convolutional neural network
- Handwriting a simulated reentrantlock
- 哪个券商公司开户佣金低又安全,又靠谱
- 接口测试要测试什么?
- Hotel
- German prime minister says Ukraine will not receive "NATO style" security guarantee
- SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
- 炒股开户怎么最方便,手机上开户安全吗
- tourist的NTT模板
猜你喜欢

New library online | cnopendata China Star Hotel data

【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础

【obs】官方是配置USE_GPU_PRIORITY 效果为TRUE的
![[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output](/img/79/f5cffe62d5d1e4a69b6143aef561d9.png)
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output

Service mesh introduction, istio overview
![Cause analysis and solution of too laggy page of [test interview questions]](/img/8d/3ca92ce5f9cdc85d52dbcd826e477d.jpg)
Cause analysis and solution of too laggy page of [test interview questions]

What does interface testing test?

3 years of experience, can't you get 20K for the interview and test post? Such a hole?

FOFA-攻防挑战记录

RPA cloud computer, let RPA out of the box with unlimited computing power?
随机推荐
Deep dive kotlin synergy (XXII): flow treatment
A brief history of information by James Gleick
NTT template for Tourism
22年秋招心得
AI遮天传 ML-初识决策树
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
jemter分布式
Hotel
Codeforces Round #804 (Div. 2)(A~D)
Summary of the third course of weidongshan
German prime minister says Ukraine will not receive "NATO style" security guarantee
【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出
股票开户免费办理佣金最低的券商,手机上开户安全吗
Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
Basic types of 100 questions for basic grammar of Niuke
Codeforces Round #804 (Div. 2)(A~D)
What does interface testing test?
They gathered at the 2022 ecug con just for "China's technological power"
Fofa attack and defense challenge record
The method of server defense against DDoS, Hangzhou advanced anti DDoS IP section 103.219.39 x