2016-04-25 15 views
7

Tất cả các data types in pyspark.sql.types are:Làm thế nào để trả về một "loại Tuple" trong một UDF trong PySpark?

__all__ = [ 
    "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", 
    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", 
    "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] 

tôi phải viết một UDF (trong pyspark) trả về một mảng của các bộ. Tôi đưa ra đối số thứ hai cho nó là kiểu trả về của phương pháp udf là gì? Nó sẽ là một cái gì đó trên các dòng của ArrayType(TupleType()) ...

+0

Câu hỏi tiêu đề của bạn dường như không khớp với nội dung. Tài liệu không cho bạn biết cách đặt giá trị trả về là * "loại vùng chứa loại khác" *? – jonrsharpe

+0

@jonrsharpe Tôi đã thay đổi tiêu đề. Hy vọng rằng đó là đại diện của cơ thể bây giờ. – kamalbanga

Trả lời

11

Không có điều nào như là một TupleType trong Spark. Các loại sản phẩm được thể hiện là structs với các trường thuộc loại cụ thể. Ví dụ, nếu bạn muốn trả về một mảng các cặp (integer, string) bạn có thể sử dụng sơ đồ như thế này: sử dụng

from pyspark.sql.types import * 

schema = ArrayType(StructType([ 
    StructField("char", StringType(), False), 
    StructField("count", IntegerType(), False) 
])) 

Ví dụ:

from pyspark.sql.functions import udf 
from collections import Counter 

char_count_udf = udf(
    lambda s: Counter(s).most_common(), 
    schema 
) 

df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"]) 

df.select("*", char_count_udf(df["value"])).show(2, False) 

## +---+-----+-------------------------+ 
## |id |value|PythonUDF#<lambda>(value)| 
## +---+-----+-------------------------+ 
## |1 |foo |[[o,2], [f,1]]   | 
## |2 |bar |[[r,1], [a,1], [b,1]] | 
## +---+-----+-------------------------+ 
+0

Câu trả lời của bạn đang hoạt động, nhưng trường hợp của tôi hơi phức tạp. Dữ liệu trả về của tôi là kiểu '[('a1', [('b1', 1), ('b2', 2)]), ('a2', [('b1', 1), ('b2 ', 2)])] 'và vì vậy tôi tạo một kiểu như' ArrayType (StructType ([StructField ("date", StringType(), False), ArrayType (StructType ([StructField ("hashId", StringType(), False), StructField ("TimeSpent-Front", FloatType(), False), StructField ("TimeSpent-Back", FloatType(), False)]))))) 'cung cấp đối tượng ** 'ArrayType' không có thuộc tính ' name '** ... – kamalbanga

+1

'StructType' yêu cầu một chuỗi' StructFields' do đó bạn không thể sử dụng 'ArrayTypes' một mình. Bạn cần 'StructField' lưu trữ' ArrayType'. Ngoài ra lời khuyên - nếu bạn thấy mình tạo ra các cấu trúc như thế này bạn có lẽ nên suy nghĩ lại mô hình dữ liệu. Các cấu trúc lồng nhau rất khó xử lý mà không có UDF và các UDF của Python không hiệu quả. – zero323

+0

Làm cách nào tôi có thể chỉ định lược đồ trong udf để trả về một danh sách. F.udf (lambda start_date, end_date: [0,1] nếu start_date pseudocode

4

Stackoverflow giữ đạo tôi cho câu hỏi này, vì vậy tôi đoán tôi sẽ thêm một số thông tin ở đây.

Trở loại đơn giản từ UDF:

from pyspark.sql.types import * 
from pyspark.sql import functions as F 

def get_df(): 
    d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)] 
    df = sqlContext.createDataFrame(d, ['x', 'y']) 
    return df 

df = get_df() 
df.show() 

# +---+---+ 
# | x| y| 
# +---+---+ 
# |0.0|0.0| 
# |0.0|3.0| 
# |1.0|6.0| 
# |1.0|9.0| 
# +---+---+ 

func = udf(lambda x: str(x), StringType()) 
df = df.withColumn('y_str', func('y')) 

func = udf(lambda x: int(x), IntegerType()) 
df = df.withColumn('y_int', func('y')) 

df.show() 

# +---+---+-----+-----+ 
# | x| y|y_str|y_int| 
# +---+---+-----+-----+ 
# |0.0|0.0| 0.0| 0| 
# |0.0|3.0| 3.0| 3| 
# |1.0|6.0| 6.0| 6| 
# |1.0|9.0| 9.0| 9| 
# +---+---+-----+-----+ 

df.printSchema() 

# root 
# |-- x: double (nullable = true) 
# |-- y: double (nullable = true) 
# |-- y_str: string (nullable = true) 
# |-- y_int: integer (nullable = true) 

Khi số nguyên không đủ:

df = get_df() 

func = udf(lambda x: [0]*int(x), ArrayType(IntegerType())) 
df = df.withColumn('list', func('y')) 

func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
      MapType(FloatType(), StringType())) 
df = df.withColumn('map', func('y')) 

df.show() 
# +---+---+--------------------+--------------------+ 
# | x| y|    list|     map| 
# +---+---+--------------------+--------------------+ 
# |0.0|0.0|     []|    Map()| 
# |0.0|3.0|   [0, 0, 0]|Map(2.0 -> 2, 0.0...| 
# |1.0|6.0| [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...| 
# |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...| 
# +---+---+--------------------+--------------------+ 

df.printSchema() 
# root 
# |-- x: double (nullable = true) 
# |-- y: double (nullable = true) 
# |-- list: array (nullable = true) 
# | |-- element: integer (containsNull = true) 
# |-- map: map (nullable = true) 
# | |-- key: float 
# | |-- value: string (valueContainsNull = true) 

Trở kiểu dữ liệu phức tạp từ UDF:

df = get_df() 
df = df.groupBy('x').agg(F.collect_list('y').alias('y[]')) 
df.show() 

# +---+----------+ 
# | x|  y[]| 
# +---+----------+ 
# |0.0|[0.0, 3.0]| 
# |1.0|[9.0, 6.0]| 
# +---+----------+ 

schema = StructType([ 
    StructField("min", FloatType(), True), 
    StructField("size", IntegerType(), True), 
    StructField("edges", ArrayType(FloatType()), True), 
    StructField("val_to_index", MapType(FloatType(), IntegerType()), True) 
    # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)])) 

]) 

def func(values): 
    mn = min(values) 
    size = len(values) 
    lst = sorted(values)[::-1] 
    val_to_index = {x: i for i, x in enumerate(values)} 
    return (mn, size, lst, val_to_index) 

func = udf(func, schema) 
dff = df.select('*', func('y[]').alias('complex_type')) 
dff.show(10, False) 

# +---+----------+------------------------------------------------------+ 
# |x |y[]  |complex_type           | 
# +---+----------+------------------------------------------------------+ 
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| 
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| 
# +---+----------+------------------------------------------------------+ 

dff.printSchema() 

# +---+----------+------------------------------------------------------+ 
# |x |y[]  |complex_type           | 
# +---+----------+------------------------------------------------------+ 
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| 
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| 
# +---+----------+------------------------------------------------------+ 

Đi qua nhiều đối số cho một UDF:

df = get_df() 
func = udf(lambda arr: arr[0]*arr[1],FloatType()) 
df = df.withColumn('x*y', func(F.array('x', 'y'))) 

    # +---+---+---+ 
    # | x| y|x*y| 
    # +---+---+---+ 
    # |0.0|0.0|0.0| 
    # |0.0|3.0|0.0| 
    # |1.0|6.0|6.0| 
    # |1.0|9.0|9.0| 
    # +---+---+---+ 

Mã này hoàn toàn dành cho mục đích demo, tất cả các phép biến đổi trên đều có sẵn trong mã Spark và sẽ mang lại hiệu suất tốt hơn nhiều. Như @ zero323 trong nhận xét ở trên, các UDF nói chung nên tránh trong pyspark; các kiểu phức tạp trở về sẽ khiến bạn suy nghĩ về việc đơn giản hóa logic của mình.

Các vấn đề liên quan