Bạn có thể thử với một chút NumPy và RDDs. Lần đầu tiên một loạt các hàng nhập khẩu:
from operator import itemgetter
import numpy as np
from pyspark.statcounter import StatCounter
Hãy xác định một vài biến:
keys = ["key1", "key2", "key3"] # list of key column names
xs = ["x1", "x2", "x3"] # list of column names to compare
y = "y" # name of the reference column
Và một số người giúp đỡ:
def as_pair(keys, y, xs):
""" Given key names, y name, and xs names
return a tuple of key, array-of-values"""
key = itemgetter(*keys)
value = itemgetter(y, * xs) # Python 3 syntax
def as_pair_(row):
return key(row), np.array(value(row))
return as_pair_
def init(x):
""" Init function for combineByKey
Initialize new StatCounter and merge first value"""
return StatCounter().merge(x)
def center(means):
"""Center a row value given a
dictionary of mean arrays
"""
def center_(row):
key, value = row
return key, value - means[key]
return center_
def prod(arr):
return arr[0] * arr[1:]
def corr(stddev_prods):
"""Scale the row to get 1 stddev
given a dictionary of stddevs
"""
def corr_(row):
key, value = row
return key, value/stddev_prods[key]
return corr_
và chuyển đổi DataFrame
để RDD
các cặp:
pairs = df.rdd.map(as_pair(keys, y, xs))
Tiếp theo chúng ta hãy tính toán thống kê cho mỗi nhóm:
stats = (pairs
.combineByKey(init, StatCounter.merge, StatCounter.mergeStats)
.collectAsMap())
means = {k: v.mean() for k, v in stats.items()}
Note: Với 5000 tính năng và 7000 nhóm có nên không có vấn đề với việc giữ cấu trúc này trong bộ nhớ. Với bộ dữ liệu lớn hơn, bạn có thể phải sử dụng RDD và join
nhưng điều này sẽ chậm hơn.
Trung tâm dữ liệu:
centered = pairs.map(center(means))
Tính hiệp phương sai:
covariance = (centered
.mapValues(prod)
.combineByKey(init, StatCounter.merge, StatCounter.mergeStats)
.mapValues(StatCounter.mean))
Và cuối cùng tương quan:
stddev_prods = {k: prod(v.stdev()) for k, v in stats.items()}
correlations = covariance.map(corr(stddev_prods))
Ví dụ dữ liệu:
df = sc.parallelize([
("a", "b", "c", 0.5, 0.5, 0.3, 1.0),
("a", "b", "c", 0.8, 0.8, 0.9, -2.0),
("a", "b", "c", 1.5, 1.5, 2.9, 3.6),
("d", "e", "f", -3.0, 4.0, 5.0, -10.0),
("d", "e", "f", 15.0, -1.0, -5.0, 10.0),
]).toDF(["key1", "key2", "key3", "y", "x1", "x2", "x3"])
Kết quả với DataFrame
:
df.groupBy(*keys).agg(*[corr(y, x) for x in xs]).show()
+----+----+----+-----------+------------------+------------------+
|key1|key2|key3|corr(y, x1)| corr(y, x2)| corr(y, x3)|
+----+----+----+-----------+------------------+------------------+
| d| e| f| -1.0| -1.0| 1.0|
| a| b| c| 1.0|0.9972300220940342|0.6513360726920862|
+----+----+----+-----------+------------------+------------------+
và phương pháp cung cấp ở trên:
correlations.collect()
[(('a', 'b', 'c'), array([ 1. , 0.99723002, 0.65133607])),
(('d', 'e', 'f'), array([-1., -1., 1.]))]
Giải pháp này, trong khi một chút liên quan, là khá đàn hồi và có thể dễ dàng điều chỉnh để xử lý phân phối dữ liệu khác nhau. Nó cũng có thể được tăng thêm với JIT.
Nhóm có thể thay đổi từ 100 hàng đến 5-10 triệu, số lượng nhóm sẽ là 7000. – Harish