当前位置:网站首页>scikit-learn——机器学习应用开发的步骤和理解
scikit-learn——机器学习应用开发的步骤和理解
2022-07-29 05:07:00 【m0_65187443】
scikit-learn是一个开源Python语言机器学习工具包,它涵盖了几乎所有主流机器学习算法的实现,并且提供了一致的调用接口。它基于Numpy和scipy等Python数值计算库,提供了高效的算法实现。
目录
1.数据采集和标记
先采集数据,再对数据进行标记。其中采集数据要就有代表性,以确保最终训练出来模型的准确性。
2.特征选择
选择特征的直观方法:直接使用图片的每个像素点作为一个特征。
数据保存为样本个数×特征个数格式的array对象。scikit-learn使用Numpy的array对象来表示数据,所有的图片数据保存在digits.images里,每个元素都为一个8×8尺寸的灰阶图片。
3.数据清洗
把采集到的、不合适用来做机器学习训练的数据进行预处理,从而转换为合适机器学习的数据。
目的:减少计算量,确保模型稳定性。
4.模型选择
对于不同的数据集,选择不同的模型有不同的效率。因此在选择模型要考虑很多的因素,来提高最终选择模型的契合度。
5.模型训练
在进行模型训练之前,要将数据集划分为训练数据集和测试数据集,再利用划分好的数据集进行模型训练,最后得到我们训练出来的模型参数。
6.模型测试
模型测试的直观方法:用训练出来的模型预测测试数据集,然后将预测出来的结果与真正的结果进行比较,最后比较出来的结果即为模型的准确度。
scikit-learn提供的完成这项工作的方法:
clf . score ( Xtest , Ytest)除此之外,还可以直接把测试数据集里的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。
7.模型保存与加载
当我们训练出一个满意的模型后即可将模型保存下来,这样当下次需要预测时,可以直接利用此模型进行预测,不用再一次进行模型训练。
8.实例
数据采集和标记
#导入库
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
"""
sk-learn库中自带了一些数据集
此处使用的就是手写数字识别图片的数据
"""
# 导入sklearn库中datasets模块
from sklearn import datasets
# 利用datasets模块中的函数load_digits()进行数据加载
digits = datasets.load_digits()
# 把数据所代表的图片显示出来
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Digit: %i' % label, fontsize=20);
特征选择
# 将数据保存为 样本个数x特征个数 格式的array对象 的数据格式进行输出
# 数据已经保存在了digits.data文件中
print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))
模型训练
# 把数据分成训练数据集和测试数据集(此处将数据集的百分之二十作为测试数据集)
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2);
# 使用支持向量机来训练模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100., probability=True)
# 使用训练数据集Xtrain和Ytrain来训练模型
clf.fit(Xtrain, Ytrain);模型测试
"""
sklearn.metrics.accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)
normalize:默认值为True,返回正确分类的比例;如果为False,返回正确分类的样本数
"""
# 评估模型的准确度(此处默认为true,直接返回正确的比例,也就是模型的准确度)
from sklearn.metrics import accuracy_score
# predict是训练后返回预测结果,是标签值。
Ypred = clf.predict(Xtest);
accuracy_score(Ytest, Ypred)模型保存与加载
"""
将测试数据集里的部分图片显示出来
图片的左下角显示预测值,右下角显示真实值
"""
# 查看预测的情况
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
transform=ax.transAxes,
color='green' if Ypred[i] == Ytest[i] else 'red')
ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
transform=ax.transAxes,
color='black')
ax.set_xticks([])
ax.set_yticks([])
# 保存模型参数
import joblib
joblib.dump(clf, 'digits_svm.pkl');保存模型参数过程中出现如下错误:
![]()
原因:sklearn.externals.joblib函数是用在0.21及以前的版本中,在最新的版本,该函数应被弃用。
解决方法:将 from sklearn.externals import joblib改为 import joblib
# 导入模型参数,直接进行预测
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Xtest, Ytest)![]()
边栏推荐
- 力扣------对奇偶下标分别排序
- MySQL regularly calls preset functions to complete data update
- The person who goes to and from work on time and never wants to work overtime has been promoted in front of me
- Solution to the fourth game of 2022 Hangzhou Electric Multi school league
- Reply from the Secretary of jindawei: the company is optimistic about the market prospect of NMN products and has launched a series of products
- 【文件下载】Easyexcel快速上手
- Operator operation list of spark
- IDEA中使用注解Test
- js(forEach)出现return无法结束函数的解决方法
- AttributeError: ‘module‘ object has no attribute ‘create_connection‘
猜你喜欢

使用Jstack、Jconsole和jvisualvm进行死锁分析

向往的开源之多YOUNG新生 | 从开源到就业的避坑指南来啦!

玩家访问网站自动弹窗加QQ群方法以及详细代码

Learn matlab to draw geographical map, line scatter bubble density map

【2022新生学习】第三周要点

Introduction of JDBC preparestatement+ database connection pool

TCP三次握手四次挥手

WPS插入超链接无法打开,提示“无法打开指定文件”怎么办!

< El table column> place multiple pictures

五个关联分析,领略数据分析师一大重要必会处理技能
随机推荐
【微信小程序】swiper滑动页面,滑块左右各露出前后的一部分,露出一部分
How to solve the problem of configuring the progress every time Office2010 is opened?
Live in small private enterprises
Activity workflow table structure learning
ODOO开发教程之图表
缓存穿透、缓存击穿、缓存雪崩以及解决方法
怎样监测微型的网站服务
Academic | [latex] super detailed texlive2022+tex studio download installation configuration
深度学习刷SOTA的一堆trick
让你的正则表达式可读性提高一百倍
Use annotation test in idea
玩家访问网站自动弹窗加QQ群方法以及详细代码
开区网站打开自动播放音乐的添加跟修改教程
荣耀2023内推,内推码ambubk
Use openmap and ArcGIS to draw maps and transportation networks of any region, and convert OMS data into SHP format
传奇如何一台服务器配置多个版本微端更新
Connection database time zone setting
MySQL time calculation function
[file download] easyexcel quick start
Mapper agent development