当前位置:网站首页>Spark特征工程-one-hot 和 multi-hot
Spark特征工程-one-hot 和 multi-hot
2022-08-02 03:28:00 【Code_LT】
One-hot
/** * One-hot encoding example function * @param samples movie samples dataframe */
def oneHotEncoderExample(samples:DataFrame): Unit ={
val samplesWithIdNumber = samples.withColumn("movieIdNumber", col("movieId").cast(sql.types.IntegerType))
val oneHotEncoder = new OneHotEncoderEstimator()
.setInputCols(Array("movieIdNumber"))
.setOutputCols(Array("movieIdVector"))
.setDropLast(false)
val oneHotEncoderSamples = oneHotEncoder.fit(samplesWithIdNumber).transform(samplesWithIdNumber)
oneHotEncoderSamples.printSchema()
oneHotEncoderSamples.show(10)
}
Multi-hot
val array2vec: UserDefinedFunction = udf {
(a: Seq[Int], length: Int) => org.apache.spark.ml.linalg.Vectors.sparse(length, a.sortWith(_ < _).toArray, Array.fill[Double](a.length)(1.0)) }
/** * Multi-hot encoding example function * @param samples movie samples dataframe */
def multiHotEncoderExample(samples:DataFrame): Unit ={
val samplesWithGenre = samples.select(col("movieId"), col("title"),explode(split(col("genres"), "\\|").cast("array<string>")).as("genre"))
val genreIndexer = new StringIndexer().setInputCol("genre").setOutputCol("genreIndex")
val stringIndexerModel : StringIndexerModel = genreIndexer.fit(samplesWithGenre)
val genreIndexSamples = stringIndexerModel.transform(samplesWithGenre)
.withColumn("genreIndexInt", col("genreIndex").cast(sql.types.IntegerType))
/* println("genreIndexSamples:") genreIndexSamples.printSchema() genreIndexSamples.show(10,false) println("genreIndexSamples.agg:") genreIndexSamples.agg(max(col("genreIndexInt"))).show(10,false)*/
val indexSize = genreIndexSamples.agg(max(col("genreIndexInt"))).head().getAs[Int](0) + 1
val processedSamples = genreIndexSamples
.groupBy(col("movieId")).agg(collect_list("genreIndexInt").as("genreIndexes"))
.withColumn("indexSize", typedLit(indexSize))
val finalSample = processedSamples.withColumn("vector", array2vec(col("genreIndexes"),col("indexSize")))
finalSample.printSchema()
finalSample.show(10,false)
}
注释:
StringIndexer的使用
lit和typeLit
collect_list
agg使用
spark的聚合函数
输出样例:
one-hot
Raw Movie Samples:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genres: string (nullable = true)
+-------+--------------------+--------------------+
|movieId| title| genres|
+-------+--------------------+--------------------+
| 1| Toy Story (1995)|Adventure|Animati...|
| 2| Jumanji (1995)|Adventure|Childre...|
| 3|Grumpier Old Men ...| Comedy|Romance|
| 4|Waiting to Exhale...|Comedy|Drama|Romance|
| 5|Father of the Bri...| Comedy|
| 6| Heat (1995)|Action|Crime|Thri...|
| 7| Sabrina (1995)| Comedy|Romance|
| 8| Tom and Huck (1995)| Adventure|Children|
| 9| Sudden Death (1995)| Action|
| 10| GoldenEye (1995)|Action|Adventure|...|
+-------+--------------------+--------------------+
only showing top 10 rows
OneHotEncoder Example:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genres: string (nullable = true)
|-- movieIdNumber: integer (nullable = true)
|-- movieIdVector: vector (nullable = true)
+-------+--------------------+--------------------+-------------+-----------------+
|movieId| title| genres|movieIdNumber| movieIdVector|
+-------+--------------------+--------------------+-------------+-----------------+
| 1| Toy Story (1995)|Adventure|Animati...| 1| (1001,[1],[1.0])|
| 2| Jumanji (1995)|Adventure|Childre...| 2| (1001,[2],[1.0])|
| 3|Grumpier Old Men ...| Comedy|Romance| 3| (1001,[3],[1.0])|
| 4|Waiting to Exhale...|Comedy|Drama|Romance| 4| (1001,[4],[1.0])|
| 5|Father of the Bri...| Comedy| 5| (1001,[5],[1.0])|
| 6| Heat (1995)|Action|Crime|Thri...| 6| (1001,[6],[1.0])|
| 7| Sabrina (1995)| Comedy|Romance| 7| (1001,[7],[1.0])|
| 8| Tom and Huck (1995)| Adventure|Children| 8| (1001,[8],[1.0])|
| 9| Sudden Death (1995)| Action| 9| (1001,[9],[1.0])|
| 10| GoldenEye (1995)|Action|Adventure|...| 10|(1001,[10],[1.0])|
+-------+--------------------+--------------------+-------------+-----------------+
multi-hot
MultiHotEncoder Example:
genreIndexSamples:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genre: string (nullable = true)
|-- genreIndex: double (nullable = false)
|-- genreIndexInt: integer (nullable = true)
+-------+-----------------------+---------+----------+-------------+
|movieId|title |genre |genreIndex|genreIndexInt|
+-------+-----------------------+---------+----------+-------------+
|1 |Toy Story (1995) |Adventure|6.0 |6 |
|1 |Toy Story (1995) |Animation|15.0 |15 |
|1 |Toy Story (1995) |Children |7.0 |7 |
|1 |Toy Story (1995) |Comedy |1.0 |1 |
|1 |Toy Story (1995) |Fantasy |10.0 |10 |
|2 |Jumanji (1995) |Adventure|6.0 |6 |
|2 |Jumanji (1995) |Children |7.0 |7 |
|2 |Jumanji (1995) |Fantasy |10.0 |10 |
|3 |Grumpier Old Men (1995)|Comedy |1.0 |1 |
|3 |Grumpier Old Men (1995)|Romance |2.0 |2 |
+-------+-----------------------+---------+----------+-------------+
genreIndexSamples.agg:
+------------------+
|max(genreIndexInt)|
+------------------+
|18 |
+------------------+
finalSample:
root
|-- movieId: string (nullable = true)
|-- genreIndexes: array (nullable = true)
| |-- element: integer (containsNull = true)
|-- indexSize: integer (nullable = false)
|-- vector: vector (nullable = true)
+-------+------------+---------+--------------------------------+
|movieId|genreIndexes|indexSize|vector |
+-------+------------+---------+--------------------------------+
|296 |[1, 5, 0, 3]|19 |(19,[0,1,3,5],[1.0,1.0,1.0,1.0])|
|467 |[1] |19 |(19,[1],[1.0]) |
|675 |[4, 0, 3] |19 |(19,[0,3,4],[1.0,1.0,1.0]) |
|691 |[1, 2] |19 |(19,[1,2],[1.0,1.0]) |
|829 |[1, 10, 14] |19 |(19,[1,10,14],[1.0,1.0,1.0]) |
|125 |[1] |19 |(19,[1],[1.0]) |
|451 |[0, 8, 2] |19 |(19,[0,2,8],[1.0,1.0,1.0]) |
|800 |[0, 8, 16] |19 |(19,[0,8,16],[1.0,1.0,1.0]) |
|853 |[0] |19 |(19,[0],[1.0]) |
|944 |[0] |19 |(19,[0],[1.0]) |
+-------+------------+---------+--------------------------------+
另一种multi-hot方法(适合标签量不大的情况)
主要是靠获取getWordsIndexMap,然后做映射
def getWordsIndexMap(rdd: RDD[Set[String]], ss: SparkSession): Broadcast[Map[String, Int]] = {
val allWords = rdd.map {
x => (1, x) }.reduceByKey((x, y) => x ++ y).collect().head._2.toArray.sorted
val wordsMapbt = ss.sparkContext.broadcast(allWords.zip(0.until(allWords.length)).toMap)
wordsMapbt
}
def transformVec(rdd: RDD[(String, Set[String], String)], ss: SparkSession, mp: Broadcast[Map[String, Int]]) = {
import ss.sqlContext.implicits._
val indexDF = rdd.map {
x => x._1 }.distinct().zipWithUniqueId().toDF("id", "index")
val outRDD = rdd.toDF("id", "keywords", "from")
.join(indexDF, "id")
.select("index", "keywords", "from")
.rdd
.map {
case Row(index: Long, keywords: collection.mutable.WrappedArray[String], from: String) =>
val len = mp.value.size
val arr1 = keywords.toArray.sorted.map {
x =>
mp.value(x)
}
val arr2 = arr1.map {
x => 1.0 }
(index, Vectors.sparse(len, arr1, arr2), from)
}
(indexDF, outRDD)
}
边栏推荐
- Go中的一些优化笔记,简约而不简单
- SATA M2 SSD 无法安装系统的解决方法
- Debian 12 Bookworm 尝鲜记
- laravel 查询数据库获取结果如何判断是否为空?
- 修复APP的BUG,热修复的知识点和大厂的相关资料汇总
- Go 程序太大了,能要个延迟初始化不?
- laravel 写api接口时 session获取不到处理办法
- Vision Transformer(ViT)论文精读和Pytorch实现代码解析
- 关于我的项目-微信小程序2(uniapp->wx小程序)
- 账务处理程序、记账凭证账务处理程序、汇总记账凭证账务处理程序、科目汇总表账务处理程序、会计信息化概述、信息化环境下会计账务处理的基本要求(此章出1道小题)
猜你喜欢
随机推荐
如何在正则表达式里表达可能存在也可能不存在的内容?
SGDP(2)——声纳寻宝游戏
ReentrantLock的使用和原理详解
MVC,MVP和MVVM架构解析
同时安装VirtualBox和VMware,虚拟机如何上网
加密数字货币前传:从大卫·乔姆到中本聪
聊一聊数据库的行存与列存
真·杂项:资本论阅读笔记(随缘更新)
【无标题】
VS2017报错:LNK1120 1 个无法解析的外部命令
Binder机制详解(三)
Android-Kotlin anko库实现优雅跳转
面试必备:Android性能分析与优化实战进阶手册
Acwing:哈夫曼树(详解)
学IT,找工作——移除链表元素
解决flex布局warp自动换行下最后一行居中问题
不懂“赚钱逻辑”,你永远都是社会最底层(广告电商)
成本会计的概念、产品成本核算的要求、产品成本核算的对象与成本项目、产品成本的归集和分配(可能考判断)、产品成本计算方法 (三种:产品的品种(品种法),批次(分批法),步骤(分步法))
Win10 解决AMD平台下SVM无法开启的问题
【萌新解题】斐波那契数列