当前位置:网站首页>FE01_OneHot-Scala Application
FE01_OneHot-Scala Application
2022-08-04 18:00:00 【51CTO】
OneHotIt is a common processing method for dealing with categorical variables,scalaIf it is applied?If it appears in the test set and not in the training setvalue,怎么处理?
1 数据构造
import org. apache. spark. ml.{ Model, Pipeline, PipelineModel, PipelineStage}
import org. apache. spark. ml. classification. LogisticRegression
import org. apache. spark. ml. feature.{ OneHotEncoderEstimator, StringIndexer}
import org. apache. spark. sql. functions. _
import org. apache. spark. ml. feature. VectorAssembler
import org. apache. spark. sql.{ DataFrame, Row, SparkSession}
val builder = SparkSession
. builder()
. appName( "LR")
. config( "spark.executor.heartbeatInterval", "60s")
. config( "spark.network.timeout", "120s")
. config( "spark.serializer", "org.apache.spark.serializer.KryoSerializer")
. config( "spark.kryoserializer.buffer.max", "512m")
. config( "spark.dynamicAllocation.enabled", false)
. config( "spark.sql.inMemoryColumnarStorage.compressed", true)
. config( "spark.sql.inMemoryColumnarStorage.batchSize", 10000)
. config( "spark.sql.broadcastTimeout", 600)
. config( "spark.sql.autoBroadcastJoinThreshold", - 1)
. config( "spark.sql.crossJoin.enabled", true)
. master( "local[*]")
val spark = builder. getOrCreate()
spark. sparkContext. setLogLevel( "ERROR")
import spark. implicits. _
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
import org.apache.spark.ml.{Model, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
builder: org.apache.spark.sql.SparkSession.Builder = [email protected]
spark: org.apache.spark.sql.SparkSession = [email protected]
import spark.implicits._
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
println( "------------- dfTrain -------------")
var dfTrain = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", null, 0.2, 1),
( 3, 4.7, "b", "hi", 0.2, 0),
( 4, 4.6, "c", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
dfTrain. show()
println( "------------- dfTest -------------")
var dfTest = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", "no", 0.2, 1),
( 3, 4.7, "a", "yes", 0.2, 0),
( 4, 4.6, "d", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
// Test set directlycopy就行了,仅用来测试
dfTest. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
------------- dfTrain -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| null|0.2| 1|
| 3|4.7| b| hi|0.2| 0|
| 4|4.6| c|hello|0.2| 1|
+---+---+---+-----+---+-----+
------------- dfTest -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| no|0.2| 1|
| 3|4.7| a| yes|0.2| 0|
| 4|4.6| d|hello|0.2| 1|
+---+---+---+-----+---+-----+
dfTrain: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
dfTest: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
注意:
- dTestappeared in the categorical features of dTrain中没有的value
- x2中的d,x3中的no,yes
2 StringIndexer转换
StringIndexerCan encode character fields into tag indices,索引的范围为0to the number of labels(这里指的是dfTrain中了),The order of index building is the frequency of the labels,Labels with higher encoding frequency are preferentially encoded,So the tag with the highest frequency is 0号.如果输入的是数值型的,We will convert it to character type,Then encode it.
做onehotWe need to do that beforestringIndex转换.
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
new Pipeline(). setStages( indexers). fit( dfTrain). transform( dfTrain). show()
- 1.
- 2.
- 3.
- 4.
- 5.
+---+---+---+-----+---+-----+----------+----------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed|
+---+---+---+-----+---+-----+----------+----------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|
+---+---+---+-----+---+-----+----------+----------+
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_2ab569f4651d, strIdx_4d7196a2a72b)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- StringIndexr和常用的OneHotEncoderEstimator、VectorAssembler不同,Only a single feature can be processed
- 上面用mapmethod is simple
- 从x2look at the encoding,the most frequentb标记为0,x3的helloedited to0
- x3Missing values of are marked as 2.0
3 OneHotEncoderEstimator
The following is the transformation of the dataonehot编码
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|(3,[2],[1.0])|(2,[0],[1.0])|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
dfTrain1: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 6 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 从数据上看,Not very intuitive yet.We are used to the way of data presentation in row and column
- x2_onehot中,3表示一共有3个值(a,b,c),The second indicates which index respectively,第三个都是1
- x3_onehot中,id=2is a bit special,均为[],应该是指null不算,That is, the number of values is 2,(0,1)
The data is transformed below,把vector转换成array,看起来更直观一些
- 可以看到x2_onehot按x2_indexed的[0,1,2]Generate three columns,值为0,1,hit it1否则为0
- 而对于id=2的x3,确实都是0,0,即onehot其实不考虑null值
If there are values in the test set that do not appear in the training set,onehotWhen the value of all columns are automatically changed0,And alsonull处理方法一致.直接看个例子
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble))
val p1 = pipeline. fit( dfTrain)
p1. transform( dfTest). show()
// Convert the feature column to array,再打印出来
println( "--------------------------------")
p1. transform( dfTest). select( "features"). map( x => x( 0). asInstanceOf[ Vector]. toArray). take( 4). foreach( x => println( x. mkString( ",")))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot| features|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|[5.1,0.2,0.0,1.0,...|
| 2|4.9| b| no|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|(7,[0,1,2],[4.9,0...|
| 3|4.7| a| yes|0.2| 0| 1.0| 2.0|(3,[1],[1.0])| (2,[],[])|(7,[0,1,3],[4.7,0...|
| 4|4.6| d|hello|0.2| 1| 3.0| 0.0| (3,[],[])|(2,[0],[1.0])|(7,[0,1,5],[4.6,0...|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
--------------------------------
5.1,0.2,0.0,1.0,0.0,1.0,0.0
4.9,0.2,1.0,0.0,0.0,0.0,0.0
4.7,0.2,0.0,1.0,0.0,0.0,0.0
4.6,0.2,0.0,0.0,0.0,1.0,0.0
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_e2e4f73059df, strIdx_581f33be39be)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_f134aaddf69e
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_2ec94bc0ff67
pipeline: org.apache.spark.ml.Pipeline = pipeline_cdcfa4e895ce
p1: org.apache.spark.ml.PipelineModel = pipeline_cdcfa4e895ce
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
先看features列
- id=1没啥好说的,id=2,3,4All are shorthand
- 以id=3为例,7表示一共7个元素,其中第[0,1,3]是[4.7,0.2,0.0,1],别的都是0
- id=3的x3_onehot中,yesis not present in the training set,所以onehot之后都是0
The following printout intuitively illustrates the specific way of data storage
4 Pipeline
用PipelineString together the process of data processing,简单跑个demo
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val lr = new LogisticRegression().
setMaxIter( 10).
setRegParam( 0.01)
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble, lr))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_29b5d6cde551, strIdx_1fa35b31e12b)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_00efb376d02c
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_78da110d20bf
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_f316005bcb94
pipeline: org.apache.spark.ml.Pipeline = pipeline_e4e6f0fbeabe
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
println( "-------------- Train -------------- ")
pipeModel. transform( dfTrain). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
println( "-------------- Test -------------- ")
pipeModel. transform( dfTest). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
- 1.
- 2.
- 3.
- 4.
-------------- Train --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|[4.7,0.2,1.0,0.0,...|[2.86186371148445...|[0.94592870313351...| 0.0|
| 4|[4.6,0.2,0.0,0.0,...|[-3.5623891952210...|[0.02758825464682...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
-------------- Test --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|(7,[0,1,3],[4.7,0...|[2.40266618214349...|[0.91703038740784...| 0.0|
| 4|(7,[0,1,5],[4.6,0...|[-0.7089322949842...|[0.32983480687831...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
总结下:
- Scala做onehotDon't worry about appearing in the test set and not in the training setvalue
- 用pipeline更方便
Ref
[2] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer
2020-04-02 于南京市江宁区九龙湖
边栏推荐
猜你喜欢
随机推荐
asp dotnet core 通过图片统计 csdn 用户访问
荣耀发布开发者服务平台,智慧生态合作提速
buuctf(探险1)
小程序经典案例
R语言缺失时间序列的填充及合并:补齐时间序列数据中所有缺失的时间索引、使用merge函数合并日期补齐之后的时间序列数据和另外一个时间序列数据(补齐左侧数据)
路由懒加载
Matlab drawing 1
租房小程序登顶码云热门
怎么面试程序员的?傲慢与无礼,就数他牛逼
[Web Automation Test] Quick Start with Playwright, 5 minutes to get started
RecyclerView 缓存与复用机制
荣耀互联对外开放,赋能智能硬件合作伙伴,促进全场景生态产品融合
小程序学习目标
离线同步odps到mysql 中文乱码是因为?mysql已是utf8mb4
LeetCode 899. 有序队列
darknet源码阅读笔记-02-list.h和lish.c
R语言时间序列数据算术运算:使用diff函数计算时间序列数据的逐次差分、使用时间序列之间的除法计算相对变化率(乘以100获得百分比)
2022年7月31日 暑假第三周总结
【日记】高并发下的DB分库分表分区策略
数据集成:holo数据同步至redis。redis必须是集群模式?