当前位置:网站首页>FE01_OneHot-Scala Application
FE01_OneHot-Scala Application
2022-08-04 18:00:00 【51CTO】
OneHotIt is a common processing method for dealing with categorical variables,scalaIf it is applied?If it appears in the test set and not in the training setvalue,怎么处理?
1 数据构造
import org. apache. spark. ml.{ Model, Pipeline, PipelineModel, PipelineStage}
import org. apache. spark. ml. classification. LogisticRegression
import org. apache. spark. ml. feature.{ OneHotEncoderEstimator, StringIndexer}
import org. apache. spark. sql. functions. _
import org. apache. spark. ml. feature. VectorAssembler
import org. apache. spark. sql.{ DataFrame, Row, SparkSession}
val builder = SparkSession
. builder()
. appName( "LR")
. config( "spark.executor.heartbeatInterval", "60s")
. config( "spark.network.timeout", "120s")
. config( "spark.serializer", "org.apache.spark.serializer.KryoSerializer")
. config( "spark.kryoserializer.buffer.max", "512m")
. config( "spark.dynamicAllocation.enabled", false)
. config( "spark.sql.inMemoryColumnarStorage.compressed", true)
. config( "spark.sql.inMemoryColumnarStorage.batchSize", 10000)
. config( "spark.sql.broadcastTimeout", 600)
. config( "spark.sql.autoBroadcastJoinThreshold", - 1)
. config( "spark.sql.crossJoin.enabled", true)
. master( "local[*]")
val spark = builder. getOrCreate()
spark. sparkContext. setLogLevel( "ERROR")
import spark. implicits. _
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
import org.apache.spark.ml.{Model, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
builder: org.apache.spark.sql.SparkSession.Builder = [email protected]
spark: org.apache.spark.sql.SparkSession = [email protected]
import spark.implicits._
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
println( "------------- dfTrain -------------")
var dfTrain = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", null, 0.2, 1),
( 3, 4.7, "b", "hi", 0.2, 0),
( 4, 4.6, "c", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
dfTrain. show()
println( "------------- dfTest -------------")
var dfTest = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", "no", 0.2, 1),
( 3, 4.7, "a", "yes", 0.2, 0),
( 4, 4.6, "d", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
// Test set directlycopy就行了,仅用来测试
dfTest. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
------------- dfTrain -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| null|0.2| 1|
| 3|4.7| b| hi|0.2| 0|
| 4|4.6| c|hello|0.2| 1|
+---+---+---+-----+---+-----+
------------- dfTest -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| no|0.2| 1|
| 3|4.7| a| yes|0.2| 0|
| 4|4.6| d|hello|0.2| 1|
+---+---+---+-----+---+-----+
dfTrain: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
dfTest: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
注意:
- dTestappeared in the categorical features of dTrain中没有的value
- x2中的d,x3中的no,yes
2 StringIndexer转换
StringIndexerCan encode character fields into tag indices,索引的范围为0to the number of labels(这里指的是dfTrain中了),The order of index building is the frequency of the labels,Labels with higher encoding frequency are preferentially encoded,So the tag with the highest frequency is 0号.如果输入的是数值型的,We will convert it to character type,Then encode it.
做onehotWe need to do that beforestringIndex转换.
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
new Pipeline(). setStages( indexers). fit( dfTrain). transform( dfTrain). show()
- 1.
- 2.
- 3.
- 4.
- 5.
+---+---+---+-----+---+-----+----------+----------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed|
+---+---+---+-----+---+-----+----------+----------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|
+---+---+---+-----+---+-----+----------+----------+
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_2ab569f4651d, strIdx_4d7196a2a72b)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- StringIndexr和常用的OneHotEncoderEstimator、VectorAssembler不同,Only a single feature can be processed
- 上面用mapmethod is simple
- 从x2look at the encoding,the most frequentb标记为0,x3的helloedited to0
- x3Missing values of are marked as 2.0
3 OneHotEncoderEstimator
The following is the transformation of the dataonehot编码
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|(3,[2],[1.0])|(2,[0],[1.0])|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
dfTrain1: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 6 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 从数据上看,Not very intuitive yet.We are used to the way of data presentation in row and column
- x2_onehot中,3表示一共有3个值(a,b,c),The second indicates which index respectively,第三个都是1
- x3_onehot中,id=2is a bit special,均为[],应该是指null不算,That is, the number of values is 2,(0,1)
The data is transformed below,把vector转换成array,看起来更直观一些
- 可以看到x2_onehot按x2_indexed的[0,1,2]Generate three columns,值为0,1,hit it1否则为0
- 而对于id=2的x3,确实都是0,0,即onehot其实不考虑null值
If there are values in the test set that do not appear in the training set,onehotWhen the value of all columns are automatically changed0,And alsonull处理方法一致.直接看个例子
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble))
val p1 = pipeline. fit( dfTrain)
p1. transform( dfTest). show()
// Convert the feature column to array,再打印出来
println( "--------------------------------")
p1. transform( dfTest). select( "features"). map( x => x( 0). asInstanceOf[ Vector]. toArray). take( 4). foreach( x => println( x. mkString( ",")))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot| features|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|[5.1,0.2,0.0,1.0,...|
| 2|4.9| b| no|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|(7,[0,1,2],[4.9,0...|
| 3|4.7| a| yes|0.2| 0| 1.0| 2.0|(3,[1],[1.0])| (2,[],[])|(7,[0,1,3],[4.7,0...|
| 4|4.6| d|hello|0.2| 1| 3.0| 0.0| (3,[],[])|(2,[0],[1.0])|(7,[0,1,5],[4.6,0...|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
--------------------------------
5.1,0.2,0.0,1.0,0.0,1.0,0.0
4.9,0.2,1.0,0.0,0.0,0.0,0.0
4.7,0.2,0.0,1.0,0.0,0.0,0.0
4.6,0.2,0.0,0.0,0.0,1.0,0.0
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_e2e4f73059df, strIdx_581f33be39be)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_f134aaddf69e
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_2ec94bc0ff67
pipeline: org.apache.spark.ml.Pipeline = pipeline_cdcfa4e895ce
p1: org.apache.spark.ml.PipelineModel = pipeline_cdcfa4e895ce
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
先看features列
- id=1没啥好说的,id=2,3,4All are shorthand
- 以id=3为例,7表示一共7个元素,其中第[0,1,3]是[4.7,0.2,0.0,1],别的都是0
- id=3的x3_onehot中,yesis not present in the training set,所以onehot之后都是0
The following printout intuitively illustrates the specific way of data storage
4 Pipeline
用PipelineString together the process of data processing,简单跑个demo
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val lr = new LogisticRegression().
setMaxIter( 10).
setRegParam( 0.01)
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble, lr))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_29b5d6cde551, strIdx_1fa35b31e12b)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_00efb376d02c
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_78da110d20bf
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_f316005bcb94
pipeline: org.apache.spark.ml.Pipeline = pipeline_e4e6f0fbeabe
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
println( "-------------- Train -------------- ")
pipeModel. transform( dfTrain). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
println( "-------------- Test -------------- ")
pipeModel. transform( dfTest). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
- 1.
- 2.
- 3.
- 4.
-------------- Train --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|[4.7,0.2,1.0,0.0,...|[2.86186371148445...|[0.94592870313351...| 0.0|
| 4|[4.6,0.2,0.0,0.0,...|[-3.5623891952210...|[0.02758825464682...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
-------------- Test --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|(7,[0,1,3],[4.7,0...|[2.40266618214349...|[0.91703038740784...| 0.0|
| 4|(7,[0,1,5],[4.6,0...|[-0.7089322949842...|[0.32983480687831...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
总结下:
- Scala做onehotDon't worry about appearing in the test set and not in the training setvalue
- 用pipeline更方便
Ref
[2] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer
2020-04-02 于南京市江宁区九龙湖
边栏推荐
- 使用Redis做某个时间段在线数统计
- 群友求助,一周没有搞定的需求,3分钟就解决了?
- JWT主动校验Token是否过期
- golang安装和基础配置
- "Involution" Index Analysis Based on AHP
- asp dotnet core 通过图片统计 csdn 用户访问
- LeetCode 899. 有序队列
- Create Sentinel high-availability cluster current limiting middleware from -99
- 离线同步odps到mysql 中文乱码是因为?mysql已是utf8mb4
- 【web自动化测试】Playwright快速入门,5分钟上手
猜你喜欢
PT100铂热电阻三种测温方法介绍
mmdetection/mmdetection3d多机多卡训练
Create Sentinel high-availability cluster current limiting middleware from -99
CAS:474922-26-4,DSPE-PEG-NH2,DSPE-PEG-amine,磷脂-聚乙二醇-氨基供应
基于 eBPF 的 Kubernetes 可观测实践
启动项目(瑞吉外卖)
《机器学习理论到应用》电子书免费下载
网页端IM即时通讯开发:短轮询、长轮询、SSE、WebSocket
谷歌开源芯片 180 纳米制造工艺
2018年南海区小学生程序设计竞赛详细答案
随机推荐
CAS:385437-57-0,DSPE-PEG-Biotin,生物活性分子磷脂-聚乙二醇-生物素
Boosting之GBDT原理
八猴渲染器是什么?它能干什么?八猴软件的界面讲解
R语言ggpubr包的ggtexttable函数可视化表格数据(直接绘制表格图或者在图像中添加表格数据)、使用ggarrange函数将表格数据和可视化图像组合起来(表格数据在可视化图像下方)
2018年南海区小学生程序设计竞赛详细答案
linux下Mysql的简单操作
使用Redis做某个时间段在线数统计
FE01_OneHot-Scala应用
2022 May 1 Mathematical Modeling Question C Explanation
荣耀互联对外开放,赋能智能硬件合作伙伴,促进全场景生态产品融合
leetcode/有效的回文串,含有不需要判断回文的字符
Kotlin挂起函数原理是什么
CAS:474922-26-4,DSPE-PEG-NH2,DSPE-PEG-amine,磷脂-聚乙二醇-氨基供应
mysql cdc 为什么需要RELOAD 这个权限?这个权限在采集数据的过程中的作用是什么?有哪
基于大学生内卷行为的调查研究
"No title"
Documentary on Security Reinforcement of Network Range Monitoring System (1)—SSL/TLS Encrypted Transmission of Log Data
网络靶场监控系统的安全加固纪实(1)—SSL/TLS对日志数据加密传输
darknet源码阅读笔记-02-list.h和lish.c
ptables基本语法使用规则