2015-12-10 17 views
15

Tôi có một dataframe với schema như vậy:Làm cách nào để tổng hợp các giá trị vào bộ sưu tập sau groupBy?

[visitorId: string, trackingIds: array<string>, emailIds: array<string>] 

Tìm kiếm một cách để nhóm (hoặc có thể Rollup?) Dataframe này bằng cách visitorId nơi trackingIds và cột emailIds sẽ gắn với nhau. Vì vậy, ví dụ nếu df ban đầu của tôi trông giống như:

visitorId |trackingIds|emailIds 
+-----------+------------+-------- 
|a158|  [666b]  | [12] 
|7g21|  [c0b5]  | [45] 
|7g21|  [c0b4]  | [87] 
|a158|  [666b, 777c]| [] 

Tôi muốn df đầu ra của tôi trông như thế này

visitorId |trackingIds|emailIds 
+-----------+------------+-------- 
|a158|  [666b,666b,777c]|  [12,''] 
|7g21|  [c0b5,c0b4]  |  [45, 87] 

Cố gắng sử dụng groupByagg nhà khai thác nhưng không có nhiều may mắn.

Trả lời

17

Spark 2.x

Có thể nhưng khá tốn kém. Sử dụng dữ liệu mà bạn đã cung cấp:

case class Record(
    visitorId: String, trackingIds: Array[String], emailIds: Array[String]) 

val df = Seq(
    Record("a158", Array("666b"), Array("12")), 
    Record("7g21", Array("c0b5"), Array("45")), 
    Record("7g21", Array("c0b4"), Array("87")), 
    Record("a158", Array("666b", "777c"), Array.empty[String])).toDF 

và một hàm helper:

import org.apache.spark.sql.functions.udf 

val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten) 

chúng ta có thể điền vào chỗ trống với placeholders:

import org.apache.spark.sql.functions.{array, lit, when} 

val dfWithPlaceholders = df.withColumn(
    "emailIds", 
    when(size($"emailIds") === 0, array(lit(""))).otherwise($"emailIds")) 

collect_listsflatten:

import org.apache.spark.sql.functions.{array, collect_listn} 

val emailIds = flatten(collect_list($"emailIds")).alias("emailIds") 
val trackingIds = flatten(collect_list($"trackingIds")).alias("trackingIds") 

df 
    .groupBy($"visitorId") 
    .agg(trackingIds, emailIds) 

// +---------+------------------+--------+ 
// |visitorId|  trackingIds|emailIds| 
// +---------+------------------+--------+ 
// |  a158|[666b, 666b, 777c]| [12, ]| 
// |  7g21|  [c0b5, c0b4]|[45, 87]| 
// +---------+------------------+--------+ 

Với tĩnh gõ Dataset:

df.as[Record] 
    .groupByKey(_.visitorId) 
    .mapGroups { case (key, vs) => 
    vs.map(v => (v.trackingIds, v.emailIds)).toArray.unzip match { 
     case (trackingIds, emailIds) => 
     Record(key, trackingIds.flatten, emailIds.flatten) 
    }} 

// +---------+------------------+--------+ 
// |visitorId|  trackingIds|emailIds| 
// +---------+------------------+--------+ 
// |  a158|[666b, 666b, 777c]| [12, ]| 
// |  7g21|  [c0b5, c0b4]|[45, 87]| 
// +---------+------------------+--------+ 

Spark 1.x

Bạn thể chuyển đổi sang RDD nhóm

import org.apache.spark.sql.Row 

dfWithPlaceholders.rdd 
    .map { 
    case Row(id: String, 
     trcks: Seq[String @ unchecked], 
     emails: Seq[String @ unchecked]) => (id, (trcks, emails)) 
    } 
    .groupByKey 
    .map {case (key, vs) => vs.toArray.unzip match { 
    case (trackingIds, emailIds) => 
     Record(key, trackingIds.flatten, emailIds.flatten) 
    }} 
    .toDF 

// +---------+------------------+--------+ 
// |visitorId|  trackingIds|emailIds| 
// +---------+------------------+--------+ 
// |  7g21|  [c0b5, c0b4]|[45, 87]| 
// |  a158|[666b, 666b, 777c]| [12, ]| 
// +---------+------------------+--------+ 
+0

gì flatten phương pháp nào chính xác? – xXxpRoGrAmmErxXx

+0

Điều gì sẽ xảy ra nếu chúng tôi phải xóa các bản sao trong 'trackingIds'? – puru

6

@ câu trả lời zero323 là khá nhiều hoàn tất, nhưng Spark cho chúng tôi linh hoạt hơn. Làm thế nào về các giải pháp sau đây?

import org.apache.spark.sql.functions._ 
inventory 
    .select($"*", explode($"trackingIds") as "tracking_id") 
    .select($"*", explode($"emailIds") as "email_id") 
    .groupBy("visitorId") 
    .agg(
    collect_list("tracking_id") as "trackingIds", 
    collect_list("email_id") as "emailIds") 

Đó tuy nhiên lá ra tất cả các bộ sưu tập sản phẩm nào (vì vậy có một số phòng để cải thiện :))

+1

Trong giải pháp này, có thể áp dụng một orderBy() sau nhóm group và trước agg() không? Hoặc trong trường hợp này, trật tự sẽ không xác định? –

+0

Theo ý kiến ​​của tôi, bạn trả lời nó không phải là cuốn tiểu thuyết vì những lý do sau a) phát hiện không được dùng nữa trong spark.2.2.b) collect_list trên một tập dữ liệu rất lớn có thể làm hỏng quá trình điều khiển với OutOfMemoryError – xXxpRoGrAmmErxXx

+0

@xXxpRoGrAmmErxXx Xin đừng nhầm lẫn với toán tử 'explode' và hàm' explode'. Đối với b) có thể. –

0

Bạn có thể sử dụng Người dùng xác định chức năng được tổng hợp.

1) tạo UDAF tùy chỉnh bằng cách sử dụng lớp scala có tên là customAggregation.

package com.package.name 

import org.apache.spark.sql.Row 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 
import org.apache.spark.sql.types._ 
import scala.collection.JavaConverters._ 

class CustomAggregation() extends UserDefinedAggregateFunction { 

// Input Data Type Schema 
def inputSchema: StructType = StructType(Array(StructField("col5", ArrayType(StringType)))) 

// Intermediate Schema 
def bufferSchema = StructType(Array(
StructField("col5_collapsed", ArrayType(StringType)))) 

// Returned Data Type . 
def dataType: DataType = ArrayType(StringType) 

// Self-explaining 
def deterministic = true 

// This function is called whenever key changes 
def initialize(buffer: MutableAggregationBuffer) = { 
buffer(0) = Array.empty[String] // initialize array 
} 

// Iterate over each entry of a group 
def update(buffer: MutableAggregationBuffer, input: Row) = { 
buffer(0) = 
    if(!input.isNullAt(0)) 
    buffer.getList[String](0).toArray ++ input.getList[String](0).toArray 
    else 
    buffer.getList[String](0).toArray 
} 

    // Merge two partial aggregates 
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { 
buffer1(0) = buffer1.getList[String](0).toArray ++ buffer2.getList[String](0).toArray 
} 

// Called after all the entries are exhausted. 
def evaluate(buffer: Row) = { 
    buffer.getList[String](0).asScala.toList.distinct 
} 
} 

2) Sau đó sử dụng UDAF trong mã của bạn như

//define UDAF 
val CustomAggregation = new CustomAggregation() 
DataFrame 
    .groupBy(col1,col2,col3) 
    .agg(CustomAggregation(DataFrame(col5))).show() 
Các vấn đề liên quan