2011-12-29 39 views
27

Khi vẽ một dấu chấm bằng cách sử dụng matplotlib, tôi muốn bù đắp các điểm dữ liệu trùng lặp để giữ cho chúng hiển thị tất cả. Ví dụ, nếu tôi cóMatplotlib: tránh các điểm dữ liệu chồng lên nhau trong một ô "tán xạ/chấm/beeswarm"

CategoryA: 0,0,3,0,5 
CategoryB: 5,10,5,5,10 

Tôi muốn mỗi CategoryA "0" datapoints được đặt cạnh nhau, chứ không phải là ngay trên đầu trang của mỗi khác, trong khi vẫn còn khác biệt với CategoryB.

Trong R (ggplot2) có tùy chọn "jitter" thực hiện việc này. Có một lựa chọn tương tự trong matplotlib, hoặc là có cách tiếp cận khác mà có thể dẫn đến một kết quả tương tự?

Edit: để làm rõ, the "beeswarm" plot in R về cơ bản là những gì tôi có trong tâm trí, và pybeeswarm là một khởi đầu nhưng hữu ích tại một matplotlib phiên bản/Python.

Edit: thêm rằng sanh ở biển của Swarmplot, được giới thiệu trong phiên bản 0.7, là một thực hiện tuyệt vời của những gì tôi muốn.

+0

Trong một [dot cốt truyện] (http://en.wikipedia.org/wiki/Dot_plot_ (thống kê)) những điểm này đã được tách ra trong cột của chúng – joaquin

+1

Định nghĩa wiki của "dấu chấm" không phải là những gì tôi đang cố gắng mô tả, nhưng tôi chưa bao giờ nghe về một thuật ngữ khác ngoài "dấu chấm" cho nó. Đó là khoảng một âm mưu phân tán nhưng với các nhãn x tùy ý (không nhất thiết là số). Vì vậy, trong ví dụ tôi mô tả trong câu hỏi, sẽ có một cột giá trị cho "CategoryA", cột thứ hai cho "CategoryB", v.v. (_Edit_: Định nghĩa wikipedia của "Cleveland dot plot" là tương tự như những gì tôi đang tìm kiếm, mặc dù vẫn không chính xác như nhau.) – iayork

Trả lời

6

Không biết của một mpl thay thế trực tiếp ở đây bạn có một đề nghị rất thô sơ:

from matplotlib import pyplot as plt 
from itertools import groupby 

CA = [0,4,0,3,0,5] 
CB = [0,0,4,4,2,2,2,2,3,0,5] 

x = [] 
y = [] 
for indx, klass in enumerate([CA, CB]): 
    klass = groupby(sorted(klass)) 
    for item, objt in klass: 
     objt = list(objt) 
     points = len(objt) 
     pos = 1 + indx + (1 - points)/50. 
     for item in objt: 
      x.append(pos) 
      y.append(item) 
      pos += 0.04 

plt.plot(x, y, 'o') 
plt.xlim((0,3)) 

plt.show() 

enter image description here

7

tôi đã sử dụng numpy.random để "phân tán/beeswarm" các dữ liệu cùng trục X nhưng xung quanh một điểm cố định cho mỗi thể loại, và sau đó về cơ bản làm pyplot.scatter() cho mỗi thể loại:

import matplotlib.pyplot as plt 
import numpy as np 

#random data for category A, B, with B "taller" 
yA, yB = np.random.randn(100), 5.0+np.random.randn(1000) 

xA, xB = np.random.normal(1, 0.1, len(yA)), 
     np.random.normal(3, 0.1, len(yB)) 

plt.scatter(xA, yA) 
plt.scatter(xB, yB) 
plt.show() 

X-scattered data

29

Mở rộng câu trả lời của @ user2467675, dưới đây là cách tôi đã làm nó:

def rand_jitter(arr): 
    stdev = .01*(max(arr)-min(arr)) 
    return arr + np.random.randn(len(arr)) * stdev 

def jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None, **kwargs): 
    return scatter(rand_jitter(x), rand_jitter(y), s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, verts=verts, hold=hold, **kwargs) 

Biến stdev đảm bảo rằng các jitter là đủ để được nhìn thấy trên quy mô khác nhau, nhưng nó giả định rằng các giới hạn của các trục là 0 và giá trị tối đa.

Sau đó, bạn có thể gọi jitter thay vì scatter.

+0

Tôi thực sự thích tính toán tự động của bạn về quy mô của jitter. Hoạt động tốt cho tôi. –

+0

Tính năng này có hoạt động nếu 'arr' chỉ chứa 0 số không (tức là stdev = 0)? – Dataman

5

Một cách để tiếp cận vấn đề là để suy nghĩ của mỗi 'hàng' trong phân tán/dot âm mưu của bạn/beeswarm như một thùng trong một biểu đồ:

data = np.random.randn(100) 

width = 0.8  # the maximum width of each 'row' in the scatter plot 
xpos = 0  # the centre position of the scatter plot in x 

counts, edges = np.histogram(data, bins=20) 

centres = (edges[:-1] + edges[1:])/2. 
yvals = centres.repeat(counts) 

max_offset = width/counts.max() 
offsets = np.hstack((np.arange(cc) - 0.5 * (cc - 1)) for cc in counts) 
xvals = xpos + (offsets * max_offset) 

fig, ax = plt.subplots(1, 1) 
ax.scatter(xvals, yvals, s=30, c='b') 

Điều này rõ ràng liên quan đến binning dữ liệu, vì vậy bạn có thể mất một số độ chính xác.Nếu bạn có dữ liệu rời rạc, bạn có thể thay thế:

counts, edges = np.histogram(data, bins=20) 
centres = (edges[:-1] + edges[1:])/2. 

với:

centres, counts = np.unique(data, return_counts=True) 

Một phương pháp khác mà giữ gìn chính xác tọa độ y, ngay cả đối với dữ liệu liên tục, là sử dụng một kernel density estimate để mở rộng biên độ của jitter ngẫu nhiên trong trục x:

from scipy.stats import gaussian_kde 

kde = gaussian_kde(data) 
density = kde(data)  # estimate the local density at each datapoint 

# generate some random jitter between 0 and 1 
jitter = np.random.rand(*data.shape) - 0.5 

# scale the jitter by the KDE estimate and add it to the centre x-coordinate 
xvals = 1 + (density * jitter * width * 2) 

ax.scatter(xvals, data, s=30, c='g') 
for sp in ['top', 'bottom', 'right']: 
    ax.spines[sp].set_visible(False) 
ax.tick_params(top=False, bottom=False, right=False) 

ax.set_xticks([0, 1]) 
ax.set_xticklabels(['Histogram', 'KDE'], fontsize='x-large') 
fig.tight_layout() 

Lần gặp gỡ thứ hai này hod dựa trên cách hoạt động của violin plots. Nó vẫn không thể đảm bảo rằng không có điểm nào trùng nhau, nhưng tôi thấy rằng trong thực tế nó có xu hướng cho kết quả khá đẹp, miễn là có một số điểm phong nha (> 20), và phân phối có thể được xấp xỉ một cách hợp lý bởi một tổng hợp Gaussians.

enter image description here

3

sanh ở biển cung cấp biểu đồ giống như phân loại dot-lô qua sns.swarmplot() và jittered phân loại dot-lô qua sns.stripplot():

import seaborn as sns 

sns.set(style='ticks', context='talk') 
iris = sns.load_dataset('iris') 

sns.swarmplot('species', 'sepal_length', data=iris) 
sns.despine() 

enter image description here

sns.stripplot('species', 'sepal_length', data=iris, jitter=0.2) 
sns.despine() 

enter image description here

1

swarmplot sanh ở biển của vẻ như phù hợp với apt nhất cho những gì bạn có trong tâm trí, nhưng bạn cũng có thể jitter với regplot sanh ở biển của:

import seaborn as sns 
iris = sns.load_dataset('iris') 

sns.regplot(x='sepal_length', 
      y='sepal_width', 
      data=iris, 
      fit_reg=False, # do not fit a regression line 
      x_jitter=0.1, # could also dynamically set this with range of data 
      y_jitter=0.1, 
      scatter_kws={'alpha': 0.5}) # set transparency to 50% 
Các vấn đề liên quan