当前位置:网站首页>FE01_OneHot-Scala应用
FE01_OneHot-Scala应用
2022-08-04 17:58:00 【51CTO】
OneHot是处理类别型变量常用的处理方法,scala中如果应用呢?如果测试集中出现训练集中没有value,怎么处理?
1 数据构造
import
org.
apache.
spark.
ml.{
Model,
Pipeline,
PipelineModel,
PipelineStage}
import
org.
apache.
spark.
ml.
classification.
LogisticRegression
import
org.
apache.
spark.
ml.
feature.{
OneHotEncoderEstimator,
StringIndexer}
import
org.
apache.
spark.
sql.
functions.
_
import
org.
apache.
spark.
ml.
feature.
VectorAssembler
import
org.
apache.
spark.
sql.{
DataFrame,
Row,
SparkSession}
val
builder
=
SparkSession
.
builder()
.
appName(
"LR")
.
config(
"spark.executor.heartbeatInterval",
"60s")
.
config(
"spark.network.timeout",
"120s")
.
config(
"spark.serializer",
"org.apache.spark.serializer.KryoSerializer")
.
config(
"spark.kryoserializer.buffer.max",
"512m")
.
config(
"spark.dynamicAllocation.enabled",
false)
.
config(
"spark.sql.inMemoryColumnarStorage.compressed",
true)
.
config(
"spark.sql.inMemoryColumnarStorage.batchSize",
10000)
.
config(
"spark.sql.broadcastTimeout",
600)
.
config(
"spark.sql.autoBroadcastJoinThreshold",
-
1)
.
config(
"spark.sql.crossJoin.enabled",
true)
.
master(
"local[*]")
val
spark
=
builder.
getOrCreate()
spark.
sparkContext.
setLogLevel(
"ERROR")
import
spark.
implicits.
_
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
import org.apache.spark.ml.{Model, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
builder: org.apache.spark.sql.SparkSession.Builder = [email protected]
spark: org.apache.spark.sql.SparkSession = [email protected]
import spark.implicits._
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
println(
"------------- dfTrain -------------")
var
dfTrain
=
Seq(
(
1,
5.1,
"a",
"hello",
0.2,
0),
(
2,
4.9,
"b",
null,
0.2,
1),
(
3,
4.7,
"b",
"hi",
0.2,
0),
(
4,
4.6,
"c",
"hello",
0.2,
1)
).
toDF(
"id",
"x1",
"x2",
"x3",
"x4",
"label")
dfTrain.
show()
println(
"------------- dfTest -------------")
var
dfTest
=
Seq(
(
1,
5.1,
"a",
"hello",
0.2,
0),
(
2,
4.9,
"b",
"no",
0.2,
1),
(
3,
4.7,
"a",
"yes",
0.2,
0),
(
4,
4.6,
"d",
"hello",
0.2,
1)
).
toDF(
"id",
"x1",
"x2",
"x3",
"x4",
"label")
// 测试集直接copy就行了,仅用来测试
dfTest.
show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
------------- dfTrain -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| null|0.2| 1|
| 3|4.7| b| hi|0.2| 0|
| 4|4.6| c|hello|0.2| 1|
+---+---+---+-----+---+-----+
------------- dfTest -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| no|0.2| 1|
| 3|4.7| a| yes|0.2| 0|
| 4|4.6| d|hello|0.2| 1|
+---+---+---+-----+---+-----+
dfTrain: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
dfTest: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
注意:
- dTest的类别型特征中出现了dTrain中没有的value
- x2中的d,x3中的no,yes
2 StringIndexer转换
StringIndexer能把字符型字段编码成标签索引,索引的范围为0到标签数量(这里指的是dfTrain中了),索引构建的顺序为标签的频率,优先编码频率较大的标签,所以出现频率最高的标签为0号。如果输入的是数值型的,我们会把它转化成字符型,然后再对其进行编码。
做onehot之前我们需要先做stringIndex转换。
val
columns
=
Array(
"x2",
"x3")
val
indexers:
Array[
PipelineStage]
=
columns.
map {
colName
=>
new
StringIndexer().
setInputCol(
colName).
setOutputCol(
colName
+
"_indexed").
setHandleInvalid(
"keep")
}
new
Pipeline().
setStages(
indexers).
fit(
dfTrain).
transform(
dfTrain).
show()
- 1.
- 2.
- 3.
- 4.
- 5.
+---+---+---+-----+---+-----+----------+----------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed|
+---+---+---+-----+---+-----+----------+----------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|
+---+---+---+-----+---+-----+----------+----------+
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_2ab569f4651d, strIdx_4d7196a2a72b)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- StringIndexr和常用的OneHotEncoderEstimator、VectorAssembler不同,只能单个特征的处理
- 上面用map方法做了简单的处理
- 从x2的编码看,频数最多的b标记为0,x3的hello被编辑为0
- x3的缺失值被标记为2.0
3 OneHotEncoderEstimator
下面对转换后的数据进行onehot编码
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|(3,[2],[1.0])|(2,[0],[1.0])|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
dfTrain1: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 6 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 从数据上看,还不是很直观。我们习惯行列方式的数据展示方式
- x2_onehot中,3表示一共有3个值(a,b,c),第二个表示分别是哪个索引,第三个都是1
- x3_onehot中,id=2的有点特殊,均为[],应该是指null不算,即值个个数为2,(0,1)
下面对数据做了转换,把vector转换成array,看起来更直观一些
- 可以看到x2_onehot按x2_indexed的[0,1,2]生成三列,值为0,1,命中就1否则为0
- 而对于id=2的x3,确实都是0,0,即onehot其实不考虑null值
如果测试集中有训练集中没出现的值,onehot时会自动把所有列的值都改成0,也是就和null处理方法一致。直接看个例子
val
columns
=
Array(
"x2",
"x3")
val
indexers:
Array[
PipelineStage]
=
columns.
map {
colName
=>
new
StringIndexer().
setInputCol(
colName).
setOutputCol(
colName
+
"_indexed").
setHandleInvalid(
"keep")
}
val
onehoter
=
new
OneHotEncoderEstimator()
.
setInputCols(
columns.
map(
_
+
"_indexed"))
.
setOutputCols(
columns.
map(
x
=>
x
+
"_onehot"))
val
featureCol
=
Array(
"x1",
"x4")
++
columns.
map(
x
=>
x
+
"_onehot")
val
assemble
=
new
VectorAssembler()
.
setInputCols(
featureCol)
.
setOutputCol(
"features")
val
pipeline
=
new
Pipeline().
setStages(
indexers
++
Array(
onehoter,
assemble))
val
p1
=
pipeline.
fit(
dfTrain)
p1.
transform(
dfTest).
show()
// 把特征列转成array,再打印出来
println(
"--------------------------------")
p1.
transform(
dfTest).
select(
"features").
map(
x
=>
x(
0).
asInstanceOf[
Vector].
toArray).
take(
4).
foreach(
x
=>
println(
x.
mkString(
",")))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot| features|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|[5.1,0.2,0.0,1.0,...|
| 2|4.9| b| no|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|(7,[0,1,2],[4.9,0...|
| 3|4.7| a| yes|0.2| 0| 1.0| 2.0|(3,[1],[1.0])| (2,[],[])|(7,[0,1,3],[4.7,0...|
| 4|4.6| d|hello|0.2| 1| 3.0| 0.0| (3,[],[])|(2,[0],[1.0])|(7,[0,1,5],[4.6,0...|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
--------------------------------
5.1,0.2,0.0,1.0,0.0,1.0,0.0
4.9,0.2,1.0,0.0,0.0,0.0,0.0
4.7,0.2,0.0,1.0,0.0,0.0,0.0
4.6,0.2,0.0,0.0,0.0,1.0,0.0
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_e2e4f73059df, strIdx_581f33be39be)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_f134aaddf69e
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_2ec94bc0ff67
pipeline: org.apache.spark.ml.Pipeline = pipeline_cdcfa4e895ce
p1: org.apache.spark.ml.PipelineModel = pipeline_cdcfa4e895ce
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
先看features列
- id=1没啥好说的,id=2,3,4都是简写的方式
- 以id=3为例,7表示一共7个元素,其中第[0,1,3]是[4.7,0.2,0.0,1],别的都是0
- id=3的x3_onehot中,yes是训练集中没出现的,所以onehot之后都是0
后面打印输出直观的说明了数据储存的具体方式
4 Pipeline
用Pipeline把数据处理的流程串起来,简单跑个demo
val
columns
=
Array(
"x2",
"x3")
val
indexers:
Array[
PipelineStage]
=
columns.
map {
colName
=>
new
StringIndexer().
setInputCol(
colName).
setOutputCol(
colName
+
"_indexed").
setHandleInvalid(
"keep")
}
val
onehoter
=
new
OneHotEncoderEstimator()
.
setInputCols(
columns.
map(
_
+
"_indexed"))
.
setOutputCols(
columns.
map(
x
=>
x
+
"_onehot"))
val
featureCol
=
Array(
"x1",
"x4")
++
columns.
map(
x
=>
x
+
"_onehot")
val
assemble
=
new
VectorAssembler()
.
setInputCols(
featureCol)
.
setOutputCol(
"features")
val
lr
=
new
LogisticRegression().
setMaxIter(
10).
setRegParam(
0.01)
val
pipeline
=
new
Pipeline().
setStages(
indexers
++
Array(
onehoter,
assemble,
lr))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_29b5d6cde551, strIdx_1fa35b31e12b)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_00efb376d02c
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_78da110d20bf
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_f316005bcb94
pipeline: org.apache.spark.ml.Pipeline = pipeline_e4e6f0fbeabe
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
println(
"-------------- Train -------------- ")
pipeModel.
transform(
dfTrain).
select(
"id",
"features",
"rawPrediction",
"probability",
"prediction").
show()
println(
"-------------- Test -------------- ")
pipeModel.
transform(
dfTest).
select(
"id",
"features",
"rawPrediction",
"probability",
"prediction").
show()
- 1.
- 2.
- 3.
- 4.
-------------- Train --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|[4.7,0.2,1.0,0.0,...|[2.86186371148445...|[0.94592870313351...| 0.0|
| 4|[4.6,0.2,0.0,0.0,...|[-3.5623891952210...|[0.02758825464682...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
-------------- Test --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|(7,[0,1,3],[4.7,0...|[2.40266618214349...|[0.91703038740784...| 0.0|
| 4|(7,[0,1,5],[4.6,0...|[-0.7089322949842...|[0.32983480687831...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
总结下:
- Scala做onehot不用担心测试集中出现训练集中没有value
- 用pipeline更方便
Ref
[2] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer
2020-04-02 于南京市江宁区九龙湖
边栏推荐
- R语言缺失时间序列的填充及合并:补齐时间序列数据中所有缺失的时间索引、使用merge函数合并日期补齐之后的时间序列数据和另外一个时间序列数据(补齐左侧数据)
- 网络靶场监控系统的安全加固纪实(1)—SSL/TLS对日志数据加密传输
- 《机器学习的随机矩阵方法》
- Literature Review on Involution of College Students
- Introduction of three temperature measurement methods for PT100 platinum thermal resistance
- Create Sentinel high-availability cluster current limiting middleware from -99
- 使用Redis做某个时间段在线数统计
- 基于层次分析法的“内卷”指数分析
- R语言glm函数使用频数数据构建二分类logistic回归模型,分析的输入数据为频数数据(多个分类指标对应的阴性样本和阳性样本的频数数据)、weights参数指定频数值
- 2022 May 1 Mathematical Modeling Question C Explanation
猜你喜欢
随机推荐
网络靶场监控系统的安全加固纪实(1)—SSL/TLS对日志数据加密传输
数仓相关,总结
企业调查相关性分析案例
C. LIS or Reverse LIS?
第一章 对象和封装
Go 言 Go 语,一文看懂 Go 语言文件操作
php如何查询字符串以什么开头
R语言ggplot2可视化:使用ggpubr包的ggbarplot函数可视化柱状图、color参数指定柱状图的边框的色彩
LeetCode 899. 有序队列
框架整合(二)- 使用Apache ShardingSphere实现数据分片
mysql cdc 为什么需要RELOAD 这个权限?这个权限在采集数据的过程中的作用是什么?有哪
CF86D Powerful array
并发编程原理学习-reentrantlock源码分析
LeetCode 899. Ordered Queues
Enterprise survey correlation analysis case
语音识别学习资源
clickhouse online and offline table
【无标题】
EasyCVR调用云端录像API接口返回错误且无录像文件生成,是什么原因?
DMPE-PEG-Mal,二肉豆蔻酰磷脂酰乙醇胺-聚乙二醇-马来酰亚胺简述









