当前位置:网站首页>FE01_OneHot-Scala Application

FE01_OneHot-Scala Application

2022-08-04 18:00:00 51CTO


  OneHotIt is a common processing method for dealing with categorical variables,scalaIf it is applied?If it appears in the test set and not in the training setvalue,怎么处理?

1 数据构造

      
      
import org. apache. spark. ml.{ Model, Pipeline, PipelineModel, PipelineStage}
import org. apache. spark. ml. classification. LogisticRegression
import org. apache. spark. ml. feature.{ OneHotEncoderEstimator, StringIndexer}
import org. apache. spark. sql. functions. _
import org. apache. spark. ml. feature. VectorAssembler
import org. apache. spark. sql.{ DataFrame, Row, SparkSession}
val builder = SparkSession
. builder()
. appName( "LR")
. config( "spark.executor.heartbeatInterval", "60s")
. config( "spark.network.timeout", "120s")
. config( "spark.serializer", "org.apache.spark.serializer.KryoSerializer")
. config( "spark.kryoserializer.buffer.max", "512m")
. config( "spark.dynamicAllocation.enabled", false)
. config( "spark.sql.inMemoryColumnarStorage.compressed", true)
. config( "spark.sql.inMemoryColumnarStorage.batchSize", 10000)
. config( "spark.sql.broadcastTimeout", 600)
. config( "spark.sql.autoBroadcastJoinThreshold", - 1)
. config( "spark.sql.crossJoin.enabled", true)
. master( "local[*]")
val spark = builder. getOrCreate()
spark. sparkContext. setLogLevel( "ERROR")
import spark. implicits. _
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
      
      
import org.apache.spark.ml.{Model, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
builder: org.apache.spark.sql.SparkSession.Builder = [email protected]
spark: org.apache.spark.sql.SparkSession = [email protected]
import spark.implicits._
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
      
      
println( "------------- dfTrain -------------")
var dfTrain = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", null, 0.2, 1),
( 3, 4.7, "b", "hi", 0.2, 0),
( 4, 4.6, "c", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
dfTrain. show()
println( "------------- dfTest -------------")
var dfTest = Seq(
( 1, 5.1, "a", "hello", 0.2, 0),
( 2, 4.9, "b", "no", 0.2, 1),
( 3, 4.7, "a", "yes", 0.2, 0),
( 4, 4.6, "d", "hello", 0.2, 1)
). toDF( "id", "x1", "x2", "x3", "x4", "label")
// Test set directlycopy就行了,仅用来测试
dfTest. show()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
      
      
------------- dfTrain -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| null|0.2| 1|
| 3|4.7| b| hi|0.2| 0|
| 4|4.6| c|hello|0.2| 1|
+---+---+---+-----+---+-----+

------------- dfTest -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| no|0.2| 1|
| 3|4.7| a| yes|0.2| 0|
| 4|4.6| d|hello|0.2| 1|
+---+---+---+-----+---+-----+






dfTrain: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
dfTest: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.

注意:

  • dTestappeared in the categorical features of dTrain中没有的value
  • x2中的d,x3中的no,yes

2 StringIndexer转换

  StringIndexerCan encode character fields into tag indices,索引的范围为0to the number of labels(这里指的是dfTrain中了),The order of index building is the frequency of the labels,Labels with higher encoding frequency are preferentially encoded,So the tag with the highest frequency is 0号.如果输入的是数值型的,We will convert it to character type,Then encode it.
  做onehotWe need to do that beforestringIndex转换.

      
      
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
new Pipeline(). setStages( indexers). fit( dfTrain). transform( dfTrain). show()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
      
      
+---+---+---+-----+---+-----+----------+----------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed|
+---+---+---+-----+---+-----+----------+----------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|
+---+---+---+-----+---+-----+----------+----------+






columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_2ab569f4651d, strIdx_4d7196a2a72b)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • StringIndexr和常用的OneHotEncoderEstimator、VectorAssembler不同,Only a single feature can be processed
  • 上面用mapmethod is simple
  • 从x2look at the encoding,the most frequentb标记为0,x3的helloedited to0
  • x3Missing values ​​of are marked as 2.0

3 OneHotEncoderEstimator

The following is the transformation of the dataonehot编码

      
      
val columnsIndexed = Array( "x2", "x3")
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columnsIndexed. map( _ + "_indexed"))
. setOutputCols( columnsIndexed. map( _ + "_onehot"))
  • 1.
  • 2.
  • 3.
  • 4.
      
      
columnsIndexed: Array[String] = Array(x2, x3)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_7d9f19c5de00
  • 1.
  • 2.
      
      
val dfTrain1 = new Pipeline(). setStages( indexers). fit( dfTrain). transform( dfTrain)
onehoter. fit( dfTrain1). transform( dfTrain1). show()
  • 1.
  • 2.
      
      
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|(3,[2],[1.0])|(2,[0],[1.0])|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+






dfTrain1: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 6 more fields]
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 从数据上看,Not very intuitive yet.We are used to the way of data presentation in row and column
  • x2_onehot中,3表示一共有3个值(a,b,c),The second indicates which index respectively,第三个都是1
  • x3_onehot中,id=2is a bit special,均为[],应该是指null不算,That is, the number of values ​​is 2,(0,1)

The data is transformed below,把vector转换成array,看起来更直观一些

  • 可以看到x2_onehot按x2_indexed的[0,1,2]Generate three columns,值为0,1,hit it1否则为0
  • 而对于id=2的x3,确实都是0,0,即onehot其实不考虑null值
      
      
onehoter. fit( dfTrain1). transform( dfTrain1). map( x => ( x( 8). asInstanceOf[ Vector]. toArray, x( 9). asInstanceOf[ Vector]. toArray)). show()
  • 1.
      
      
+---------------+----------+
| _1| _2|
+---------------+----------+
|[0.0, 1.0, 0.0]|[1.0, 0.0]|
|[1.0, 0.0, 0.0]|[0.0, 0.0]|
|[1.0, 0.0, 0.0]|[0.0, 1.0]|
|[0.0, 0.0, 1.0]|[1.0, 0.0]|
+---------------+----------+
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.

If there are values ​​in the test set that do not appear in the training set,onehotWhen the value of all columns are automatically changed0,And alsonull处理方法一致.直接看个例子

      
      
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble))
val p1 = pipeline. fit( dfTrain)
p1. transform( dfTest). show()
// Convert the feature column to array,再打印出来
println( "--------------------------------")
p1. transform( dfTest). select( "features"). map( x => x( 0). asInstanceOf[ Vector]. toArray). take( 4). foreach( x => println( x. mkString( ",")))
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
      
      
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot| features|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|[5.1,0.2,0.0,1.0,...|
| 2|4.9| b| no|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|(7,[0,1,2],[4.9,0...|
| 3|4.7| a| yes|0.2| 0| 1.0| 2.0|(3,[1],[1.0])| (2,[],[])|(7,[0,1,3],[4.7,0...|
| 4|4.6| d|hello|0.2| 1| 3.0| 0.0| (3,[],[])|(2,[0],[1.0])|(7,[0,1,5],[4.6,0...|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+

--------------------------------
5.1,0.2,0.0,1.0,0.0,1.0,0.0
4.9,0.2,1.0,0.0,0.0,0.0,0.0
4.7,0.2,0.0,1.0,0.0,0.0,0.0
4.6,0.2,0.0,0.0,0.0,1.0,0.0





columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_e2e4f73059df, strIdx_581f33be39be)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_f134aaddf69e
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_2ec94bc0ff67
pipeline: org.apache.spark.ml.Pipeline = pipeline_cdcfa4e895ce
p1: org.apache.spark.ml.PipelineModel = pipeline_cdcfa4e895ce
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.

先看features列

  • id=1没啥好说的,id=2,3,4All are shorthand
  • 以id=3为例,7表示一共7个元素,其中第[0,1,3]是[4.7,0.2,0.0,1],别的都是0
  • id=3的x3_onehot中,yesis not present in the training set,所以onehot之后都是0

The following printout intuitively illustrates the specific way of data storage

4 Pipeline

用PipelineString together the process of data processing,简单跑个demo

      
      
val columns = Array( "x2", "x3")
val indexers: Array[ PipelineStage] = columns. map { colName =>
new StringIndexer(). setInputCol( colName). setOutputCol( colName + "_indexed"). setHandleInvalid( "keep")
}
val onehoter = new OneHotEncoderEstimator()
. setInputCols( columns. map( _ + "_indexed"))
. setOutputCols( columns. map( x => x + "_onehot"))
val featureCol = Array( "x1", "x4") ++ columns. map( x => x + "_onehot")
val assemble = new VectorAssembler()
. setInputCols( featureCol)
. setOutputCol( "features")
val lr = new LogisticRegression().
setMaxIter( 10).
setRegParam( 0.01)
val pipeline = new Pipeline(). setStages( indexers ++ Array( onehoter, assemble, lr))
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
      
      
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_29b5d6cde551, strIdx_1fa35b31e12b)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_00efb376d02c
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_78da110d20bf
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_f316005bcb94
pipeline: org.apache.spark.ml.Pipeline = pipeline_e4e6f0fbeabe
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
      
      
val pipeModel = pipeline. fit( dfTrain)
  • 1.
      
      
pipeModel: org.apache.spark.ml.PipelineModel = pipeline_e4e6f0fbeabe
  • 1.
      
      
println( "-------------- Train -------------- ")
pipeModel. transform( dfTrain). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
println( "-------------- Test -------------- ")
pipeModel. transform( dfTest). select( "id", "features", "rawPrediction", "probability", "prediction"). show()
  • 1.
  • 2.
  • 3.
  • 4.
      
      
-------------- Train --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|[4.7,0.2,1.0,0.0,...|[2.86186371148445...|[0.94592870313351...| 0.0|
| 4|[4.6,0.2,0.0,0.0,...|[-3.5623891952210...|[0.02758825464682...| 1.0|
+---+--------------------+--------------------+--------------------+----------+

-------------- Test --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|(7,[0,1,3],[4.7,0...|[2.40266618214349...|[0.91703038740784...| 0.0|
| 4|(7,[0,1,5],[4.6,0...|[-0.7089322949842...|[0.32983480687831...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.

总结下:

  • Scala做onehotDon't worry about appearing in the test set and not in the training setvalue
  • 用pipeline更方便

Ref

 ​[2] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer​​​ ​​

 ​[3] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder​

                                2020-04-02 于南京市江宁区九龙湖


原网站

版权声明
本文为[51CTO]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/216/202208041755447592.html