2013-03-23 21 views
13

Tôi đang giải quyết vấn đề phân loại nhiều lớp và cố gắng sử dụng Mô hình tổng quát được tăng cường (gói gbm trong R). Vấn đề tôi gặp phải: chức năng train của caret với method="gbm" dường như không hoạt động với dữ liệu đa giác đúng cách. Một ví dụ đơn giản được trình bày dưới đây.Cách sử dụng dấu mũ với phương pháp gbm để phân loại nhiều lớp

library(gbm) 
library(caret) 
data(iris) 
fitControl <- trainControl(method="repeatedcv", 
          number=5, 
          repeats=1, 
          verboseIter=TRUE) 
set.seed(825) 
gbmFit <- train(Species ~ ., data=iris, 
       method="gbm", 
       trControl=fitControl, 
       verbose=FALSE) 
gbmFit 

Đầu ra là

+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150 
... 
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
Aggregating results 
Selecting tuning parameters 
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set 
Error in if (interaction.depth < 1) { : argument is of length zero 

Tuy nhiên, nếu tôi cố gắng sử dụng GBM mà không wrapper caret, tôi nhận được kết quả tốt đẹp.

set.seed(1365) 
train <- createDataPartition(iris$Species, p=0.7, list=F) 
train.iris <- iris[train,] 
valid.iris <- iris[-train,] 
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE) 
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response") 
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##! 
confusionMatrix(gbm.pred, valid.iris$Species)$overall 

FYI, mã trên đường đánh dấu bằng ##! chuyển đổi một ma trận xác suất lớp trả về bởi predict.gbm đến một yếu tố của hầu hết các lớp học có thể xảy ra. Đầu ra là

 Accuracy   Kappa AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue 
    9.111111e-01 8.666667e-01 7.877883e-01 9.752470e-01 3.333333e-01 8.467252e-16   NaN 

Bất kỳ đề xuất nào làm cho dấu mũ hoạt động đúng cách với gbm trên dữ liệu đa giác?

UPD:

sessionInfo() 
R version 2.15.3 (2013-03-01) 
Platform: x86_64-pc-linux-gnu (64-bit) 

locale: 
[1] LC_CTYPE=en_US.UTF-8  LC_NUMERIC=C    LC_TIME=en_US.UTF-8  LC_COLLATE=en_US.UTF-8  
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=C     LC_NAME=C     
[9] LC_ADDRESS=C    LC_TELEPHONE=C    LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C  

attached base packages: 
[1] splines stats  graphics grDevices utils  datasets methods base  

other attached packages: 
[1] e1071_1.6-1  class_7.3-5  gbm_2.0-8  survival_2.36-14 caret_5.15-61 reshape2_1.2.2 plyr_1.8   
[8] lattice_0.20-13 foreach_1.4.0 cluster_1.14.3 compare_0.2-3 

loaded via a namespace (and not attached): 
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3  iterators_1.0.6 stringr_0.6.2 tools_2.15.3 
+0

Chỉ cần một câu hỏi, tại sao bạn đang sử dụng 2 giống khác nhau? 825 và 1365? – agstudy

+1

Có quan trọng không? 825 - là một hạt giống từ một mã ví dụ tôi lấy mẫu [caret.r-forge.r-project.org] (http://caret.r-forge.r-project.org/training.html), 1365 - hạt giống Tôi đã sử dụng trong dự án của mình. – maruan

Trả lời

6

Đây là vấn đề mà tôi đang làm việc trên ngay bây giờ.

Sẽ hữu ích nếu bạn đăng kết quả của sessionInfo().

Ngoài ra, hãy tắt gbm mới nhất của https://code.google.com/p/gradientboostedmodels/ có thể giải quyết được sự cố.

Max

+0

Sự cố có liên quan đến https://code.google.com/p/gradientboostedmodels/issues/detail?id=12. Tôi có một công việc xung quanh, nhưng tôi muốn tránh nó vì nó chỉ là một vấn đề với dữ liệu đa thức. Tôi sẽ liên lạc với người duy trì một lần nữa để xem liệu có eta hay không. – topepo

+0

Dường như có một vấn đề nổi tiếng về việc tải lại các tài liệu sau khi cập nhật gbm với devtools https://github.com/hadley/devtools/issues/419 cũng ảnh hưởng đến điều này. –

3

Cập nhật: Caret thể làm phân loại đa lớp.

Bạn nên đảm bảo rằng nhãn lớp ở định dạng số-alpha (bắt đầu bằng chữ cái).

Ví dụ: nếu dữ liệu của bạn có nhãn "1", "2", "3" thì hãy thay đổi các nhãn này thành "Seg1", "Seg2" và "Seg3", nếu không thì sẽ bị lỗi.

2

Cập nhật: Mã gốc không chạy và xuất ra như sau

+ Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150 
- Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150 
... 
... 
... 
+ Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150 
- Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150 
Aggregating results 
Selecting tuning parameters 
Fitting n.trees = 50, interaction.depth = 2, shrinkage = 0.1 on full training set 
> gbmFit 
Stochastic Gradient Boosting 

150 samples 
    4 predictor 
    3 classes: 'setosa', 'versicolor', 'virginica' 

No pre-processing 
Resampling: Cross-Validated (5 fold, repeated 1 times) 

Summary of sample sizes: 120, 120, 120, 120, 120 

Resampling results across tuning parameters: 

    interaction.depth n.trees Accuracy Kappa Accuracy SD 
    1     50  0.9400000 0.91 0.04346135 
    1     100  0.9400000 0.91 0.03651484 
    1     150  0.9333333 0.90 0.03333333 
    2     50  0.9533333 0.93 0.04472136 
    2     100  0.9533333 0.93 0.05055250 
    2     150  0.9466667 0.92 0.04472136 
    3     50  0.9333333 0.90 0.03333333 
    3     100  0.9466667 0.92 0.04472136 
    3     150  0.9400000 0.91 0.03651484 
    Kappa SD 
    0.06519202 
    0.05477226 
    0.05000000 
    0.06708204 
    0.07582875 
    0.06708204 
    0.05000000 
    0.06708204 
    0.05477226 

Tuning parameter 'shrinkage' was held constant at a value of 0.1 
Accuracy was used to select the optimal model using the 
largest value. 
The final values used for the model were n.trees = 
50, interaction.depth = 2 and shrinkage = 0.1. 
> summary(gbmFit) 
         var rel.inf 
Petal.Length Petal.Length 74.1266408 
Petal.Width Petal.Width 22.0668983 
Sepal.Width Sepal.Width 3.2209288 
Sepal.Length Sepal.Length 0.5855321 
Các vấn đề liên quan