2014-05-09 21 views
14

Tôi có hai vấn đề với sự hiểu biết kết quả của cây quyết định từ scikit-tìm hiểu. Ví dụ: đây là một trong những cây quyết định của tôi:làm thế nào để giải thích cây quyết định từ scikit-tìm hiểu

enter image description here Câu hỏi của tôi là cách tôi có thể sử dụng cây?

Câu hỏi đầu tiên là: nếu một mẫu thỏa mãn điều kiện, sau đó nó đi vào LEFT chi nhánh (nếu có), nếu không nó đi QUYỀN. Trong trường hợp của tôi, nếu một mẫu có X [7]> 63521.3984. Sau đó, mẫu sẽ chuyển đến hộp màu xanh lục. Chính xác?

Câu hỏi thứ hai là: khi mẫu đạt đến nút lá, làm thế nào tôi có thể biết được danh mục của nó? Trong ví dụ này, tôi có ba loại để phân loại. Trong hộp màu đỏ, có 91, 212, và 113 mẫu thỏa mãn điều kiện, tương ứng. Nhưng làm cách nào tôi có thể quyết định danh mục? Tôi biết có chức năng clf.predict (mẫu) để cho biết danh mục. Tôi có thể làm điều đó từ đồ thị không ??? Rất cám ơn.

+1

Ngoài sự tò mò, bạn đã vẽ cây quyết định như thế nào? – Matt

+4

Xuất lần đầu tiên cây sang định dạng JSON (xem [link] này (http://www.garysieling.com/blog/rending-scikit-decision-trees-d3-js)) và sau đó vẽ cây bằng cách sử dụng d3.js . Hoặc bạn có thể trực tiếp sử dụng hàm được nhúng: 'tree.export_graphviz (clf, out_file = your_out_file, feature_names = your_feature_names)' Hy vọng nó hoạt động, @Matt –

Trả lời

21

Dòng value trong mỗi hộp cho bạn biết số lượng mẫu tại nút đó rơi vào từng danh mục theo thứ tự. Đó là lý do tại sao, trong mỗi hộp, các số trong số value thêm tối đa số được hiển thị trong sample. Ví dụ: trong hộp màu đỏ của bạn, 91 + 212 + 113 = 416. Vì vậy, điều này có nghĩa là nếu bạn đạt đến nút này, có 91 điểm dữ liệu trong danh mục 1, 212 trong loại 2 và 113 trong danh mục 3.

Nếu bạn dự đoán kết quả cho một điểm dữ liệu mới đạt tới lá đó trong cây quyết định, bạn sẽ dự đoán loại 2, vì đó là danh mục phổ biến nhất cho các mẫu tại nút đó.

+0

Tôi muốn biết giá trị nào thuộc về lớp nào. 'DecisionTreeClassifier.classes' giữ thông tin này. – ezdazuzena

+0

(Câu trả lời hữu ích) Để làm rõ bằng cách sử dụng chỉ mục python: một mẫu đích trong hộp màu đỏ sẽ được dự đoán (đếm 212) làm loại 1, thay vì loại 0 (91) hoặc loại 2 (113) :-)) –

0

Theo sách "Học scikit-learning: Học máy bằng Python", Cây quyết định thể hiện một loạt các quyết định dựa trên dữ liệu đào tạo.

! (http://i.imgur.com/vM9fJLy.png)

Để phân loại một ví dụ, chúng ta nên trả lời câu hỏi tại mỗi nút. Ví dụ: Giới tính là < = 0,5? (chúng ta đang nói về một người phụ nữ?). Nếu câu trả lời là có, bạn đi đến nút con bên trái trong cây; nếu không bạn đi đến nút con bên phải. Bạn tiếp tục trả lời câu hỏi (là cô ấy ở lớp thứ ba ?, là cô ấy ở lớp đầu tiên ?, và cô ấy dưới 13 tuổi?), Cho đến khi bạn đạt đến một chiếc lá. Khi bạn ở đó, dự đoán tương ứng với lớp đích có hầu hết các trường hợp.

2

Câu hỏi đầu tiên: Vâng, logic của bạn là chính xác. Nút bên trái là True và nút bên phải là False. Điều này phản trực giác; true nói chung sẽ có nghĩa là một giá trị nhỏ hơn.

Câu hỏi thứ hai: Vấn đề này được giải quyết tốt nhất bằng cách trực quan hóa cây dưới dạng đồ thị với pydotplus. Thuộc tính 'class_names' của tree.export_graphviz() sẽ thêm một khai báo lớp vào lớp đa số của mỗi nút. Mã được thực hiện trong iPython.

from sklearn.datasets import load_iris 
from sklearn import tree 
iris = load_iris() 
clf2 = tree.DecisionTreeClassifier() 
clf2 = clf2.fit(iris.data, iris.target) 

with open("iris.dot", 'w') as f: 
    f = tree.export_graphviz(clf, out_file=f) 

import os 
os.unlink('iris.dot') 

import pydotplus 
dot_data = tree.export_graphviz(clf2, out_file=None) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 
graph2.write_pdf("iris.pdf") 

from IPython.display import Image 
dot_data = tree.export_graphviz(clf2, out_file=None, 
        feature_names=iris.feature_names, 
        class_names=iris.target_names, 
        filled=True, rounded=True, # leaves_parallel=True, 
        special_characters=True) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 

## Color of nodes 
nodes = graph2.get_node_list() 

for node in nodes: 
    if node.get_label(): 
     values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]; 
     color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],} 
     values = color[values.index(max(values))]; # print(values) 
     color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color) 
     node.set_fillcolor(color) 
# 

Image(graph2.create_png()) 

enter image description here

Đối với việc xác định lớp nhìn vào chiếc lá, ví dụ bạn không có lá với một lớp duy nhất, như các thiết lập iris dữ liệu nào. Điều này là phổ biến và có thể yêu cầu mô hình quá phù hợp để đạt được kết quả như vậy. Phân phối các lớp riêng biệt là kết quả tốt nhất cho nhiều mô hình được xác thực chéo.

Tận hưởng mã!

Các vấn đề liên quan