当前位置:网站首页>【深度学习】关于处理过拟合的一点心得
【深度学习】关于处理过拟合的一点心得
2022-08-02 14:37:00 【折途】
前言
现在的深度学习与传统的机器学习相比,最显著的特点就是一个“深”字,如今深度学习的网络层数就算有个成百上千层也并不奇怪。然而过于强大的神经网络会导致一个问题,那就是过拟合,神经网络可以精确地预测出提供的数据集的结果,可一旦传入从未见过的数据,则准确率低的离谱。
过拟合的一个明显的特征就是训练时,损失值(loss)极低,精度极高接近100%,并且训练集的精度与验证集的精度有着不小差距,那么该如何解决过拟合这一难题呢?
从数据集入手
最直接也是代价最高的办法就是增加数据集的数量,但数据集获取困难,不仅要去寻找图片,还要分类打标签,仅靠个人的努力,耗尽一天的时间也增加不了多少数据。
但可以增强数据,通过旋转、裁剪、添加噪声点……理论上可以获取无限多的数据,虽效果较差,但也是可以尝试的。TensorFlow和PyTorch可以参考我下面这篇博客
【TensorFlow&PyTorch】图像数据增强API_折途的博客-CSDN博客在进行深度学习训练时,遇到训练效果较差、训练集数量小、有过拟合趋向时可以选择加大数据集数量来优化训练模型,但是大多数情况下,增加数据集数量所花费的时间精力是巨大的,所以我们更常用的方法是对现有的数据集进行数据增强。不如实实在在增加数据集数量,但是还是有一定的效果的,性价比高。(只要加几行代码)TensorFlow的API在image下(我用的2.0版本,不同的版本可能API不同,但是基本都可以在iamge下找到)Modulehttps以下列举几个本人认为常用的方法....https://blog.csdn.net/m0_63235356/article/details/125972651?spm=1001.2014.3001.5501另外,小数据集(一共就几千张图片的)可以适当加大测试集在整个数据集的占比,一般数据集,6:1来分配训练集和测试集即可,小数据集可以将比例调至4:1。
从网络层入手
可以在网络层内加入规范化函数,以TensorFlow为例:
layer=tensorflow.keras.layers.Conv2D(filters,kernel_size,strides,padding="same",kernel_regularizer=tensorflow.keras.regularizers.l2(0.001)
其中filters为输出维度
kernel_size为卷积核大小
strides为步长(步距)
padding为填充模式,"same"表示填充后输出大小和输入保持一致
kernel_regularizer可以指定权值,使网络层正则化以缓解过拟合.
还可以在全连接层中添加Dropout层,即随机断开全连接层的部分连接点:
#tensorflow为例:
dropout_layer=tensorflow.keras.layers.Dropout(0.5)
#PyTorch为例:
dropout_layer=troch.nn.Dropout(0.5)
也可以在卷积层后加入规范化层:
#TensorFlow为例:
batchnorm_layer=tensorflow.keras.layers.VatchNormalization()
#PyTorch为例:
batchnorm_layer=troch.nn.BatchNorm2d()
更多网络层可以查看我下面这篇博客
从训练入手:
可以在过拟合之前就将训练停止并保存,将网络参数保持在相对优秀的水平.以TensorFlow为例:
from tensorflow.keras.callbacks import EarlyStopping
early=EarlyStopping(monitor='val_accuracy'
min_delta=0.001
patience=5)
model.fit(train_data,epochs=100,callbacks=[early])
其中monitor为评估标准
min_delta为最小变化量
patience为训练轮数
例子中的意思为若连续5轮训练,精度的提升没有达到0.001,则停止训练.
定义完'early'(名字可以随便起)后,在模型(model)训练(fit)中加入即可.
边栏推荐
猜你喜欢
Apache APISIX 2.15 版本发布,为插件增加更多灵活性
XGBoost 和随机森林在表格数据上优于深度学习?
“绿色低碳+数字孪生“双轮驱动,解码油气管道站升级难点 | 图扑软件
2022-07-23 第六小组 瞒春 学习笔记
两分钟录音就可秒变语言通!火山语音音色复刻技术如何修炼而成?
How to check the WeChat applet server domain name and modify it
2022-07-25 第六小组 瞒春 学习笔记
如何使用Swiper外部插件写一个轮播图
HDU1561 树形背包dp+边界优化 0ms过题
使用 docker 搭建 redis-cluster 集群
随机推荐
2022-07-28 第六小组 瞒春 学习笔记
《数字经济全景白皮书》银行业智能风控科技应用专题分析 发布
2022-07-29 第六小组 瞒春 学习笔记
vite.config.ts 引入 `path` 模块注意点!
李开复花上千万投的缝纫机器人,团队出自大疆
什么是Knife4j?
ELK日志分析系统
PAT甲级 1145 哈希 - 平均查找时间
两分钟录音就可秒变语言通!火山语音音色复刻技术如何修炼而成?
类加载过程
中国服装行业已形成一套完整的产业体系
Window function method for FIR filter design
MySQL 的几种碎片整理方案总结(解决delete大量数据后空间不释放的问题)
Win 10、Win 11 安装 MuJoCo 及 mujoco-py 教程
为什么四个字节的float表示的范围比八个字节的long表示的范围要广
树状DP(记忆化搜索)PAT甲级 1079 1090 1106
2022-07-13 第五小组 瞒春 学习笔记
MySQL 行级锁(行锁、临键锁、间隙锁)
A status code, and access baidu process
Traverse Heap PAT Class A 1155 Heap Path