当前位置:网站首页>【Kaggle比赛常用trick】K折交叉验证、TTA
【Kaggle比赛常用trick】K折交叉验证、TTA
2022-07-30 12:04:00 【满船清梦压星河HK】
一、什么是k折交叉验证?
在训练阶段,我们一般不会使用全部的数据进行训练,而是采用交叉验证的方式来训练。交叉验证(Cross Validation,CV)是机器学习模型的重要环节之一。它可以增强随机性,从有限的数据中获得更全面的信息,减少噪声干扰,从而缓解过拟合,增强模型的泛化能力。
比赛一般会只给我们训练集,但是测试集我们是看不到的,所以我们一般会将训练集按照一定的方式划分为训练集和验证集。训练集用于模型的训练,验证集用于本地验证,选取最好的pt权重文件,再提交到比赛官网进行测试集的验证。所以如何划分训练集和验证集,让我们最大限度的利用训练集,学习有效的特征,是至关重要的。交叉验证就是做这个事的。
交叉验证步骤:
- 将整个数据集划分为大小相等的K个部分;
- 每次选取其中一份作为验证集,其余K-1份作为训练集进行训练;
- 重复K次,直至每一份数据都被当作验证集验证了一遍;
- 模型的最终精度是通过K个子模型的平均精度来计算的;
下面这个图可以比较好的诠释上面这个过程:
我们一般不会自己实现这个功能,一般都是调用SKLearn包直接使用,SKlearn帮我们实现了KFold、Stratified KFold、Group KFold和Stratified Group KFold四种方式,下面我一一介绍它们的区别和用法。
二、常见的几种交叉验证方式
2.1、KFold
KFold是最简单的一种K折交叉验证,它的具体步骤如下图所示是一个4折交叉验证,橘色代表验证集(1份),蓝色代表训练集(3份),整个数据集有三个类别(对应图中三种颜色的分布情况);这些数据属于很多个不同的组;

可以看的很清楚,这种K折交叉验证,有两个缺点:
- 不适应于数据集样本不均衡的情况,因为很可能会把整个少数的类别划分为验证集或训练集;
- 不适应于时间序列问题;
2.2、Stratified(分层) KFold
上面讲到,KFold不适应于数据不平衡的问题,所以Stratified KFold(分层)交叉验证就是专门来解决这个问题的。如下图,在分层交叉验证中,数据集依然被划分为K组,但是验证组的目标类别是从各个类中分层抽取出来的,是均匀的,所以就不会存在少数类别被全部划分为验证集或训练集。
特点:可以解决数据不平衡问题,但是不适应于时间序列问题。
2.3、Group (分组)KFold
GroupKFold是KFold一个变体,目的在于将group严格分开,就是说同一个group的数据只能出现在训练集或者验证集,不能同时出现在训练集和验证集,如下图:
特点:可以将数据的group完全分开,避免高度相似的样本既出现在训练集又测试在验证集。
2.4、Stratified Group KFold
Group KFold和Stratified KFold的合体,如下图:

特点:可以将数据的group和标签的class完全分层划开,避免出现样本高度相似和标签分布不均的问题。
2.5、Time Series Split
可以解决时间序列相关的问题。对于时间序列数据集,根据时间将数据分为训练和验证,也称为前向链接方法或滚动交叉验证。

使用方式举例:
skf = StratifiedGroupKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):
三、什么是TTA?
TTA,即Test time augmention,测试时增强。数据增强一般是出现在训练阶段,使用数据增强一般都能提升性能。而测试时数据增强是指在测试的时候,将原图进行数据增强(比如水平翻转、垂直翻转、对角线翻转、旋转等,这里假设使用了3种数据增强),可以得到4张测试图片,对这四张测试图片分布进行推理,得到推理结果。再对三张增强后的推理结果再变换回来(比如我对原图进行水平翻转,得到的mask,再对mask进行水平翻转)。最后就得到了4张预测结果,对这四张预测结果mask对应位置相加取平均,就得到了最终的mask预测果。
使用方式举例:
model = build_model(CFG, test_flag=True)
model.load_state_dict(torch.load(sub_ckpt_path))
model.eval()
y_preds = model(images) # [b, c, w, h]
y_preds = torch.nn.Sigmoid()(y_preds)
masks += y_preds
#x,y,xy flips as TTA
if CFG.tta:
flips = [[-1]] # 水平翻转
for f in flips:
images_f = torch.flip(images, f)
y_preds = model(images_f) # [b, c, w, h]
y_preds = torch.flip(y_preds, f)
y_preds = torch.nn.Sigmoid()(y_preds)
masks += y_preds
if CFG.tta:
total_ckpt_paths = len(ckpt_paths_dict) * CFG.n_fold * 2
else:
total_ckpt_paths = len(ckpt_paths_dict) * CFG.n_fold
Reference
知乎: 常见交叉验证方法汇总
边栏推荐
猜你喜欢

【32. 图中的层次(图的广度优先遍历)】

contentDocument contentWindow, canvas, svg, iframe

解码Redis最易被忽视的CPU和内存占用高问题

数据湖(十八):Flink与Iceberg整合SQL API操作

A tutorial on how to build a php environment under win

JD.com was brutally killed by middleware on two sides. After 30 days of learning this middleware booklet, it advanced to Ali.

Beijing, Shanghai and Guangzhou offline events丨The most unmissable technology gatherings at the end of the year are all gathered

刷屏了!!!

打破原则引入SQL,MongoDB到底想要干啥???

Program environment and preprocessing (detailed)
随机推荐
Transfer Learning Technology Training
横向对比5种常用的注册中心,无论是用于面试还是技术选型,都非常有帮助
Verilog grammar basics HDL Bits training 08
New:WebKitX ActiveX :::Crack
win下怎么搭建php环境的方法教程
维护数千规模MySQL实例,数据库灾备体系构建指南
saltstack学习3模块
The method of judging the same variable without the if branch
PyQt5快速开发与实战 8.4 设置窗口背景 && 8.5 不规则窗口的显示
Redis master-slave replication
Homework 7.29 correlation function directory and file attributes related functions
Farmers on the assembly line: I grow vegetables in a factory
和数集团:让智慧城市更智慧,让现实生活更美好
Apifox generates interface documentation tutorial and operation steps
备战金九银十!2022面试必刷大厂架构面试真题汇总+阿里七面面经+架构师简历模板分享
云原生应用的概念和云原生应用的 15 个特征
限时招募!淘宝无货源副业,800/天,不限经验,男女皆可,仅限前200名!
What happened when the computer crashed?
Based on MySQL database, Redis cache, MQ message middleware, ES high availability scheme of search engine parsing
Execution order of select, from, join, on where groupby, etc. in MySQL