2014-08-29 24 views
6

Tôi có ma trận vuông lớn (xấp xỉ 14.000 x 14.000) được biểu thị dưới dạng Numpy ndarray. Tôi muốn trích xuất một số lượng lớn các hàng và cột - các chỉ số mà tôi biết trước, mặc dù thực tế tất cả các hàng và cột không phải là tất cả các số không - để có được một ma trận vuông mới (xấp xỉ 10.000 x 10.000).Cách nhanh nhất để trích xuất các hàng và cột đã cho từ một Numpy ndarray là gì?

Cách nhanh nhất tôi đã tìm thấy để làm điều này là:

> timeit A[np.ix_(indices, indices)] 
1 loops, best of 3: 6.19 s per loop 

Tuy nhiên, đây là chậm hơn nhiều so với thời gian cần thiết để làm phép nhân ma trận:

> timeit np.multiply(A, A) 
1 loops, best of 3: 982 ms per loop 

Điều này có vẻ kỳ lạ, vì cả phép tách hàng và cột và phép nhân ma trận cần phân bổ một mảng mới (sẽ còn lớn hơn cho kết quả phép nhân ma trận so với phép trích xuất), nhưng phép nhân ma trận cũng cần thực hiện tính toán bổ sung.

Do đó, câu hỏi: có cách nào hiệu quả hơn để thực hiện việc trích xuất, cụ thể, ít nhất là nhanh như phép nhân ma trận?

+0

Tôi nghi ngờ nhân ma trận lý do nhanh là vì nó sử dụng tất cả các phần tử của mảng theo cách nghiêm ngặt và được tối ưu hóa để làm như vậy. Nếu bạn phải vượt qua các chỉ số tùy ý (nghĩa là, không phải là một lát hình chữ nhật), bạn sẽ không nhận được tốc độ tối đa. – BrenBarn

+1

vâng, sử dụng numpy 1.9 – seberg

+3

'np.multiply (A, A)' là * phần tử * nhân. Sử dụng 'np.dot (A, A)' để nhân ma trận. –

Trả lời

1

Nếu tôi cố gắng tái tạo sự cố của bạn, tôi không thấy hiệu ứng quyết liệt như vậy. Tôi nhận thấy rằng tùy thuộc vào số lượng chỉ mục bạn chọn, việc lập chỉ mục thậm chí có thể nhanh hơn so với phép nhân.

>>> import numpy as np 
>>> np.__version__ 
Out[1]: '1.9.0' 
>>> N = 14000 
>>> A = np.random.random(size=[N, N]) 

>>> indices = np.sort(np.random.choice(np.arange(N), 0.9*N, replace=False)) 
>>> timeit A[np.ix_(indices, indices)] 
1 loops, best of 3: 1.02 s per loop 
>>> timeit A.take(indices, axis=0).take(indices, axis=1) 
1 loops, best of 3: 1.37 s per loop 
>>> timeit np.multiply(A,A) 
1 loops, best of 3: 748 ms per loop 

>>> indices = np.sort(np.random.choice(np.arange(N), 0.7*N, replace=False)) 
>>> timeit A[np.ix_(indices, indices)] 
1 loops, best of 3: 633 ms per loop 
>>> timeit A.take(indices, axis=0).take(indices, axis=1) 
1 loops, best of 3: 946 ms per loop 
>>> timeit np.multiply(A,A) 
1 loops, best of 3: 728 ms per loop 
Các vấn đề liên quan