2011-10-29 31 views

Trả lời

17

Nói chung, matplotlib thường không chứa các chức năng vẽ biểu đồ hoạt động trên nhiều đối tượng trục (subplot, trong trường hợp này). Kỳ vọng là bạn sẽ viết một hàm đơn giản để kết hợp mọi thứ với nhau theo ý bạn muốn.

Tôi không hoàn toàn chắc chắn dữ liệu của mình trông như thế nào, nhưng thật đơn giản khi chỉ xây dựng một hàm để thực hiện việc này từ đầu. Nếu bạn luôn làm việc với các mảng có cấu trúc hoặc các mảng rec, thì bạn có thể đơn giản hóa điều này một lần chạm. (Ví dụ: Luôn luôn có một tên gắn liền với mỗi loạt dữ liệu, vì vậy bạn có thể bỏ qua phải chỉ định tên.)

Như một ví dụ:

import itertools 
import numpy as np 
import matplotlib.pyplot as plt 

def main(): 
    np.random.seed(1977) 
    numvars, numdata = 4, 10 
    data = 10 * np.random.random((numvars, numdata)) 
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'], 
      linestyle='none', marker='o', color='black', mfc='none') 
    fig.suptitle('Simple Scatterplot Matrix') 
    plt.show() 

def scatterplot_matrix(data, names, **kwargs): 
    """Plots a scatterplot matrix of subplots. Each row of "data" is plotted 
    against other rows, resulting in a nrows by nrows grid of subplots with the 
    diagonal subplots labeled with "names". Additional keyword arguments are 
    passed on to matplotlib's "plot" command. Returns the matplotlib figure 
    object containg the subplot grid.""" 
    numvars, numdata = data.shape 
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8)) 
    fig.subplots_adjust(hspace=0.05, wspace=0.05) 

    for ax in axes.flat: 
     # Hide all ticks and labels 
     ax.xaxis.set_visible(False) 
     ax.yaxis.set_visible(False) 

     # Set up ticks only on one side for the "edge" subplots... 
     if ax.is_first_col(): 
      ax.yaxis.set_ticks_position('left') 
     if ax.is_last_col(): 
      ax.yaxis.set_ticks_position('right') 
     if ax.is_first_row(): 
      ax.xaxis.set_ticks_position('top') 
     if ax.is_last_row(): 
      ax.xaxis.set_ticks_position('bottom') 

    # Plot the data. 
    for i, j in zip(*np.triu_indices_from(axes, k=1)): 
     for x, y in [(i,j), (j,i)]: 
      axes[x,y].plot(data[x], data[y], **kwargs) 

    # Label the diagonal subplots... 
    for i, label in enumerate(names): 
     axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction', 
       ha='center', va='center') 

    # Turn on the proper x or y axes ticks. 
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))): 
     axes[j,i].xaxis.set_visible(True) 
     axes[i,j].yaxis.set_visible(True) 

    return fig 

main() 

enter image description here

+1

Chà, nhiều chức năng mới! Có, không quá khó khăn khi bạn có chủ của mô-đun ... nhưng không đơn giản như gọi 'cặp' như trong R. :) – hatmatrix

+0

Đúng! R có nhiều chức năng chuyên biệt hơn, trong kinh nghiệm (giới hạn!) Của tôi với nó. Matplotlib có một cách tiếp cận DIY hơn một chút. (Hoặc chắc chắn ít hơn rất nhiều chức năng lập kế hoạch thống kê chuyên ngành, ở mức nào.) –

+0

Chắc chắn tôi cảm thấy theo cách này. Tôi đang gắn bó với bộ ba Python (cho bây giờ) với hy vọng mặc dù nó cung cấp những lợi thế khác ... – hatmatrix

9

Cám ơn chia sẻ mã của bạn! Bạn đã tìm ra tất cả những thứ khó khăn cho chúng tôi. Khi tôi đang làm việc với nó, tôi nhận thấy một vài điều nhỏ mà trông không hoàn toàn đúng.

  1. [FIX # 1] tics trục không được xếp hàng như tôi mong chờ (ví dụ, trong ví dụ của bạn ở trên, bạn sẽ có thể vẽ một đường dọc và ngang qua bất kỳ điểm nào trên tất cả các lô và các đường thẳng phải đi qua điểm tương ứng trong các ô khác, nhưng khi nó nằm ngay bây giờ điều này không xảy ra.

  2. [FIX # 2] Nếu bạn có số lẻ các biến bạn đang vẽ, phía dưới bên phải các trục góc không kéo xtics chính xác hoặc ytics. Nó chỉ để nó như là mặc định 0..1 ve.

  3. Không phải là một sửa chữa, nhưng tôi làm cho nó tùy chọn để nhập rõ ràng names, để nó đặt mặc định xi cho biến i ở vị trí chéo.

Dưới đây bạn sẽ tìm thấy phiên bản cập nhật mã của bạn giải quyết hai điểm này, nếu không hãy giữ nguyên vẻ đẹp mã của bạn.

import itertools 
import numpy as np 
import matplotlib.pyplot as plt 

def scatterplot_matrix(data, names=[], **kwargs): 
    """ 
    Plots a scatterplot matrix of subplots. Each row of "data" is plotted 
    against other rows, resulting in a nrows by nrows grid of subplots with the 
    diagonal subplots labeled with "names". Additional keyword arguments are 
    passed on to matplotlib's "plot" command. Returns the matplotlib figure 
    object containg the subplot grid. 
    """ 
    numvars, numdata = data.shape 
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8)) 
    fig.subplots_adjust(hspace=0.0, wspace=0.0) 

    for ax in axes.flat: 
     # Hide all ticks and labels 
     ax.xaxis.set_visible(False) 
     ax.yaxis.set_visible(False) 

     # Set up ticks only on one side for the "edge" subplots... 
     if ax.is_first_col(): 
      ax.yaxis.set_ticks_position('left') 
     if ax.is_last_col(): 
      ax.yaxis.set_ticks_position('right') 
     if ax.is_first_row(): 
      ax.xaxis.set_ticks_position('top') 
     if ax.is_last_row(): 
      ax.xaxis.set_ticks_position('bottom') 

    # Plot the data. 
    for i, j in zip(*np.triu_indices_from(axes, k=1)): 
     for x, y in [(i,j), (j,i)]: 
      # FIX #1: this needed to be changed from ...(data[x], data[y],...) 
      axes[x,y].plot(data[y], data[x], **kwargs) 

    # Label the diagonal subplots... 
    if not names: 
     names = ['x'+str(i) for i in range(numvars)] 

    for i, label in enumerate(names): 
     axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction', 
       ha='center', va='center') 

    # Turn on the proper x or y axes ticks. 
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))): 
     axes[j,i].xaxis.set_visible(True) 
     axes[i,j].yaxis.set_visible(True) 

    # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the 
    # correct axes limits, so we pull them from other axes 
    if numvars%2: 
     xlimits = axes[0,-1].get_xlim() 
     ylimits = axes[-1,0].get_ylim() 
     axes[-1,-1].set_xlim(xlimits) 
     axes[-1,-1].set_ylim(ylimits) 

    return fig 

if __name__=='__main__': 
    np.random.seed(1977) 
    numvars, numdata = 4, 10 
    data = 10 * np.random.random((numvars, numdata)) 
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'], 
      linestyle='none', marker='o', color='black', mfc='none') 
    fig.suptitle('Simple Scatterplot Matrix') 
    plt.show() 

Cảm ơn bạn đã chia sẻ điều này với chúng tôi. Tôi đã sử dụng nó nhiều lần! Ồ, và tôi đã sắp xếp lại phần main() của mã để nó có thể là mã mẫu chính thức hoặc không được gọi nếu mã đó đang được nhập vào một đoạn mã khác.

+0

Cảm ơn, tôi đã gặp sự cố với mã @Joe Kington cho đến khi tôi thấy câu trả lời của bạn. Nó tiết kiệm cho tôi một số thời gian gỡ lỗi :) – chutsu

+0

Bất kỳ ý tưởng, làm thế nào tôi có thể làm cho chức năng này nhanh hơn, tôi cần phải tạo ra một ma trận âm mưu phân tán lớn khoảng 100 vars và phương pháp này là rất chậm. – MARK

3

Trong khi đọc câu hỏi tôi dự kiến ​​sẽ thấy câu trả lời bao gồm rpy. Tôi nghĩ đây là một lựa chọn tốt để tận dụng hai ngôn ngữ đẹp. Vì vậy, ở đây là:

import rpy 
import numpy as np 

def main(): 
    np.random.seed(1977) 
    numvars, numdata = 4, 10 
    data = 10 * np.random.random((numvars, numdata)) 
    mpg = data[0,:] 
    disp = data[1,:] 
    drat = data[2,:] 
    wt = data[3,:] 
    rpy.set_default_mode(rpy.NO_CONVERSION) 

    R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt) 

    # Figure saved as eps 
    rpy.r.postscript('pairsPlot.eps') 
    rpy.r.pairs(R_data, 
     main="Simple Scatterplot Matrix Via RPy") 
    rpy.r.dev_off() 

    # Figure saved as png 
    rpy.r.png('pairsPlot.png') 
    rpy.r.pairs(R_data, 
     main="Simple Scatterplot Matrix Via RPy") 
    rpy.r.dev_off() 

    rpy.set_default_mode(rpy.BASIC_CONVERSION) 


if __name__ == '__main__': main() 

Tôi không thể đăng hình để hiển thị kết quả :(xin lỗi!

78

Đối với những người không muốn để xác định chức năng của mình, có một libarary phân tích dữ liệu lớn trong Python, được gọi là Pandas, nơi người ta có thể tìm thấy scatter_matrix() phương pháp:

from pandas.tools.plotting import scatter_matrix 
df = DataFrame(randn(1000, 4), columns=['a', 'b', 'c', 'd']) 
scatter_matrix(df, alpha=0.2, figsize=(6, 6), diagonal='kde') 

enter image description here

+2

Xin chào, làm thế nào chỉ một phần của các ô phụ có lưới trong đó? Điều đó có thể được sửa đổi (hoặc là tất cả hay không)? Cảm ơn – user2808117

+4

+1 Điều đó sẽ dạy tôi tìm kiếm một tính năng Python trước khi tìm xem nó đã có trong gấu trúc chưa. Bước 1: Luôn luôn hỏi, nó đã tồn tại trong gấu trúc chưa? 'pd.scatter_matrix (df); plt.show() '. Đáng kinh ngạc. – Jarad

+0

Đặt một kde trong ma trận phân tán matplotlib là môn thể thao khắc nghiệt. Tôi yêu gấu trúc. –

10

Bạn có thể cũng sử dụng Seaborn's pairplot function:

import seaborn as sns 
sns.set() 
df = sns.load_dataset("iris") 
sns.pairplot(df, hue="species") 
+0

phần khó chịu về bẩm sinh là nó tập trung xung quanh gấu trúc DataFrames. Nếu bạn có một mảng NumPy, cách giải quyết này cảm thấy khó chịu, và nếu bạn đã có một DataFrame gấu trúc, tại sao không chỉ sử dụng phương thức scatter_matrix trong xây dựng của gấu trúc? – Sebastian

Các vấn đề liên quan