2016-03-23 17 views
5

Tôi cần phải tạo một UDF được sử dụng trong python pyspark trong đó sử dụng một đối tượng java cho các phép tính bên trong của nó.Thực hiện một UDF java và gọi nó từ pyspark

Nếu nó là một con trăn đơn giản, tôi sẽ làm một cái gì đó như:

def f(x): 
    return 7 
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType()) 

và gọi nó là sử dụng:

df = sqlContext.range(0,5) 
df2 = df.withColumn("a",fudf(df.id)).show() 

Tuy nhiên, việc thực hiện các chức năng tôi cần là trong java và không có trong python. Tôi cần bọc nó bằng cách nào đó để tôi có thể gọi nó theo cách tương tự từ python.

Lần thử đầu tiên của tôi là triển khai đối tượng java, sau đó bọc nó trong python trong pyspark và chuyển đổi thành UDF. Điều đó không thành công với lỗi tuần tự hóa. đang

Java: Mã

package com.test1.test2; 

public class TestClass1 { 
    Integer internalVal; 
    public TestClass1(Integer val1) { 
     internalVal = val1; 
    } 
    public Integer do_something(Integer val) { 
     return internalVal; 
    }  
} 

pyspark:

from py4j.java_gateway import java_import 
from pyspark.sql.functions import udf 
from pyspark.sql.types import IntegerType 
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
audf = udf(a,IntegerType()) 

lỗi:

--------------------------------------------------------------------------- 
Py4JError         Traceback (most recent call last) 
<ipython-input-2-9756772ab14f> in <module>() 
     4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
     5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
----> 6 audf = udf(a,IntegerType()) 

/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType) 
    1595  [Row(slen=5), Row(slen=3)] 
    1596  """ 
-> 1597  return UserDefinedFunction(f, returnType) 
    1598 
    1599 blacklist = ['map', 'since', 'ignore_unicode_prefix'] 

/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name) 
    1556   self.returnType = returnType 
    1557   self._broadcast = None 
-> 1558   self._judf = self._create_judf(name) 
    1559 
    1560  def _create_judf(self, name): 

/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name) 
    1565   command = (func, None, ser, ser) 
    1566   sc = SparkContext.getOrCreate() 
-> 1567   pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) 
    1568   ctx = SQLContext.getOrCreate(sc) 
    1569   jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) 

/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj) 
    2297  # the serialized command will be compressed by broadcast 
    2298  ser = CloudPickleSerializer() 
-> 2299  pickled_command = ser.dumps(command) 
    2300  if len(pickled_command) > (1 << 20): # 1M 
    2301   # The broadcast will have same life cycle as created PythonRDD 

/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj) 
    426 
    427  def dumps(self, obj): 
--> 428   return cloudpickle.dumps(obj, 2) 
    429 
    430 

/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol) 
    644 
    645  cp = CloudPickler(file,protocol) 
--> 646  cp.dump(obj) 
    647 
    648  return file.getvalue() 

/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj) 
    105   self.inject_addons() 
    106   try: 
--> 107    return Pickler.dump(self, obj) 
    108   except RuntimeError as e: 
    109    if 'recursion' in e.args[0]: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj) 
    222   if self.proto >= 2: 
    223    self.write(PROTO + chr(self.proto)) 
--> 224   self.save(obj) 
    225   self.write(STOP) 
    226 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    566   write(MARK) 
    567   for element in obj: 
--> 568    save(element) 
    569 
    570   if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name) 
    191   if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None: 
    192    #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) 
--> 193    self.save_function_tuple(obj) 
    194    return 
    195   else: 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func) 
    234   # create a skeleton function object and memoize it 
    235   save(_make_skel_func) 
--> 236   save((code, closure, base_globals)) 
    237   write(pickle.REDUCE) 
    238   self.memoize(func) 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    552   if n <= 3 and proto >= 2: 
    553    for element in obj: 
--> 554     save(element) 
    555    # Subtle. Same as in the big comment below. 
    556    if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj) 
    604 
    605   self.memoize(obj) 
--> 606   self._batch_appends(iter(obj)) 
    607 
    608  dispatch[ListType] = save_list 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items) 
    637     write(MARK) 
    638     for x in tmp: 
--> 639      save(x) 
    640     write(APPENDS) 
    641    elif n: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    304    reduce = getattr(obj, "__reduce_ex__", None) 
    305    if reduce: 
--> 306     rv = reduce(self.proto) 
    307    else: 
    308     reduce = getattr(obj, "__reduce__", None) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    811   answer = self.gateway_client.send_command(command) 
    812   return_value = get_return_value(
--> 813    answer, self.gateway_client, self.target_id, self.name) 
    814 
    815   for temp_arg in temp_args: 

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 
    43  def deco(*a, **kw): 
    44   try: 
---> 45    return f(*a, **kw) 
    46   except py4j.protocol.Py4JJavaError as e: 
    47    s = e.java_exception.toString() 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 
    310     raise Py4JError(
    311      "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". 
--> 312      format(target_id, ".", name, value)) 
    313   else: 
    314    raise Py4JError(

Py4JError: An error occurred while calling o18.__getnewargs__. Trace: 
py4j.Py4JException: Method __getnewargs__([]) does not exist 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335) 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344) 
    at py4j.Gateway.invoke(Gateway.java:252) 
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133) 
    at py4j.commands.CallCommand.execute(CallCommand.java:79) 
    at py4j.GatewayConnection.run(GatewayConnection.java:209) 
    at java.lang.Thread.run(Thread.java:745) 

EDIT: Tôi cũng đã cố gắng để làm cho lớp java serializable nhưng không có kết quả .

nỗ lực thứ hai của tôi là để xác định UDF trong java để bắt đầu với nhưng điều đó không thành công như tôi không chắc chắn làm thế nào để quấn nó một cách chính xác:

đang

java: gói com.test1.test2;

import org.apache.spark.sql.api.java.UDF1; 

public class TestClassUdf implements UDF1<Integer, Integer> { 

    Integer retval; 

    public TestClassUdf(Integer val) { 
     retval = val; 
    } 

    @Override 
    public Integer call(Integer arg0) throws Exception { 
     return retval; 
    } 
} 

nhưng làm cách nào để sử dụng? tôi đã cố gắng:

from py4j.java_gateway import java_import 
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf") 
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
dfint = sqlContext.range(0,15) 
df = dfint.withColumn("a",a(dfint.id)) 

nhưng tôi nhận được:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-5-514811090b5f> in <module>() 
     3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
     4 dfint = sqlContext.range(0,15) 
----> 5 df = dfint.withColumn("a",a(dfint.id)) 

TypeError: 'JavaObject' object is not callable 

và tôi cố gắng sử dụng a.call thay vì một:

df = dfint.withColumn("a",a.call(dfint.id)) 

nhưng có: ----- -------------------------------------------------- -------------------- LoạiError Traceback (cuộc gọi gần đây nhất) trong() 3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf (7) 4 dfint = sqlContext.range (0,15) ----> 5 df = dfint.withColumn ("a", a .call (dfint.id))

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    796  def __call__(self, *args): 
    797   if self.converters is not None and len(self.converters) > 0: 
--> 798    (new_args, temp_args) = self._get_args(args) 
    799   else: 
    800    new_args = args 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args) 
    783     for converter in self.gateway_client.converters: 
    784      if converter.can_convert(arg): 
--> 785       temp_arg = converter.convert(arg, self.gateway_client) 
    786       temp_args.append(temp_arg) 
    787       new_args.append(temp_arg) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client) 
    510   HashMap = JavaClass("java.util.HashMap", gateway_client) 
    511   java_map = HashMap() 
--> 512   for key in object.keys(): 
    513    java_map[key] = object[key] 
    514   return java_map 

TypeError: 'Column' object is not callable 

Bất kỳ trợ giúp nào sẽ được đánh giá cao.

Trả lời

3

Tôi đã làm việc này với sự trợ giúp của another question (and answer) of your own về UDAF.

Spark cung cấp phương thức udf() để gói Scala FunctionN, vì vậy chúng tôi có thể bọc chức năng Java trong Scala và sử dụng. Phương thức Java của bạn cần tĩnh hoặc trên một lớp mà implements Serializable.

package com.example 

import org.apache.spark.sql.UserDefinedFunction 
import org.apache.spark.sql.functions.udf 

class MyUdf extends Serializable { 
    def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod()) 
} 

Cách sử dụng trong PySpark:

def my_udf(): 
    from pyspark.sql.column import Column, _to_java_column, _to_seq 
    pcls = "com.example.MyUdf" 
    jc = sc._jvm.java.lang.Thread.currentThread() \ 
     .getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply 
    return Column(jc(_to_seq(sc, [], _to_java_column))) 

rdd1 = sc.parallelize([{'c1': 'a'}, {'c1': 'b'}, {'c1': 'c'}]) 
df1 = rdd1.toDF() 
df2 = df1.withColumn('mycol', my_udf()) 

Như với UDAF trong câu hỏi và câu trả lời khác của bạn, chúng tôi có thể vượt qua cột vào nó với return Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column)))

Các vấn đề liên quan