Tôi muốn đánh giá một khu rừng ngẫu nhiên đang được đào tạo về một số dữ liệu. Có bất kỳ tiện ích nào trong Apache Spark để làm tương tự hay tôi phải thực hiện xác thực chéo theo cách thủ công không?Làm cách nào để xác thực chéo mô hình RandomForest?
17
A
Trả lời
31
ML cung cấp CrossValidator
lớp có thể được sử dụng để thực hiện xác thực chéo và tìm kiếm thông số. Giả sử dữ liệu của bạn đã được xử lý trước bạn có thể thêm cross-validation như sau:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// [label: double, features: vector]
trainingData org.apache.spark.sql.DataFrame = ???
val nFolds: Int = ???
val NumTrees: Int = ???
val metric: String = ???
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(NumTrees)
val pipeline = new Pipeline().setStages(Array(rf))
val paramGrid = new ParamGridBuilder().build() // No parameter search
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
// "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
.setMetricName(metric)
val cv = new CrossValidator()
// ml.Pipeline with ml.classification.RandomForestClassifier
.setEstimator(pipeline)
// ml.evaluation.MulticlassClassificationEvaluator
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(nFolds)
val model = cv.fit(trainingData) // trainingData: DataFrame
Sử dụng PySpark:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer
rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala
pipeline = Pipeline(stages=[rf])
crossval = CrossValidator(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=numFolds)
model = crossval.fit(trainingData)
1
Để xây dựng dựa trên câu trả lời tuyệt vời zero323 bằng cách sử dụng rừng ngẫu nhiên Phân loại, đây là một ví dụ tương tự cho ngẫu nhiên rừng Regressor:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.regression.RandomForestRegressor // CHANGED
import org.apache.spark.ml.evaluation.RegressionEvaluator // CHANGED
import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer}
val numFolds = ??? // Integer
val data = ??? // DataFrame
// Training (80%) and test data (20%)
val Array(train, test) = data.randomSplit(Array(0.8,0.2))
val featuresCols = data.columns
val va = new VectorAssembler()
va.setInputCols(featuresCols)
va.setOutputCol("rawFeatures")
val vi = new VectorIndexer()
vi.setInputCol("rawFeatures")
vi.setOutputCol("features")
vi.setMaxCategories(5)
val regressor = new RandomForestRegressor()
regressor.setLabelCol("events")
val metric = "rmse"
val evaluator = new RegressionEvaluator()
.setLabelCol("events")
.setPredictionCol("prediction")
// "rmse" (default): root mean squared error
// "mse": mean squared error
// "r2": R2 metric
// "mae": mean absolute error
.setMetricName(metric)
val paramGrid = new ParamGridBuilder().build()
val cv = new CrossValidator()
.setEstimator(regressor)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(numFolds)
val model = cv.fit(train) // train: DataFrame
val predictions = model.transform(test)
predictions.show
val rmse = evaluator.evaluate(predictions)
println(rmse)
Evaluator nguồn số liệu: https://spark.apache.org/docs/latest/api/scala/#org.apache.spark.ml.evaluation.RegressionEvaluator
Các vấn đề liên quan
- 1. Xác thực chéo mô hình CART
- 2. guess_proba cho một mô hình được xác thực chéo
- 3. Làm cách nào để bạn điền/xác thực Mô hình Chế độ xem của mình?
- 4. Làm cách nào để xác thực chỉ một phần của mô hình trong ASP .NET MVC?
- 5. Làm cách nào để kiểm tra lỗi xác thực Mô hình trong asp.net mvc?
- 6. Làm cách nào để có được hình chữ nhật r được xác thực chéo từ mô hình tuyến tính trong R?
- 7. Tạo các bộ để xác thực chéo
- 8. Làm thế nào để xuất ra RandomForest Classifier từ python?
- 9. Xác thực đối tượng mô hình Python
- 10. Sử dụng hình thức xác thực chéo miền
- 11. Làm cách nào để tạo trình tạo xác thực chéo tùy chỉnh trong việc tìm hiểu?
- 12. cách xác thực id mô hình được liên kết?
- 13. làm thế nào để sử dụng sed thay thế mô hình chuỗi với dấu chéo ngược
- 14. Làm cách nào để bạn thực hiện xác thực Mẫu mà không có mô hình trong CakePHP?
- 15. Làm thế nào để vô hiệu hóa xác thực mô hình MVC 4?
- 16. Xác thực mô hình xương sống
- 17. Làm thế nào để thực hiện xác thực chéo k-fold với tensorflow?
- 18. Làm thế nào để dịch xác thực mô hình bản ghi hoạt động
- 19. Làm cách nào để xác thực ngày trong đường ray?
- 20. Cách chính xác để xác thực đối tượng mô hình Django?
- 21. Làm cách nào để xác định sự khác biệt giữa mô hình kinh doanh và mô hình dữ liệu?
- 22. Các mối quan hệ chéo mô hình trong NSManagedObjectModel từ các mô hình đã hợp nhất?
- 23. R randomForest để phân loại
- 24. Làm cách nào để bao gồm mô hình với RedirectToAction?
- 25. mô hình chồng chéo phù hợp
- 26. Tại sao sklearn mô hình RandomForest chiếm nhiều không gian đĩa sau khi lưu?
- 27. Làm cách nào để hiển thị thông báo lỗi xác thực bằng Simple_form nhưng không có mô hình?
- 28. Làm cách nào để xác thực hai giá trị đó không bằng nhau trong một mô hình Rails?
- 29. Làm cách nào để xử lý xác thực bằng Devise khi sử dụng nhiều mô hình trong Rails 3.2 App
- 30. Làm cách nào để hiển thị lỗi xác thực từ mô hình được liên kết trong Rails?
Bạn có chắc chắn rằng điều này có hiệu quả đối với việc bỏ đi không? Cuộc gọi kFold() dưới mui xe không xuất hiện để xác định trở lại hai lần chiều dài N-1 và 1. Khi tôi chạy đoạn mã trên bằng một bộ hồi quy RegressionEvaluator và Lasso, tôi nhận được: Ngoại lệ trong chủ đề "chính" java.lang .IllegalArgumentException: yêu cầu không thành công: Không có gì được thêm vào bản tóm tắt này. – paradiso
Không, tôi khá chắc chắn là không. 'MLUtils.kFold' đang sử dụng' BernoulliCellSampler' để xác định chia nhỏ. Mặt khác, chi phí thực hiện việc bỏ qua xác thực chéo một lần trong Spark có lẽ là cao để làm cho nó khả thi trong thực tế. – zero323
Xin chào @ zero323, khi bạn đặt số liệu trong đối tượng Đánh giá của bạn như .setMetricName ("precision"). Câu hỏi của tôi là, làm cách nào để tôi có được số liệu được tính toán trong quá trình đào tạo? (Vui lòng tham khảo câu hỏi này: http://stackoverflow.com/questions/37778532/how-to-get-precision-recall-using-crossvalidator-for-training-naivebayes-model-u) – dbustosp