2016-08-22 14 views
5

Tôi muốn đếm số ma trận bằng nhau mà tôi gặp phải sau khi tách một ma trận lớn.numpy - đếm các mảng bằng nhau

mat1 = np.zeros((4, 8)) 

split4x4 = np.split(mat1, 4) 

Bây giờ tôi muốn biết có bao nhiêu ma trận bằng nhau trong split4x4, nhưng collections.Counter(split4x4) sẽ phát sinh lỗi. Có một cách tích hợp trong numpy để làm điều này?

+0

tôi là người nghiệp dư vì vậy đây nghe có vẻ ngớ ngẩn, nhưng np.split() sẽ theo mặc định chia mảng trong mảnh bằng mà bạn chỉ định (cho ví dụ: 4 trong ví dụ trên) và nếu nó không thể nó ném một lỗi. Vì vậy, tại sao bạn cần phải tìm hiểu thông tin đó, đó sẽ không chỉ là 4? –

Trả lời

1

này có thể được thực hiện một cách đầy đủ vectorized sử dụng numpy_indexed gói (từ chối trách nhiệm: Tôi là tác giả của nó):

import numpy_indexed as npi 
unique_rows, row_counts = npi.count(mat1) 

này cần được nhanh hơn đáng kể so với sử dụng collections.Counter.

1

Có lẽ cách đơn giản nhất là sử dụng np.unique và san bằng các mảng tách để so sánh chúng như tuple:

import numpy as np 
# Generate some sample data: 
a = np.random.uniform(size=(8,3)) 
# With repetition: 
a = np.r_[a,a] 
# Split a in 4 arrays 
s = np.asarray(np.split(a, 4)) 
s = [tuple(e.flatten()) for e in s] 
np.unique(s, return_counts=True) 

Ghi chú: return_counts đối số của np.unique mới trong phiên bản 1.9.0.

Một giải pháp NumPy tinh khiết khác lấy cảm hứng từ that post

# Generate some sample data: 
In: a = np.random.uniform(size=(8,3)) 
# With some repetition 
In: a = r_[a,a] 
In: a.shape 
Out: (16,3) 
# Split a in 4 arrays 
In: s = np.asarray(np.split(a, 4)) 
In: print s 
Out: [[[ 0.78284847 0.28883662 0.53369866] 
     [ 0.48249722 0.02922249 0.0355066 ] 
     [ 0.05346797 0.35640319 0.91879326] 
     [ 0.1645498 0.15131476 0.1717498 ]] 

     [[ 0.98696629 0.8102581 0.84696276] 
     [ 0.12612661 0.45144896 0.34802173] 
     [ 0.33667377 0.79371788 0.81511075] 
     [ 0.81892789 0.41917167 0.81450135]] 

     [[ 0.78284847 0.28883662 0.53369866] 
     [ 0.48249722 0.02922249 0.0355066 ] 
     [ 0.05346797 0.35640319 0.91879326] 
     [ 0.1645498 0.15131476 0.1717498 ]] 

     [[ 0.98696629 0.8102581 0.84696276] 
     [ 0.12612661 0.45144896 0.34802173] 
     [ 0.33667377 0.79371788 0.81511075] 
     [ 0.81892789 0.41917167 0.81450135]]] 
In: s.shape 
Out: (4, 4, 3) 
# Flatten the array: 
In: s = asarray([e.flatten() for e in s]) 
In: s.shape 
Out: (4, 12) 
# Sort the rows using lexsort: 
In: idx = np.lexsort(s.T) 
In: s_sorted = s[idx] 
# Create a mask to get unique rows 
In: row_mask = np.append([True],np.any(np.diff(s_sorted,axis=0),1)) 
# Get unique rows: 
In: out = s_sorted[row_mask] 
# and count: 
In: for e in out: 
     count = (e == s).all(axis=1).sum() 
     print e.reshape(4,3), count 
Out:[[ 0.78284847 0.28883662 0.53369866] 
    [ 0.48249722 0.02922249 0.0355066 ] 
    [ 0.05346797 0.35640319 0.91879326] 
    [ 0.1645498 0.15131476 0.1717498 ]] 2 
    [[ 0.98696629 0.8102581 0.84696276] 
    [ 0.12612661 0.45144896 0.34802173] 
    [ 0.33667377 0.79371788 0.81511075] 
    [ 0.81892789 0.41917167 0.81450135]] 2 
+0

bạn đang sử dụng python 3 trong ví dụ đầu tiên? Nguyên nhân tôi nhận được từ 'a = r_ [a, a]' 'TênError: name 'r_' không được xác định' – andandandand

+0

@andandandand Không. Đó là lỗi của tôi, tôi đã quên 'np' ngay trước' r_', một cách đơn giản để xây dựng các mảng nhanh chóng (xem: http://docs.scipy.org/doc/numpy/reference/generated/numpy.r_ .html). Tôi vừa sửa lại câu trả lời của mình. – bougui

Các vấn đề liên quan