当前位置:网站首页>第六章 网络学习相关技巧5(超参数验证)
第六章 网络学习相关技巧5(超参数验证)
2022-06-24 19:42:00 【追寻远方的人】
神经网络中,除了权重和偏置等参数,超参数(hyper-parameter)也经常出现。这里所说的超参数是指,比如各层的神经元数量、batch大小、参数更新时的学习率或权值衰减等。如果这些超参数没有设置合适的值,模型的性能就会很差。虽然超参数的取值非常重要,但是在决定超参数的过程中一般会伴随很多的试错。本节将介绍尽可能高效地寻找超参数的值的方法。
6.1验证数据
我们使用的数据集分成了训练数据和测试数据,训练数据用于学习,测试数据用于评估泛化能力。由此,就可以评估是否只过度拟合了训练数据(是否发生了过拟合),以及泛化能力如何等。
然而,我们也需要对超参数设置各种各样的值以进行验证。因此,调整超参数时,必须使用超参数专用的确认数据。用于调整超参
数的数据,一般称为验证数据(validation data)。我们使用这个验证数据来评估超参数的好坏。
【注】训练数据用于参数(权重和偏置)的学习,验证数据用于超参数的性能评估。测试数据是为了确认泛化能力,要在最后使用(比较理想的是只用一次)。
根据不同的数据集,有的会事先分成训练数据、验证数据、测试数据三部分,有的只分成训练数据和测试数据两部分,有的则不进行分割。在这种情况下,用户需要自行进行分割。如果是MNIST数据集,获得验证数据的最简单的方法就是从训练数据中事先分割20%作为验证数据。
代码实现如下:
# coding: utf-8
import os
import sys
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
from common.util import shuffle_dataset
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
# 打乱训练数据
x_train, t_train = shuffle_dataset(x_train, t_train)
# 分割验证数据
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate) # 验证数据集的数量
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]
6.2超参数的最优化
进行超参数的最优化时,逐渐缩小超参数的“好值”的存在范围非常重要。
逐渐缩小范围:是指一开始先大致设定一个范围,从这个范围中随机选出一个超参数(采样),用这个采样到的值进行识别精度的评估;然后,多次重复该操作,观察识别精度的结果,根据这个结果缩小超参数的“好值”的范围。通过重复这一操作,就可以逐渐确定超参数的合适范围。
【注】在进行神经网络的超参数的最优化时,与网格搜索等有规律的搜索相比,随机采样的搜索方式效果更好。这是因为在多个超参数中,各个超参数对最终的识别精度的影响程度不同。
超参数的范围只要“大致地指定”就可以了。所谓“大致地指定”,是指像0.001(10^ −3 )到1000(10^ 3 )这样,以“10的阶乘”的尺度指定范围(也表述为“用对数尺度(log scale)指定”)。在Python中可以写成 10 ** np.random.uniform(-3, 3) 。
在超参数的最优化中,要注意的是深度学习需要很长时间(比如,几天或几周)。因此,在超参数的搜索中,需要尽早放弃那些不符合逻辑的超参数。于是,在超参数的最优化中,减少学习的epoch,缩短一次评估所需的时间是一个不错的办法。
6.2.1优化步骤
1、设定超参数的范围。
2、从设定的超参数范围中随机采样。
3、使用步骤2中采样到的超参数的值进行学习,通过验证数据评估识别精度(但是要将epoch设置得很小)。
4、重复步骤2和步骤3(100次等),根据它们的识别精度的结果,缩小超参数的范围。
反复进行上述操作,不断缩小超参数的范围,在缩小到一定程度时,从该范围中选出一个超参数的值。这就是进行超参数的最优化的一种方法。
【注】在超参数的最优化中,如果需要更精炼的方法,可以使用贝叶斯最优化(Bayesian optimization)。
6.3实现
使用MNIST数据集进行超参数的最优化。这里我们将学习率和控制权值衰减强度的系数(下文称为“权值衰减系数”)这两个超参数的搜索问题作为对象。
通过从0.001(10 −3 )到1000(10 3 )这样的对数尺度的范围中随机采样进行超参数的验证。这在Python中可以写成 10 ** np.random.uniform(-3, 3) 。在该实验中,权值衰减系数的初始范围为10 −8 到10 −4 ,学习率的初始范围为10 ^−6 到10 ^−2 。此时,超参数的随机采样的代码如下所示。
weight_decay = 10 ** np.random.uniform(-8, -4)
lr = 10 ** np.random.uniform(-6, -2)
6.3.1案例
例:使用MNIST数据以权值衰减系数为10 −8 到10 −4 、学习率为10 −6 到10 −2 的范围进行实验。
文件目录如下:

funtions.py, gradient.py, layers.py, multi_layer_net.py, optimizer.py, util.py)见前面博文
trainer.py见该博文
6.3.2代码及结果
6.3.2.1结果
运行hyperparameter_optimization.py 结果如下:

【注】按识别精度从高到低的顺序排列了验证数据的学习的变化。从图中可知,直到“Best-5”左右,学习进行得都很顺利。

【注】从这个结果可以看出,学习率在0.001到0.01、权值衰减系数在10 −8 到10 −6 之间时,学习可以顺利进行。像这样,观察可以使学习顺利进行的超参数的范围,从而缩小值的范围。然后,在这个缩小的范围中重复相同的操作。这样就能缩小到合适的超参数的存在范围,然后在某个阶段,选择一个最终的超参数的值。
6.3.2.2代码实现
hyperparameter_optimization.py 代码实现如下:
# coding: utf-8
import sys, os
sys.path.append(os.pardir)
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 削减学习数据
x_train = x_train[:500]
t_train = t_train[:500]
# 训练集与验证集分离
validation_rate = 0.20
validation_num = int(x_train.shape[0] * validation_rate)
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
t_val = t_train[:validation_num]
x_train = x_train[validation_num:]
t_train = t_train[validation_num:]
def __train(lr, weight_decay, epocs=50):
network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
output_size=10, weight_decay_lambda=weight_decay)
trainer = Trainer(network, x_train, t_train, x_val, t_val,
epochs=epocs, mini_batch_size=100,
optimizer='sgd', optimizer_param={
'lr': lr}, verbose=False)
trainer.train()
return trainer.test_acc_list, trainer.train_acc_list
# 超参数随机搜索
optimization_trial = 100
results_val = {
}
results_train = {
}
for _ in range(optimization_trial):
# 指定搜索超参数的范围===============
weight_decay = 10 ** np.random.uniform(-8, -4)
lr = 10 ** np.random.uniform(-6, -2)
# ================================================
val_acc_list, train_acc_list = __train(lr, weight_decay)
print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
results_val[key] = val_acc_list
results_train[key] = train_acc_list
# 绘制图表========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0
for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):
print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)
plt.subplot(row_num, col_num, i+1)
plt.title("Best-" + str(i+1))
plt.ylim(0.0, 1.0)
if i % 5: plt.yticks([])
plt.xticks([])
x = np.arange(len(val_acc_list))
plt.plot(x, val_acc_list)
plt.plot(x, results_train[key], "--")
i += 1
if i >= graph_draw_num:
break
plt.show()
边栏推荐
- Uncover the secrets of Huawei cloud enterprise redis issue 16: acid'true' transactions beyond open source redis
- 关于某手滑块的一些更新(6-18,js逆向)
- Learn about redlock
- laravel学习笔记
- 【nvm】
- Blogs personal blog test point (manual test)
- Laravel user authorization
- Laravel pagoda security configuration
- Epics record reference 2 -- epics process database concept
- 23研考生注意啦!备考期间最容易中招的骗局,居然是它们?!
猜你喜欢

Attention, postgraduate candidates! They are the easiest scams to get caught during the preparation period?!

Learn about redlock

Uncover the secrets of Huawei cloud enterprise redis issue 16: acid'true' transactions beyond open source redis

【js】-【栈、队-应用】-学习笔记

What kind of processor architecture is ARM architecture?

Tech Talk 活动回顾|云原生 DevOps 的 Kubernetes 技巧

01_ Getting started with the spingboot framework

Online group chat and dating platform test point
How should we measure agile R & D projects?

07_SpingBoot 实现 RESTful 风格
随机推荐
Research Report on market evaluation and investment direction of Chinese dermatology drugs (2022 Edition)
laravel model 注意事项
Push markdown format information to the nailing robot
laravel 定时任务
A big factory interview must ask: how to solve the problem of TCP reliable transmission? 8 pictures for you to learn in detail
canvas 实现图片新增水印
【js】-【数组、栈、队列、链表基础】-笔记
Servlet
How to submit the shopee opening and settlement flow?
Selection (026) - what is the output of the following code?
EMI的主要原因-工模电流
QT to place the form in the lower right corner of the desktop
F29oc analysis
Selection (025) - what is the output of the following code?
[Wuhan University] information sharing of the first and second postgraduate entrance examinations
. Net 7 Preview 1 has been officially released
Docker installation redis- simple without pit
Spark 离线开发框架设计与实现
Listen to the markdown file and hot update next JS page
Research Report on research and investment prospects of China's container coating industry (2022 Edition)