2012-05-29 23 views
13

Tôi đang cố gắng sử dụng các ràng buộc java cho libsvm:libsvm thi java

http://www.csie.ntu.edu.tw/~cjlin/libsvm/ 

tôi đã thực hiện một ví dụ 'tầm thường' mà có thể dễ dàng tuyến tính tách trong y. Dữ liệu được định nghĩa là:

double[][] train = new double[1000][]; 
double[][] test = new double[10][]; 

for (int i = 0; i < train.length; i++){ 
    if (i+1 > (train.length/2)){  // 50% positive 
     double[] vals = {1,0,i+i}; 
     train[i] = vals; 
    } else { 
     double[] vals = {0,0,i-i-i-2}; // 50% negative 
     train[i] = vals; 
    }   
} 

Trường hợp 'tính năng' đầu tiên là lớp và tập huấn được xác định tương tự.

Đào tạo mô hình:

private svm_model svmTrain() { 
    svm_problem prob = new svm_problem(); 
    int dataCount = train.length; 
    prob.y = new double[dataCount]; 
    prob.l = dataCount; 
    prob.x = new svm_node[dataCount][];  

    for (int i = 0; i < dataCount; i++){    
     double[] features = train[i]; 
     prob.x[i] = new svm_node[features.length-1]; 
     for (int j = 1; j < features.length; j++){ 
      svm_node node = new svm_node(); 
      node.index = j; 
      node.value = features[j]; 
      prob.x[i][j-1] = node; 
     }   
     prob.y[i] = features[0]; 
    }    

    svm_parameter param = new svm_parameter(); 
    param.probability = 1; 
    param.gamma = 0.5; 
    param.nu = 0.5; 
    param.C = 1; 
    param.svm_type = svm_parameter.C_SVC; 
    param.kernel_type = svm_parameter.LINEAR;  
    param.cache_size = 20000; 
    param.eps = 0.001;  

    svm_model model = svm.svm_train(prob, param); 

    return model; 
} 

Sau đó, để đánh giá các mô hình tôi sử dụng:

public int evaluate(double[] features) { 
    svm_node node = new svm_node(); 
    for (int i = 1; i < features.length; i++){ 
     node.index = i; 
     node.value = features[i]; 
    } 
    svm_node[] nodes = new svm_node[1]; 
    nodes[0] = node; 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(_model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(_model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return (int)v; 
} 

Trong trường hợp mảng trôi qua là một điểm từ tập thử nghiệm.

Các kết quả được luôn trở về lớp 0. Với những kết quả chính xác con người:

(0:0.9882998314585194)(1:0.011700168541480586)(Actual:0.0 Prediction:0.0) 
(0:0.9883952943701599)(1:0.011604705629839989)(Actual:0.0 Prediction:0.0) 
(0:0.9884899803606306)(1:0.011510019639369528)(Actual:0.0 Prediction:0.0) 
(0:0.9885838957058696)(1:0.011416104294130458)(Actual:0.0 Prediction:0.0) 
(0:0.9886770466322342)(1:0.011322953367765776)(Actual:0.0 Prediction:0.0) 
(0:0.9870913229268679)(1:0.012908677073132284)(Actual:1.0 Prediction:0.0) 
(0:0.9868781382588805)(1:0.013121861741119505)(Actual:1.0 Prediction:0.0) 
(0:0.986661444476744)(1:0.013338555523255982)(Actual:1.0 Prediction:0.0) 
(0:0.9864411843906802)(1:0.013558815609319848)(Actual:1.0 Prediction:0.0) 
(0:0.9862172999068877)(1:0.013782700093112332)(Actual:1.0 Prediction:0.0) 

Ai đó có thể giải thích tại sao phân loại này không hoạt động? Có một bước tôi đã làm rối tung, hoặc một bước tôi bị mất?

Cảm ơn

Trả lời

13

dường như với tôi rằng phương pháp đánh giá của bạn là sai. Nên một cái gì đó như thế này:

public double evaluate(double[] features, svm_model model) 
{ 
    svm_node[] nodes = new svm_node[features.length-1]; 
    for (int i = 1; i < features.length; i++) 
    { 
     svm_node node = new svm_node(); 
     node.index = i; 
     node.value = features[i]; 

     nodes[i-1] = node; 
    } 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return v; 
} 
+4

Bạn có thể giải thích sai sót trong mã câu hỏi là gì không? Tôi đang gặp vấn đề trong việc phát hiện lỗi! :( – Daniel

1

tôi đã thực hiện một phiên bản hơi refactored thực hiện java LibSVM của mà bạn có thể tìm thấy dễ dàng hơn để sử dụng: https://github.com/syeedibnfaiz/libsvm-java-kernel. Hãy xem lớp Demo.java để biết cách sử dụng nó.

2

Đây là một Rework của ví dụ trên mà tôi đã thử nghiệm sử dụng dữ liệu từ mã R sau: http://cbio.ensmp.fr/~jvert/svn/tutorials/practical/svmbasic/svmbasic_notes.pdf

import libsvm.*; 

public class libsvmTest { 

    public static void main(String [] args) { 

     double[][] xtrain = ... 
     double[][] xtest = ... 
     double[][] ytrain = ... 
     double[][] ytest = ... 

     svm_model m = svmTrain(xtrain,ytrain); 

     double[] ypred = svmPredict(xtest, m); 

     for (int i = 0; i < xtest.length; i++){ 
      System.out.println("(Actual:" + ytest[i][0] + " Prediction:" + ypred[i] + ")"); 
     } 

    } 

    static svm_model svmTrain(double[][] xtrain, double[][] ytrain) { 
     svm_problem prob = new svm_problem(); 
     int recordCount = xtrain.length; 
     int featureCount = xtrain[0].length; 
     prob.y = new double[recordCount]; 
     prob.l = recordCount; 
     prob.x = new svm_node[recordCount][featureCount];  

     for (int i = 0; i < recordCount; i++){    
      double[] features = xtrain[i]; 
      prob.x[i] = new svm_node[features.length]; 
      for (int j = 0; j < features.length; j++){ 
       svm_node node = new svm_node(); 
       node.index = j; 
       node.value = features[j]; 
       prob.x[i][j] = node; 
      }   
      prob.y[i] = ytrain[i][0]; 
     }    

     svm_parameter param = new svm_parameter(); 
     param.probability = 1; 
     param.gamma = 0.5; 
     param.nu = 0.5; 
     param.C = 100; 
     param.svm_type = svm_parameter.C_SVC; 
     param.kernel_type = svm_parameter.LINEAR;  
     param.cache_size = 20000; 
     param.eps = 0.001;  

     svm_model model = svm.svm_train(prob, param); 

     return model; 
    } 

    static double[] svmPredict(double[][] xtest, svm_model model) 
    { 

     double[] yPred = new double[xtest.length]; 

     for(int k = 0; k < xtest.length; k++){ 

     double[] fVector = xtest[k]; 

     svm_node[] nodes = new svm_node[fVector.length]; 
     for (int i = 0; i < fVector.length; i++) 
     { 
      svm_node node = new svm_node(); 
      node.index = i; 
      node.value = fVector[i]; 
      nodes[i] = node; 
     } 

     int totalClasses = 2;  
     int[] labels = new int[totalClasses]; 
     svm.svm_get_labels(model,labels); 

     double[] prob_estimates = new double[totalClasses]; 
     yPred[k] = svm.svm_predict_probability(model, nodes, prob_estimates); 

     } 

     return yPred; 
    } 


} 

Đây là kết quả:

(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
+0

Cảm ơn bạn rất nhiều về mã hữu ích Tại sao bạn sử dụng param.probability = 1 ;? Và thứ hai, bạn có biết làm thế nào người ta có thể đặt trọng số nếu người ta có các lớp không cân bằng? – machinery

+0

Không prob_estimates mất phạm vi khi bạn gọi svm.svm_predict_probability()? – user1040535

+0

Đây chỉ đơn giản là một bài đăng để giúp bắt đầu với LIBSVM, từ đó, người dùng của nó phải xác định những gì phù hợp với vấn đề. Đối với câu hỏi liên quan đến điều này, tôi đề nghị bạn truy cập vào trang web của các maintaners của gói này: https://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q06:_Probability_outputs –

Các vấn đề liên quan