当前位置:网站首页>pyspark @udf loop using variable problem

pyspark @udf loop using variable problem

2022-08-03 07:41:00 WGS。

问题描述

通过@udfway to add two columns,udfThe broadcast variable applied to each column is different.But in a few rounds,udfThe global variables accepted in are all first-round loops.伪代码如下:

from pyspark.sql.functions import udf

tlowdata = ss.createDataFrame([{
    'suuid': 'DDD1', 'oaid': '00-01', 'y': 1},
                    {
    'suuid': 'DOOD', 'oaid': '00-02', 'y': 0}, 
                    {
    'suuid': '009-1234', 'oaid': 'default1', 'y': 0},
                    {
    'suuid': 'DDD1', 'oaid': 'ttt', 'y': 0},
                    {
    'suuid': 'www', 'oaid': 'fwao', 'y': 0},
                    {
    'suuid': 'www', 'oaid': 'fff1', 'y': 0},])
tlowdata.show()

@udf
def tmp_udf(uid):
    return str(tmplst.value)

cols = ['suuid', 'oaid']
lst = [[1, 2, 3], [4, 5, 6]]
global tmplst
for i in range(len(lst)):
    tmplst = lst[i]
    
    tmplst = sc.broadcast(tmplst)
    print(tmplst.value)
    
    tlowdata = tlowdata.withColumn(c[i] + '_flag', tmp_udf(fn.col(c[i])))
    
    tmplst.unpersist()
    
tlowdata.show()
+--------+--------+---+
|    oaid|   suuid|  y|
+--------+--------+---+
|   00-01|    DDD1|  1|
|   00-02|    DOOD|  0|
|default1|009-1234|  0|
|     ttt|    DDD1|  0|
|    fwao|     www|  0|
|    fff1|     www|  0|
+--------+--------+---+

[1, 2, 3]
[4, 5, 6]
+--------+--------+---+----------+---------+
|    oaid|   suuid|  y|suuid_flag|oaid_flag|
+--------+--------+---+----------+---------+
|   00-01|    DDD1|  1| [1, 2, 3]|[1, 2, 3]|
|   00-02|    DOOD|  0| [1, 2, 3]|[1, 2, 3]|
|default1|009-1234|  0| [1, 2, 3]|[1, 2, 3]|
|     ttt|    DDD1|  0| [1, 2, 3]|[1, 2, 3]|
|    fwao|     www|  0| [1, 2, 3]|[1, 2, 3]|
|    fff1|     www|  0| [1, 2, 3]|[1, 2, 3]|
+--------+--------+---+----------+---------+

理论上,在新增oaid_flag这一列的时候,应该是[4, 5, 6],But no matter how many times it is looped,udfThe broadcast variable accepted in is always [1, 2, 3].I doubt itspark2.4版本的问题了…

解决方案

放弃@udf,使用lambda udf或者使用UserDefinedFunction

  • lambda udf
# lambda udf
tmp_udf = fn.udf(lambda x: str(tmplst.value))
tlowdata = tlowdata.withColumn(c[i] + '_flag', tmp_udf(fn.col(c[i])))
  • UserDefinedFunction
# UserDefinedFunction
from pyspark.sql.udf import UserDefinedFunction

def tmp_udf(uid):
    return str(tmplst.value)

tlowdata = tlowdata.withColumn(c[i] + '_flag', UserDefinedFunction(lambda x: tmp_udf(x))(fn.col(c[i])))
+--------+--------+---+
|    oaid|   suuid|  y|
+--------+--------+---+
|   00-01|    DDD1|  1|
|   00-02|    DOOD|  0|
|default1|009-1234|  0|
|     ttt|    DDD1|  0|
|    fwao|     www|  0|
|    fff1|     www|  0|
+--------+--------+---+

[1, 2, 3]
[4, 5, 6]
+--------+--------+---+----------+---------+
|    oaid|   suuid|  y|suuid_flag|oaid_flag|
+--------+--------+---+----------+---------+
|   00-01|    DDD1|  1| [1, 2, 3]|[4, 5, 6]|
|   00-02|    DOOD|  0| [1, 2, 3]|[4, 5, 6]|
|default1|009-1234|  0| [1, 2, 3]|[4, 5, 6]|
|     ttt|    DDD1|  0| [1, 2, 3]|[4, 5, 6]|
|    fwao|     www|  0| [1, 2, 3]|[4, 5, 6]|
|    fff1|     www|  0| [1, 2, 3]|[4, 5, 6]|
+--------+--------+---+----------+---------+

没想明白是什么原因,Temporarily solved by the above method.I'll update when I figure it out.

原网站

版权声明
本文为[WGS。]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/215/202208030527206104.html