Tôi đang giải quyết vấn đề phân loại nhiều lớp và cố gắng sử dụng Mô hình tổng quát được tăng cường (gói gbm trong R). Vấn đề tôi gặp phải: chức năng train
của caret với method="gbm"
dường như không hoạt động với dữ liệu đa giác đúng cách. Một ví dụ đơn giản được trình bày dưới đây.Cách sử dụng dấu mũ với phương pháp gbm để phân loại nhiều lớp
library(gbm)
library(caret)
data(iris)
fitControl <- trainControl(method="repeatedcv",
number=5,
repeats=1,
verboseIter=TRUE)
set.seed(825)
gbmFit <- train(Species ~ ., data=iris,
method="gbm",
trControl=fitControl,
verbose=FALSE)
gbmFit
Đầu ra là
+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150
...
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150
Aggregating results
Selecting tuning parameters
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set
Error in if (interaction.depth < 1) { : argument is of length zero
Tuy nhiên, nếu tôi cố gắng sử dụng GBM mà không wrapper caret, tôi nhận được kết quả tốt đẹp.
set.seed(1365)
train <- createDataPartition(iris$Species, p=0.7, list=F)
train.iris <- iris[train,]
valid.iris <- iris[-train,]
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE)
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response")
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##!
confusionMatrix(gbm.pred, valid.iris$Species)$overall
FYI, mã trên đường đánh dấu bằng ##!
chuyển đổi một ma trận xác suất lớp trả về bởi predict.gbm
đến một yếu tố của hầu hết các lớp học có thể xảy ra. Đầu ra là
Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue
9.111111e-01 8.666667e-01 7.877883e-01 9.752470e-01 3.333333e-01 8.467252e-16 NaN
Bất kỳ đề xuất nào làm cho dấu mũ hoạt động đúng cách với gbm trên dữ liệu đa giác?
UPD:
sessionInfo()
R version 2.15.3 (2013-03-01)
Platform: x86_64-pc-linux-gnu (64-bit)
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=C LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] splines stats graphics grDevices utils datasets methods base
other attached packages:
[1] e1071_1.6-1 class_7.3-5 gbm_2.0-8 survival_2.36-14 caret_5.15-61 reshape2_1.2.2 plyr_1.8
[8] lattice_0.20-13 foreach_1.4.0 cluster_1.14.3 compare_0.2-3
loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3 iterators_1.0.6 stringr_0.6.2 tools_2.15.3
Chỉ cần một câu hỏi, tại sao bạn đang sử dụng 2 giống khác nhau? 825 và 1365? – agstudy
Có quan trọng không? 825 - là một hạt giống từ một mã ví dụ tôi lấy mẫu [caret.r-forge.r-project.org] (http://caret.r-forge.r-project.org/training.html), 1365 - hạt giống Tôi đã sử dụng trong dự án của mình. – maruan