当前位置:网站首页>scikit-learn——机器学习应用开发的步骤和理解
scikit-learn——机器学习应用开发的步骤和理解
2022-07-29 05:07:00 【m0_65187443】
scikit-learn是一个开源Python语言机器学习工具包,它涵盖了几乎所有主流机器学习算法的实现,并且提供了一致的调用接口。它基于Numpy和scipy等Python数值计算库,提供了高效的算法实现。
目录
1.数据采集和标记
先采集数据,再对数据进行标记。其中采集数据要就有代表性,以确保最终训练出来模型的准确性。
2.特征选择
选择特征的直观方法:直接使用图片的每个像素点作为一个特征。
数据保存为样本个数×特征个数格式的array对象。scikit-learn使用Numpy的array对象来表示数据,所有的图片数据保存在digits.images里,每个元素都为一个8×8尺寸的灰阶图片。
3.数据清洗
把采集到的、不合适用来做机器学习训练的数据进行预处理,从而转换为合适机器学习的数据。
目的:减少计算量,确保模型稳定性。
4.模型选择
对于不同的数据集,选择不同的模型有不同的效率。因此在选择模型要考虑很多的因素,来提高最终选择模型的契合度。
5.模型训练
在进行模型训练之前,要将数据集划分为训练数据集和测试数据集,再利用划分好的数据集进行模型训练,最后得到我们训练出来的模型参数。
6.模型测试
模型测试的直观方法:用训练出来的模型预测测试数据集,然后将预测出来的结果与真正的结果进行比较,最后比较出来的结果即为模型的准确度。
scikit-learn提供的完成这项工作的方法:
clf . score ( Xtest , Ytest)除此之外,还可以直接把测试数据集里的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。
7.模型保存与加载
当我们训练出一个满意的模型后即可将模型保存下来,这样当下次需要预测时,可以直接利用此模型进行预测,不用再一次进行模型训练。
8.实例
数据采集和标记
#导入库
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
"""
sk-learn库中自带了一些数据集
此处使用的就是手写数字识别图片的数据
"""
# 导入sklearn库中datasets模块
from sklearn import datasets
# 利用datasets模块中的函数load_digits()进行数据加载
digits = datasets.load_digits()
# 把数据所代表的图片显示出来
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Digit: %i' % label, fontsize=20);
特征选择
# 将数据保存为 样本个数x特征个数 格式的array对象 的数据格式进行输出
# 数据已经保存在了digits.data文件中
print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))
模型训练
# 把数据分成训练数据集和测试数据集(此处将数据集的百分之二十作为测试数据集)
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2);
# 使用支持向量机来训练模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100., probability=True)
# 使用训练数据集Xtrain和Ytrain来训练模型
clf.fit(Xtrain, Ytrain);模型测试
"""
sklearn.metrics.accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)
normalize:默认值为True,返回正确分类的比例;如果为False,返回正确分类的样本数
"""
# 评估模型的准确度(此处默认为true,直接返回正确的比例,也就是模型的准确度)
from sklearn.metrics import accuracy_score
# predict是训练后返回预测结果,是标签值。
Ypred = clf.predict(Xtest);
accuracy_score(Ytest, Ypred)模型保存与加载
"""
将测试数据集里的部分图片显示出来
图片的左下角显示预测值,右下角显示真实值
"""
# 查看预测的情况
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
transform=ax.transAxes,
color='green' if Ypred[i] == Ytest[i] else 'red')
ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
transform=ax.transAxes,
color='black')
ax.set_xticks([])
ax.set_yticks([])
# 保存模型参数
import joblib
joblib.dump(clf, 'digits_svm.pkl');保存模型参数过程中出现如下错误:
![]()
原因:sklearn.externals.joblib函数是用在0.21及以前的版本中,在最新的版本,该函数应被弃用。
解决方法:将 from sklearn.externals import joblib改为 import joblib
# 导入模型参数,直接进行预测
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Xtest, Ytest)![]()
边栏推荐
- 如何让照片中的人物笑起来?HMS Core视频编辑服务一键微笑功能,让人物笑容更自然
- Unity Metaverse(三)、Protobuf & Socket 实现多人在线
- Network Security Learning - Intranet Security 1
- Learn the first program of database
- Climbing the pit of traffic flow prediction (II): the simplest LSTM predicts traffic flow using tensorflow2
- MySQL regularly calls preset functions to complete data update
- Connection database time zone setting
- Opencv learning 1 (environment configuration)
- 虚拟偶像的歌声原来是这样生成的!
- MySQL定时调用预置函数完成数据更新
猜你喜欢

excel怎么设置行高和列宽?excel设置行高和列宽的方法

Use openmap and ArcGIS to draw maps and transportation networks of any region, and convert OMS data into SHP format

传奇开区网站如何添加流量统计代码

Unity metaverse (III), protobuf & socket realize multi person online

荣耀2023内推,内推码ambubk

1 sentence of code, get asp Net core binds multiple sources to the same class

2021-10-23

学习数据库的第一个程序

如何让照片中的人物笑起来?HMS Core视频编辑服务一键微笑功能,让人物笑容更自然

Mapper agent development
随机推荐
如何安装office2010安装包?office2010安装包安装到电脑上的方法
The song of the virtual idol was originally generated in this way!
Lenovo Savior r7000+ add ssd+ copy and partition the information of the original D disk to the new SSD
关于servlet中实现网站的页面跳转
Open the tutorial of adding and modifying automatically playing music on the open zone website
Youxuan database failed to start and reported network error
关于thymeleaf的配置与使用
Excel卡住了没保存怎么办?Excel还没保存但是卡住了的解决方法
How to make the characters in the photos laugh? HMS core video editing service one click smile function makes people smile more naturally
How does WPS use smart fill to quickly fill data? WPS method of quickly filling data
Use openmap and ArcGIS to draw maps and transportation networks of any region, and convert OMS data into SHP format
Mysql的自连接和联合查询
Reply from the Secretary of jindawei: the company is optimistic about the market prospect of NMN products and has launched a series of products
让你的正则表达式可读性提高一百倍
时间序列分析的表示学习时代来了?
2021-10-11
开区网站打开自动播放音乐的添加跟修改教程
Glory 2023 push, push code ambubk
MySQL定时调用预置函数完成数据更新
使用Jstack、Jconsole和jvisualvm进行死锁分析