2011-12-01 52 views
5

Tôi xử lý ma trận khá lớn trong Python/Scipy. Tôi cần phải trích xuất các hàng từ ma trận lớn (được nạp vào coo_matrix) và sử dụng chúng như các phần tử đường chéo. Hiện nay tôi làm điều đó trong thời trang sau:Tạo ma trận đường chéo thưa từ hàng của ma trận thưa thớt

import numpy as np 
from scipy import sparse 

def computation(A): 
    for i in range(A.shape[0]): 
    diag_elems = np.array(A[i,:].todense()) 
    ith_diag = sparse.spdiags(diag_elems,0,A.shape[1],A.shape[1], format = "csc") 
    #... 

#create some random matrix 
A = (sparse.rand(1000,100000,0.02,format="csc")*5).astype(np.ubyte) 
#get timings 
profile.run('computation(A)') 

Những gì tôi nhìn thấy từ profile đầu ra là hầu hết thời gian được tiêu thụ bởi get_csr_submatrix chức năng khi giải nén diag_elems. Điều đó làm cho tôi nghĩ rằng tôi sử dụng hoặc không hiệu quả đại diện thưa thớt của dữ liệu ban đầu hoặc sai cách chiết xuất hàng từ một ma trận thưa thớt. Bạn có thể đề xuất một cách tốt hơn để trích xuất một hàng từ một ma trận thưa thớt và đại diện cho nó trong một hình thức đường chéo?

EDIT

Các biến thể sau đây loại bỏ nút cổ chai từ quá trình chiết hàng (thông báo rằng đơn giản thay đổi 'csc'-csr là không đủ, A[i,:] phải được thay thế bằng A.getrow(i) cũng). Tuy nhiên, câu hỏi chính là làm thế nào để bỏ qua hiện thực hóa (.todense()) và tạo ma trận đường chéo từ biểu diễn thưa thớt của hàng.

import numpy as np 
from scipy import sparse 

def computation(A): 
    for i in range(A.shape[0]): 
    diag_elems = np.array(A.getrow(i).todense()) 
    ith_diag = sparse.spdiags(diag_elems,0,A.shape[1],A.shape[1], format = "csc") 
    #... 

#create some random matrix 
A = (sparse.rand(1000,100000,0.02,format="csr")*5).astype(np.ubyte) 
#get timings 
profile.run('computation(A)') 

Nếu tôi tạo ma trận đường chéo từ 1 hàng ma trận CSR trực tiếp, như sau:

diag_elems = A.getrow(i) 
ith_diag = sparse.spdiags(diag_elems,0,A.shape[1],A.shape[1]) 

sau đó tôi không thể xác định format="csc" tranh luận, cũng không phải chuyển đổi ith_diags sang định dạng CSC:

Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
    File "/usr/local/lib/python2.6/profile.py", line 70, in run 
    prof = prof.run(statement) 
    File "/usr/local/lib/python2.6/profile.py", line 456, in run 
    return self.runctx(cmd, dict, dict) 
    File "/usr/local/lib/python2.6/profile.py", line 462, in runctx 
    exec cmd in globals, locals 
    File "<string>", line 1, in <module> 
    File "<stdin>", line 4, in computation 
    File "/usr/local/lib/python2.6/site-packages/scipy/sparse/construct.py", line 56, in spdiags 
    return dia_matrix((data, diags), shape=(m,n)).asformat(format) 
    File "/usr/local/lib/python2.6/site-packages/scipy/sparse/base.py", line 211, in asformat 
    return getattr(self,'to' + format)() 
    File "/usr/local/lib/python2.6/site-packages/scipy/sparse/dia.py", line 173, in tocsc 
    return self.tocoo().tocsc() 
    File "/usr/local/lib/python2.6/site-packages/scipy/sparse/coo.py", line 263, in tocsc 
    data = np.empty(self.nnz, dtype=upcast(self.dtype)) 
    File "/usr/local/lib/python2.6/site-packages/scipy/sparse/sputils.py", line 47, in upcast 
    raise TypeError,'no supported conversion for types: %s' % args 
TypeError: no supported conversion for types: object` 
+1

bạn có thử 'format =" csr "' thay thế không? – cyborg

+0

Với 'csr' cho dữ liệu ban đầu và 'A [i,:]' được thay thế bằng 'A.getrow (i)', tôi đã đạt được tốc độ đáng kể. Nhưng những gì tôi đang tìm kiếm là bỏ qua việc vật chất hóa hàng tạo ra ma trận đường chéo. Bất kỳ ý tưởng? – savenkov

Trả lời

3

Đây là những gì tôi đã đưa ra:

def computation(A): 
    for i in range(A.shape[0]): 
     idx_begin = A.indptr[i] 
     idx_end = A.indptr[i+1] 
     row_nnz = idx_end - idx_begin 
     diag_elems = A.data[idx_begin:idx_end] 
     diag_indices = A.indices[idx_begin:idx_end] 
     ith_diag = sparse.csc_matrix((diag_elems, (diag_indices, diag_indices)),shape=(A.shape[1], A.shape[1])) 
     ith_diag.eliminate_zeros() 

Trình thu thập thông tin bằng Python cho biết 1.464 giây so với 5.574 giây trước đó. Nó tận dụng các mảng dày đặc bên dưới (indptr, index, data) để xác định các ma trận thưa thớt. Đây là khóa học tai nạn của tôi: A.indptr [i]: A.indptr [i + 1] xác định phần tử nào trong mảng dày đặc tương ứng với các giá trị khác 0 trong hàng i. A.data là một mảng dày đặc 1d không khác giá trị của A và A.indptr là cột nơi các giá trị đó đi.

Tôi sẽ làm một số thử nghiệm khác để làm cho rất chắc chắn điều này không giống như trước đây. Tôi chỉ kiểm tra một vài trường hợp.

+0

Kevin, thật tuyệt! – savenkov

+0

BTW, row_nnz không được sử dụng – savenkov

Các vấn đề liên quan