当前位置:网站首页>8.优化器
8.优化器
2022-07-07 23:11:00 【booze-J】
文章
一、优化器
常见的一些优化器有:SGD、Adagrad、Adadelta、RMSprop、Adam、Adamax、Nadam、TFOptimizer等等。
1.SGD(Stochastic gradient descent)
标准梯度下降法:
标准梯度下降法计算所有样本汇总误差,然后根据总误差来更新权值。
随机梯度下降法:
随机梯度下降法随机抽取一个样本来计算误差,然后更新权值。
批量梯度下降法:
批量梯度下降算是一种折中的方案,从总样本中选取一个批次(比如一共有10000个样本,随机选取100个样本作为一个batch),然后计算这个batch的总误差,根据总误差来更新权值。
标准梯度下降法:速度慢,效果好
随机梯度下降法:速度快,效果差
2.Momentum
γ \gamma γ动力,通常设置为0.9。
当前权值的改变会受到上一次权值改变的影响,类似于小球向下滚动的时候带上了惯性。这样可以加快小球的向下的速度。
3.NAG(Nesterov accelerated gradient)
γ \gamma γ动力,通常设置为0.9。
4.Adagrad
ε:避免分母为0,取值一般是1e-8。
Adagrad主要的优势在于不需要人为的调节学习率,它可以自动调节。它的缺点在于,随着迭代次数的增多,学习率也会越来越低,最终会趋向于0。
5.RMSprop
γ \gamma γ动力,通常设置为0.9。
RMSprop是Adagrad的改进,RMSprop不会出现学习率越来越低的问题,而且也能自己调节学习率,可以得到一个比较好的效果。
6.Adadelta
γ \gamma γ动力,通常设置为0.9。
Adadelta也是Adagrad的改进,Adadelta不需要使用学习率也可以达到一个很好的效果。
7.Adam
β1:通常取0.9,β2:通常取0.999。
Adam是常用的一种优化器。Adam会存储之前衰减的平方梯度,同时它也会保存之前衰减的梯度。经过一些处理之后再用来更新权值W。
效果对比:
二、优化器的简单使用
以使用Adam优化器为例:
修改4.交叉熵中的
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=sgd,
loss="categorical_crossentropy",
metrics=['accuracy']
)
变化为
# 定义优化器
sgd = SGD(lr=0.2)
adam = Adam(lr=0.001)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=adam,
loss="categorical_crossentropy",
metrics=['accuracy']
)
使用前需要先导入from tensorflow.keras.optimizers import SGD,Adam
。
运行结果:
完整代码
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD,Adam
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是10 输入是784,设置偏置为1,添加softmax激活函数
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# 定义优化器
sgd = SGD(lr=0.2)
adam = Adam(lr=0.001)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=adam,
loss="categorical_crossentropy",
metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
边栏推荐
- New library launched | cnopendata China Time-honored enterprise directory
- 华为交换机S5735S-L24T4S-QA2无法telnet远程访问
- Service mesh introduction, istio overview
- Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
- Basic mode of service mesh
- Basic types of 100 questions for basic grammar of Niuke
- From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
- Deep dive kotlin synergy (XXII): flow treatment
- 詹姆斯·格雷克《信息简史》读后感记录
- Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
猜你喜欢
他们齐聚 2022 ECUG Con,只为「中国技术力量」
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
QT establish signal slots between different classes and transfer parameters
letcode43:字符串相乘
【笔记】常见组合滤波电路
搭建ADG过程中复制报错 RMAN-03009 ORA-03113
华为交换机S5735S-L24T4S-QA2无法telnet远程访问
The standby database has been delayed. Check that the MRP is wait_ for_ Log, apply after restarting MRP_ Log but wait again later_ for_ log
Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
随机推荐
什么是负载均衡?DNS如何实现负载均衡?
The weight of the product page of the second level classification is low. What if it is not included?
fabulous! How does idea open multiple projects in a single window?
基于人脸识别实现课堂抬头率检测
Fofa attack and defense challenge record
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
DNS 系列(一):为什么更新了 DNS 记录不生效?
玩轉Sonar
Installation and configuration of sublime Text3
动态库基本原理和使用方法,-fPIC 选项的来龙去脉
新库上线 | CnOpenData中国星级酒店数据
接口测试进阶接口脚本使用—apipost(预/后执行脚本)
备库一直有延迟,查看mrp为wait_for_log,重启mrp后为apply_log但过一会又wait_for_log
tourist的NTT模板
New library online | cnopendata China Star Hotel data
An error is reported during the process of setting up ADG. Rman-03009 ora-03113
Su embedded training - Day3
QT adds resource files, adds icons for qaction, establishes signal slot functions, and implements
手写一个模拟的ReentrantLock
Handwriting a simulated reentrantlock