当前位置:网站首页>FE01_OneHot-Scala应用
FE01_OneHot-Scala应用
2022-08-04 17:58:00 【51CTO】
OneHot是处理类别型变量常用的处理方法,scala中如果应用呢?如果测试集中出现训练集中没有value,怎么处理?
1 数据构造
import org. apache. spark. ml.{ Model, Pipeline, PipelineModel, PipelineStage}
import org. apache. spark. ml. classification. LogisticRegression
import org. apache. spark. ml. feature.{ OneHotEncoderEstimator, StringIndexer}
import org. apache. spark. sql. functions. _
import org. apache. spark. ml. feature. VectorAssembler
import org. apache. spark. sql.{ DataFrame, Row, SparkSession}
val builder = SparkSession
. builder()
. appName( "LR")
. config( "spark.executor.heartbeatInterval", "60s")
. config( "spark.network.timeout", "120s")
. config( "spark.serializer", "org.apache.spark.serializer.KryoSerializer")
. config( "spark.kryoserializer.buffer.max", "512m")
. config( "spark.dynamicAllocation.enabled", false)
. config( "spark.sql.inMemoryColumnarStorage.compressed", true)
. config( "spark.sql.inMemoryColumnarStorage.batchSize", 10000)
. config( "spark.sql.broadcastTimeout", 600)
. config( "spark.sql.autoBroadcastJoinThreshold", - 1)
. config( "spark.sql.crossJoin.enabled", true)
. master( "local[*]")
val spark = builder. getOrCreate()
spark. sparkContext. setLogLevel( "ERROR")
import spark. implicits. _
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
import org.apache.spark.ml.{Model, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
builder: org.apache.spark.sql.SparkSession.Builder = [email protected]
spark: org.apache.spark.sql.SparkSession = [email protected]
import spark.implicits._
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
println( "------------- dfTrain -------------")
var dfTrain = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", null, 0.2, 1),
( 3, 4.7, "b", "hi", 0.2, 0),
( 4, 4.6, "c", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
dfTrain. show()
println( "------------- dfTest -------------")
var dfTest = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", "no", 0.2, 1),
( 3, 4.7, "a", "yes", 0.2, 0),
( 4, 4.6, "d", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
// 测试集直接copy就行了,仅用来测试
dfTest. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
------------- dfTrain -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| null|0.2| 1|
| 3|4.7| b| hi|0.2| 0|
| 4|4.6| c|hello|0.2| 1|
+---+---+---+-----+---+-----+
------------- dfTest -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| no|0.2| 1|
| 3|4.7| a| yes|0.2| 0|
| 4|4.6| d|hello|0.2| 1|
+---+---+---+-----+---+-----+
dfTrain: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
dfTest: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
注意:
- dTest的类别型特征中出现了dTrain中没有的value
- x2中的d,x3中的no,yes
2 StringIndexer转换
StringIndexer能把字符型字段编码成标签索引,索引的范围为0到标签数量(这里指的是dfTrain中了),索引构建的顺序为标签的频率,优先编码频率较大的标签,所以出现频率最高的标签为0号。如果输入的是数值型的,我们会把它转化成字符型,然后再对其进行编码。
做onehot之前我们需要先做stringIndex转换。
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
new Pipeline(). setStages( indexers). fit( dfTrain). transform( dfTrain). show()
- 1.
- 2.
- 3.
- 4.
- 5.
+---+---+---+-----+---+-----+----------+----------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed|
+---+---+---+-----+---+-----+----------+----------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|
+---+---+---+-----+---+-----+----------+----------+
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_2ab569f4651d, strIdx_4d7196a2a72b)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- StringIndexr和常用的OneHotEncoderEstimator、VectorAssembler不同,只能单个特征的处理
- 上面用map方法做了简单的处理
- 从x2的编码看,频数最多的b标记为0,x3的hello被编辑为0
- x3的缺失值被标记为2.0
3 OneHotEncoderEstimator
下面对转换后的数据进行onehot编码
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|(3,[2],[1.0])|(2,[0],[1.0])|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
dfTrain1: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 6 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 从数据上看,还不是很直观。我们习惯行列方式的数据展示方式
- x2_onehot中,3表示一共有3个值(a,b,c),第二个表示分别是哪个索引,第三个都是1
- x3_onehot中,id=2的有点特殊,均为[],应该是指null不算,即值个个数为2,(0,1)
下面对数据做了转换,把vector转换成array,看起来更直观一些
- 可以看到x2_onehot按x2_indexed的[0,1,2]生成三列,值为0,1,命中就1否则为0
- 而对于id=2的x3,确实都是0,0,即onehot其实不考虑null值
如果测试集中有训练集中没出现的值,onehot时会自动把所有列的值都改成0,也是就和null处理方法一致。直接看个例子
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble))
val p1 = pipeline. fit( dfTrain)
p1. transform( dfTest). show()
// 把特征列转成array,再打印出来
println( "--------------------------------")
p1. transform( dfTest). select( "features"). map( x => x( 0). asInstanceOf[ Vector]. toArray). take( 4). foreach( x => println( x. mkString( ",")))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot| features|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|[5.1,0.2,0.0,1.0,...|
| 2|4.9| b| no|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|(7,[0,1,2],[4.9,0...|
| 3|4.7| a| yes|0.2| 0| 1.0| 2.0|(3,[1],[1.0])| (2,[],[])|(7,[0,1,3],[4.7,0...|
| 4|4.6| d|hello|0.2| 1| 3.0| 0.0| (3,[],[])|(2,[0],[1.0])|(7,[0,1,5],[4.6,0...|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
--------------------------------
5.1,0.2,0.0,1.0,0.0,1.0,0.0
4.9,0.2,1.0,0.0,0.0,0.0,0.0
4.7,0.2,0.0,1.0,0.0,0.0,0.0
4.6,0.2,0.0,0.0,0.0,1.0,0.0
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_e2e4f73059df, strIdx_581f33be39be)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_f134aaddf69e
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_2ec94bc0ff67
pipeline: org.apache.spark.ml.Pipeline = pipeline_cdcfa4e895ce
p1: org.apache.spark.ml.PipelineModel = pipeline_cdcfa4e895ce
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
先看features列
- id=1没啥好说的,id=2,3,4都是简写的方式
- 以id=3为例,7表示一共7个元素,其中第[0,1,3]是[4.7,0.2,0.0,1],别的都是0
- id=3的x3_onehot中,yes是训练集中没出现的,所以onehot之后都是0
后面打印输出直观的说明了数据储存的具体方式
4 Pipeline
用Pipeline把数据处理的流程串起来,简单跑个demo
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val lr = new LogisticRegression().
setMaxIter( 10).
setRegParam( 0.01)
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble, lr))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_29b5d6cde551, strIdx_1fa35b31e12b)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_00efb376d02c
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_78da110d20bf
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_f316005bcb94
pipeline: org.apache.spark.ml.Pipeline = pipeline_e4e6f0fbeabe
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
println( "-------------- Train -------------- ")
pipeModel. transform( dfTrain). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
println( "-------------- Test -------------- ")
pipeModel. transform( dfTest). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
- 1.
- 2.
- 3.
- 4.
-------------- Train --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|[4.7,0.2,1.0,0.0,...|[2.86186371148445...|[0.94592870313351...| 0.0|
| 4|[4.6,0.2,0.0,0.0,...|[-3.5623891952210...|[0.02758825464682...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
-------------- Test --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|(7,[0,1,3],[4.7,0...|[2.40266618214349...|[0.91703038740784...| 0.0|
| 4|(7,[0,1,5],[4.6,0...|[-0.7089322949842...|[0.32983480687831...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
总结下:
- Scala做onehot不用担心测试集中出现训练集中没有value
- 用pipeline更方便
Ref
[2] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer
2020-04-02 于南京市江宁区九龙湖
边栏推荐
猜你喜欢
[Web Automation Test] Quick Start with Playwright, 5 minutes to get started
动态数组底层是如何实现的
44. 通配符匹配 ●●● & HJ71 字符串通配符 ●●
leetcode 14. 最长公共前缀
基于层次分析法的“内卷”指数分析
npm配置国内镜像(淘宝镜像)
启动项目(瑞吉外卖)
解决错误:The package-lock.json file was created with an old version of npm
Cholesterol-PEG-DBCO,CLS-PEG-DBCO,胆固醇-聚乙二醇-二苯基环辛炔科研试剂
JS兼容问题总结
随机推荐
Matlab画图1
Fork/Join框架
2018读书记
"Involution" Index Analysis Based on AHP
C. LIS or Reverse LIS?
How to recruit programmers
谁能解答?从mysql的binlog读取数据到kafka,但是数据类型有Insert,updata,
leetcode 13. 罗马数字转整数
报道称任天堂在2023年3月前不会推出任何新硬件产品
PT100铂热电阻三种测温方法介绍
LVS+Keepalived群集
R语言缺失时间序列的填充及合并:补齐时间序列数据中所有缺失的时间索引、使用merge函数合并日期补齐之后的时间序列数据和另外一个时间序列数据(补齐左侧数据)
Cholesterol-PEG-DBCO,CLS-PEG-DBCO,胆固醇-聚乙二醇-二苯基环辛炔科研试剂
《中国综合算力指数》《中国算力白皮书》《中国存力白皮书》《中国运力白皮书》在首届算力大会上重磅发出
字节二面被问到mysql事务与锁问题,我蚌埠住了
2022 May 1 Mathematical Modeling Question C Explanation
R语言使用ggpubr包的ggsummarystats函数可视化柱状图(通过ggfunc参数设置)、在可视化图像的下方添加描述性统计结果表格、palette参数配置柱状图及统计数据的颜色
《机器学习理论到应用》电子书免费下载
解决错误:The package-lock.json file was created with an old version of npm
租房小程序登顶码云热门