当前位置:网站首页>【Spark】(task8)SparkML中的pipeline通道建立
【Spark】(task8)SparkML中的pipeline通道建立
2022-06-10 12:36:00 【山顶夕景】
文章目录
一、ML Pipeline机器学习流程
如果样本较少,可以直接使用python对样本进行ML建模,但当需要大规模数据集时,可以使用spark进行分布式内存计算,虽然spark的原生语言是scala,但如果用python写可以用pyspark进行机器学习的pipeline链路建立。
1.1 ML Pipeline构建流程
spark有MLlib机器学习库,比ML Pipeline复杂,先来大概看下ML Pipeline构建机器学习流程:
- 数据准备: 将特征值和预测变量整理成DataFrame
- 建立机器学习流程Pipeline:
StringIndexer:将文字分类特征转化为数字OneHotEncoder:将数字分类特征转化为稀疏向量VectorAssembler:将所有特征字段整合成一个Vector字段DecisionTreeClassfier:训练生成模型
- 训练:训练集使用
pipeline.fit()进行训练,产生pipelineModel - 预测:使用
pipelineModel.transform()预测测试集,产生预测结果
1.2 ML Pipeline组件
注意:pyspark的一些组件和python中的同名组件不完全一样:
DataFrame: 是Spark ML机器学习API处理的数据格式,可以由文本文件、RDD、或者Spark SQL创建,与python 的Dataframe概念相近但是方法完全不同。Transformer:可以使用.transform方法将一个DataFrame转换成另一个DataFrame。Estimator:可以使用.fit方法传入DataFrame,生成一个Transformer。pipeline:可以串联多个Transformer和Estimator建立ML机器学习的工作流。Parameter:以上Transformer和Estimator都可以共享的参数API。
二、以GBDT为栗子
2.0 GBTs介绍
Spark中的GBDT较GBTs——梯度提升树,因为其是基于决策树(Decision Tree,DT)实现的,所以叫GBDT。Spark 中的GBDT算法存在于ml包和mllib包中:
- mllib是基于RDD的,
- ml包则是针对DataFrame的,ml包中的GBDT分为分类和回归。
由于在实际生产环境中使用基于RDD的较多,所以直接使用MLLib包中的GBTs,ML包中的则进行简单说明。

- pipeline:一个 Pipeline 在结构上会包含一个或多个 PipelineStage,每一个 PipelineStage 都会完成一个任务,如数据集处理转化,模型训练,参数设置或数据预测等,这样的 PipelineStage 在 ML 里按照处理问题类型的不同都有相应的定义和实现。
- transformer:是一个pipelineStage,把一个df转为另一个df,一个model可以把一个不包含预测标签的测试数据集 DataFrame 打上标签转化成另一个包含预测标签的 DataFrame,显然这样的结果集可以被用来做分析结果的可视化。
- estimator:操作df数据生成一个transformer,包括fit部分。
2.1 加载libsvm数据
# gbdt_test
import findspark
findspark.init()
import pyspark
from pyspark import SparkConf
from pyspark.ml import Pipeline
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer, VectorIndexer,IndexToString
from pyspark.ml.classification import GBTClassifier
file_path = "file:///home/hadoop/development/RecSys/data"
# def gradientBoostedTreeClassifier(data="data/sample_libsvm_data.txt"):
# def gradientBoostedTreeClassifier(data):
''' GBDT分类器 '''
#加载LIBSVM格式的数据集
data = spark.read.format("libsvm").load(data)
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
data.show(n = 3)
""" +-----+--------------------+ |label| features| +-----+--------------------+ | 0.0|(692,[127,128,129...| | 1.0|(692,[158,159,160...| | 1.0|(692,[124,125,126...| +-----+--------------------+ only showing top 3 rows """
2.2 pipeline链路过程
#训练集、测试集划分
(trainingData, testData) = data.randomSplit([0.7, 0.3])
#print("训练集:\n", trainingData.show(n = 1), "\n")
#print("测试集:\n", testData.show(n = 1))
# 使用10个基分类器
gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)
print("gbt_test:\n", gbt, "\n")
# 建立模型的pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])
print("pipeline:\n", type(pipeline), "\n")
model = pipeline.fit(trainingData)
# 做预测
predictions = model.transform(testData)
#展示前5行数据
predictions.select("prediction", "indexedLabel", "features").show(5)
#展示预测标签与真实标签,计算测试误差 fit part
evaluator = MulticlassClassificationEvaluator(
labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
# predict
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))
gbtModel = model.stages[2]
print('gbtModelSummary: ',gbtModel) #模型摘要
结果如下,从Test Error = 0.12看,即accuracy为98%的效果,以上即一个简单的GBDT分类任务,通过10个基分类器,根据boosting策略根据负梯度的优化:
+-----+--------------------+
|label| features|
+-----+--------------------+
| 0.0|(692,[95,96,97,12...|
+-----+--------------------+
only showing top 1 row
训练集:
None
+-----+--------------------+
|label| features|
+-----+--------------------+
| 0.0|(692,[123,124,125...|
+-----+--------------------+
only showing top 1 row
测试集:
None
gbt_test:
GBTClassifier_eafe5d3c8749
pipeline:
<class 'pyspark.ml.pipeline.Pipeline'>
+----------+------------+--------------------+
|prediction|indexedLabel| features|
+----------+------------+--------------------+
| 1.0| 1.0|(692,[123,124,125...|
| 1.0| 1.0|(692,[124,125,126...|
| 1.0| 1.0|(692,[126,127,128...|
| 1.0| 1.0|(692,[129,130,131...|
| 1.0| 1.0|(692,[150,151,152...|
+----------+------------+--------------------+
only showing top 5 rows
Test Error = 0.12
gbtModelSummary: GBTClassificationModel: uid = GBTClassifier_eafe5d3c8749, numTrees=10, numClasses=2, numFeatures=692
Reference
[1] Spark 排序算法系列之 GBTs 使用方式介绍
[2] MLlib:https://www.jianshu.com/p/4d7003182398
[3] pyspark学习之——流水线Pipeline
边栏推荐
- 六石编程学:以文字处理的位置,谈谈命名
- JTAG-to-AXI Master调试AXI BRAM Controller
- Alibaba cloud ECS server builds MySQL database
- (一)预处理总结
- Const Modified member function
- Ad-pcb schematic diagram learning (1)
- JS translates Arabic numerals into Chinese capital figures, JS converts figures into capital amounts (sorting)
- 向数据库中注册用户名和密码的功能
- (11) Const decorated member function
- (五)类和对象及类的分文件操作(2)
猜你喜欢

微信web开发工具使用教程,公司开发web

Count the number and average value of natural numbers whose sum of bits within 100 is 7

统计100以内的各位数之和为7的自然数的个数及平均值

Can chip learning of max3051

JS translates Arabic numerals into Chinese capital figures, JS converts figures into capital amounts (sorting)

MySQL service evolution

Shadergraph - 302 swimming Dragon

(四)类和对象(1)

CC2642R 蓝牙MCU芯片的学习

FM4057S单节锂电池线性充电芯片的学习
随机推荐
STM32 learning notes (2) -usart (basic application 1)
Stereo Vision-based Semantic 3D Object and Ego-motion Tracking for Autonomous Driving 论文阅读
编写程序,计算2/1+3/2+5/3+8/5.....的值。要求计算前n项之和,保留2位小数(该序列从第二项起,每一项的分子是前一项分子与分母的和,分母是前一项的分子)
【移动机器人】轮式里程计原理
Stm32f407 learning notes (1) -exti interrupt event and NVIC register
FM4057S单节锂电池线性充电芯片的学习
Today, a couple won the largest e-commerce IPO in Hong Kong
Get enumeration values through reflection
JS array to JSON, JSON to array. Array to comma separated string, string to array
Count the number and average value of natural numbers whose sum of bits within 100 is 7
Practical cases, in-depth analysis
JTAG to Axi master debugging Axi Bram controller
阿里云ECS服务器搭建Mysql数据库
JS judgment includes: includes
(八)初始化列表
手机厂商“返祖”,只有苹果说不
Altium Designer重拾之开篇引入
PCB learning notes (2) -3d packaging related
Automatic Mapping of Tailored Landmark Representations for Automated Driving and Map Learning 论文阅读
What if the xshell evaluation period has expired? Follow the steps below to solve the problem!