Nó không phải là một giải pháp đầy đủ tính năng nhưng bạn có thể bắt đầu với một cái gì đó như thế này:
import org.apache.spark.ml.{UnaryTransformer}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
class NGramTokenizer(override val uid: String)
extends UnaryTransformer[String, Seq[String], NGramTokenizer] {
def this() = this(Identifiable.randomUID("ngramtokenizer"))
override protected def createTransformFunc: String => Seq[String] = {
getFeatures _
}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType)
}
override protected def outputDataType: DataType = {
new ArrayType(StringType, true)
}
}
Kiểm tra nhanh:
val df = Seq((1L, "abcdef"), (2L, "foobar")).toDF("k", "v")
val transformer = new NGramTokenizer().setInputCol("v").setOutputCol("vs")
transformer.transform(df).show
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
Bạn thậm chí có thể cố gắng khái quát nó để cái gì đó như này:
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import scala.reflect.runtime.universe._
class UnaryUDFTransformer[T : TypeTag, U : TypeTag](
override val uid: String,
f: T => U
) extends UnaryTransformer[T, U, UnaryUDFTransformer[T, U]] {
override protected def createTransformFunc: T => U = f
override protected def validateInputType(inputType: DataType): Unit =
require(inputType == schemaFor[T].dataType)
override protected def outputDataType: DataType = schemaFor[U].dataType
}
val transformer = new UnaryUDFTransformer("featurize", getFeatures)
.setInputCol("v")
.setOutputCol("vs")
Nếu bạn muốn sử dụng UDF không phải là chức năng bọc, bạn sẽ phải mở rộng Transformer
trực tiếp và ghi đè phương thức transform
. Thật không may phần lớn các lớp hữu ích là riêng tư nên nó có thể khá khó khăn.
Hoặc bạn có thể đăng ký UDF:
spark.udf.register("getFeatures", getFeatures _)
và sử dụng SQLTransformer
import org.apache.spark.ml.feature.SQLTransformer
val transformer = new SQLTransformer()
.setStatement("SELECT *, getFeatures(v) AS vs FROM __THIS__")
transformer.transform(df).show
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
Tôi đã cố gắng lưu mô hình của mình nhưng có nội dung 'Thông báo: Viết đường ống sẽ không thành công trên Đường ống này vì nó chứa một giai đoạn không thực thi được Ghi. Giai đoạn không ghi được: ngramtokenizer_f784079e2124 của lớp loại' tôi có phải thực hiện giao diện Ghi được không? –
Đây là phần xấu mà tôi đã đề cập trước đây. Theo như tôi biết cách tiếp cận tốt nhất là triển khai 'DefaultParamsWritable' và' DefaultParamsReadable' nhưng nó sẽ không thể thực hiện được nếu không đặt ít nhất một phần mã của bạn trong gói ML. Bạn có thể thử với 'MLWritable' /' MLReadable'. – zero323