当前位置:网站首页>数据集划分以及交叉验证法
数据集划分以及交叉验证法
2022-07-31 05:09:00 【Erosion_ww】
数据集划分
训练集 验证集 测试集
- 训练集用来构建模型
- 验证集用来在构建过程中微调模型,辅助模型构建,可以重复使用。(当有验证集存在时,我们从训练集选一部分用于测试模型,保留一部分训练集数据作为验证集)
- 测试集用于检验模型,评估模型的准确率。
Hold-out method
默认将数据集的75%作为训练集,数据集的25%作为测试集。
交叉验证
1.留一验证法
把一个大的数据集分为k个小数据集,k等于数据集中数据的个数,每次只使用一个作为测试集,剩下的全部作为训练集,这种方法得出的结果与训练整个测试集的期望值最为接近,但是成本过于庞大,适合小样本数据集。
2. K折交叉验证
将数据集分成k个子集,每次选k-1个子集作为训练集,剩下的那个子集作为测试集。一共进行k次,将k次的平均交叉验证正确率作为结果。train_test_split,默认训练集、测试集比例为3:1。如果是5折交叉验证,训练集比测试集为4:1;10折交叉验证训练集比测试集为9:1。数据量越大,模型准确率越高。
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn import metrics
data = load_iris() #获取鸢尾花数据集
X = data.data
y = data.target
kf = KFold(n_splits=5, random_state=None) # 5折交叉验证
i = 1
for train_index, test_index in kf.split(X, y):
print('\n{} of kfold {}'.format(i,kf.n_splits))
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model = LogisticRegression(random_state=1)
model.fit(X_train, y_train)
pred_test = model.predict(X_test)
score = metrics.accuracy_score(y_test, pred_test)
print('accuracy_score', score)
i += 1
pred = model.predict_proba(X_test)[:, 1]
3. 分层交叉验证
分层是重新将数据排列组合,使得每一折都能比较好地代表整体。
在一个二分类问题上,原始数据一共有两类(F和M),F:M的数据量比例大概是 1:3;划分了5折,每一折中F和M的比例都保持和原数据一致(1:3)。
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
data = load_iris()
X = data.data
y = data.target
skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=0)
for train_index, test_index in skf.split(X,y):
print("Train:", train_index, "Validation:", test_index)
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model = LogisticRegression()
scores = cross_val_score(model,X,y,cv=skf)
print("straitified cross validation scores:{}".format(scores))
print("Mean score of straitified cross validation:{:.2f}".format(scores.mean()))
4. 重复交叉验证
其实就是重复n次k-fold,每次重复有不同的随机性。
from sklearn.model_selection import RepeatedKFold
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn import metrics
data = load_iris()
X = data.data
y = data.target
kf = RepeatedKFold(n_splits=5, n_repeats=2, random_state=None)
for train_index, test_index in kf.split(X):
print("Train:", train_index, "Validation:",test_index)
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
i = 1
for train_index, test_index in kf.split(X, y):
print('\n{} of kfold {}'.format(i,i))
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model = LogisticRegression(random_state=1)
model.fit(X_train, y_train)
pred_test = model.predict(X_test)
score = metrics.accuracy_score(y_test, pred_test)
print('accuracy_score', score)
i += 1
#pred_test = model.predict(X_test)
pred = model.predict_proba(X_test)[:, 1]
边栏推荐
猜你喜欢

Why use Flink and how to get started with Flink?

Temporal客户端模型
![[Cloud Native] DevOps (5): Integrating Harbor](/img/5a/dd33d7d3cb0124f328d2d38fff0125.png)
[Cloud Native] DevOps (5): Integrating Harbor

城市内涝及桥洞隧道积水在线监测系统

MySQL优化之慢日志查询

Redis Advanced - Cache Issues: Consistency, Penetration, Penetration, Avalanche, Pollution, etc.

【mysql 提高查询效率】Mysql 数据库查询好慢问题解决

centos7安装mysql5.7步骤(图解版)

STM32 - DMA

TOGAF之架构标准规范(一)
随机推荐
MySQL优化:从十几秒优化到三百毫秒
[R language] [3] apply, tapply, lapply, sapply, mapply and par function related parameters
sql statement - how to query data in another table based on the data in one table
如何将项目部署到服务器上(全套教程)
Workflow番外篇
SQL injection of DVWA
质量小议12 -- 以测代评
参考代码系列_1.各种语言的Hello World
ERP Production Operation Control Kingdee
1. 获取数据-requests.get()
MySQL forgot password
Unity Tutorial: URP Rendering Pipeline Practical Tutorial Series [1]
Heavyweight | The Open Atomic School Source Line activity was officially launched
MySQL database addition, deletion, modification and query (detailed explanation of basic operation commands)
MySQL-Explain详解
110 MySQL interview questions and answers (continuously updated)
MySQL开窗函数
mysql stored procedure
2022-07-30:以下go语言代码输出什么?A:[]byte{} []byte;B:[]byte{} []uint8;C:[]uint8{} []byte;D:[]uin8{} []uint8。
Linux的mysql报ERROR 1045 (28000) Access denied for user ‘root‘@‘localhost‘ (using password NOYSE)