2013-01-08 23 views
5

tôi đang làm một tìm kiếm lưới trên dữ liệu multilabel như sau:GridSearch for OneVsRestClassifier Multilabel?

#imports 
from sklearn.svm import SVC as classifier 
from sklearn.pipeline import Pipeline 
from sklearn.decomposition import RandomizedPCA 
from sklearn.cross_validation import StratifiedKFold 
from sklearn.grid_search import GridSearchCV 

#classifier pipeline 
clf_pipeline = clf_pipeline = OneVsRestClassifier(
       Pipeline([('reduce_dim', RandomizedPCA()), 
          ('clf', classifier()) 
          ] 
         )) 

C_range = 10.0 ** np.arange(-2, 9) 
gamma_range = 10.0 ** np.arange(-5, 4) 
n_components_range = (10, 100, 200) 
degree_range = (1, 2, 3, 4) 

param_grid = dict(estimator__clf__gamma=gamma_range, 
        estimator__clf__c=c_range, 
        estimator__clf__degree=degree_range, 
        estimator__reduce_dim__n_components=n_components_range) 

grid = GridSearchCV(clf_pipeline, param_grid, 
           cv=StratifiedKFold(y=Y, n_folds=3), n_jobs=1, 
           verbose=2) 
grid.fit(X, Y) 

tôi nhìn thấy traceback sau:

/Users/andrewwinterman/Documents/sparks-honey/classifier/lib/python2.7/site-packages/sklearn/grid_search.pyc in fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func, score_func, verbose, **fit_params) 
    107 
    108  if y is not None: 
--> 109   y_test = y[safe_mask(y, test)] 
    110   y_train = y[safe_mask(y, train)] 
    111   clf.fit(X_train, y_train, **fit_params) 

TypeError: only integer arrays with one element can be converted to an index 

Hình như GridSearchCV đối tượng cho nhiều nhãn. Tôi nên làm việc như thế nào? Tôi có cần lặp lại một cách rõ ràng thông qua các lớp duy nhất với label_binarizer, chạy tìm kiếm lưới trên mỗi ước tính phụ không?

+0

Bạn đang sử dụng 0.12.1 hoặc 0.13? Tôi nghĩ rằng vấn đề sẽ biến mất khi nâng cấp lên 0,13. –

+0

Tôi đang sử dụng chi nhánh dev là 0.13. Tôi sẽ thử lại lần nữa. – Maus

+0

Nó sẽ hoạt động trong bản phát hành 0.13 và bản chính hiện tại. Nếu không, vui lòng mở một vấn đề trên github. –

Trả lời

6

Tôi nghĩ rằng có một lỗi trong grid_search.py ​​

Bạn đã cố gắng để cung cấp cho y như NumPy mảng?

import numpy as np 
Y = np.asarray(Y) 
+0

không, tôi đã thực sự ngừng làm việc với scikit tìm hiểu. Bạn sẽ phải tự mình thử bất kỳ giải pháp được đề xuất nào. Nếu bạn có thể chứng minh một công trình, tôi sẽ chấp nhận nó :) – Maus

+1

vừa giải quyết được vấn đề của tôi, cảm ơn. –

+0

@ ZéRicardo, vui mừng khi biết điều đó :) – Thorn