2015-11-18 16 views
5

Tôi đã viết một số mã trong Python hoạt động tốt nhưng rất chậm; Tôi nghĩ là do các vòng lặp. Tôi hy vọng người ta có thể tăng tốc các hoạt động sau bằng cách sử dụng các lệnh numpy. Hãy để tôi xác định mục tiêu.vectơ numpy thay vì cho vòng

Giả sử tôi có một mảng có khối lượng 2D all_CMs của thứ nguyên row x col. Ví dụ, xem xét một mảng 6 x 11 (xem hình bên dưới).

  1. tôi muốn để tính toán giá trị trung bình cho tất cả các hàng, ví dụ: tổng ⱼ aᵢⱼ kết quả trong một mảng. Điều này, tất nhiên có thể dễ dàng thực hiện. (Tôi gọi là giá trị này CM_tilde)

  2. Bây giờ, cho mỗi hàng tôi muốn để tính toán giá trị trung bình của một số giá trị được chọn, cụ thể là tất cả các giá trị dưới một ngưỡng nhất định bằng cách tính toán số tiền của họ và chia cho số lượng tất cả các cột (N). Nếu giá trị cao hơn ngưỡng được xác định này, giá trị CM_tilde (giá trị trung bình của toàn bộ hàng) được thêm vào. Giá trị này được gọi là CM

  3. Sau đó, giá trị CM được trừ từ mỗi phần tử trong hàng

Thêm vào đó tôi muốn có một mảng NumPy hoặc danh sách mà tất cả những giá trị CM được liệt kê .

Hình:

figure

Các mã sau đây được làm việc nhưng rất chậm (đặc biệt nếu các mảng việc lớn)

CM_tilde = np.mean(data, axis=1) 
N = data.shape[1] 
data_cm = np.zeros((data.shape[0], data.shape[1], data.shape[2])) 
all_CMs = np.zeros((data.shape[0], data.shape[2])) 
for frame in range(data.shape[2]): 
    for row in range(data.shape[0]): 
     CM=0 
     for col in range(data.shape[1]): 
      if data[row, col, frame] < (CM_tilde[row, frame]+threshold): 
       CM += data[row, col, frame] 
      else: 
       CM += CM_tilde[row, frame] 
     CM = CM/N 
     all_CMs[row, frame] = CM 
     # calculate CM corrected value 
     for col in range(data.shape[1]): 
      data_cm[row, col, frame] = data[row, col, frame] - CM 
    print "frame: ", frame 
return data_cm, all_CMs 

Bất kỳ ý tưởng?

+0

Trong bước 2, về cơ bản bạn có thay thế bất kỳ giá trị nào vượt quá ngưỡng trên CM_tilde không và sau đó * tính giá trị trung bình trên toàn bộ hàng, bao gồm cả giá trị được thay thế? – Evert

+0

Bắt đầu bằng cách sử dụng 'np.where' để thay thế vòng lặp bên trong của bạn. Sau đó, bằng cách sử dụng phát sóng, bạn có thể xóa 2 vòng lặp bên ngoài. Xem tài liệu cho [where] (http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.where.html) – mtadd

Trả lời

12

Đó là khá dễ dàng để vectorize những gì bạn đang thực hiện:

import numpy as np 

#generate dummy data 
nrows=6 
ncols=11 
nframes=3 
threshold=0.3 
data=np.random.rand(nrows,ncols,nframes) 

CM_tilde = np.mean(data, axis=1) 
N = data.shape[1] 

all_CMs2 = np.mean(np.where(data < (CM_tilde[:,None,:]+threshold),data,CM_tilde[:,None,:]),axis=1) 
data_cm2 = data - all_CMs2[:,None,:] 

So sánh điều này với bản gốc của bạn:

In [684]: (data_cm==data_cm2).all() 
Out[684]: True 

In [685]: (all_CMs==all_CMs2).all() 
Out[685]: True 

Logic là chúng ta làm việc với mảng kích thước [nrows,ncols,nframes] cùng một lúc. Bí quyết chính là sử dụng phát sóng python, bằng cách chuyển CM_tilde kích thước [nrows,nframes] thành CM_tilde[:,None,:] kích thước [nrows,1,nframes]. Sau đó Python sẽ sử dụng các giá trị giống nhau cho mỗi cột, vì đó là một kích thước đơn lẻ của sửa đổi CM_tilde này.

Bằng cách sử dụng np.where chúng tôi chọn (dựa trên threshold) cho dù chúng tôi muốn nhận được giá trị tương ứng của data hoặc giá trị phát sóng là CM_tilde. Sử dụng mới np.mean cho phép chúng tôi tính all_CMs2.

Trong bước cuối cùng, chúng tôi đã sử dụng phát sóng bằng cách trực tiếp trừ số all_CMs2 mới này từ các phần tử tương ứng của data.

Nó có thể giúp trong việc vector hóa mã theo cách này bằng cách xem xét các chỉ mục tiềm ẩn của các biến tạm thời của bạn. Điều tôi ngụ ý là biến tạm thời của bạn CM tồn tại bên trong vòng lặp trên [nrows,nframes] và giá trị của nó được đặt lại với mỗi lần lặp lại. Điều này có nghĩa là CM có hiệu lực với số lượng CM[row,frame] (sau đó được gán rõ ràng cho mảng 2d all_CMs) và từ đây dễ dàng thấy rằng bạn có thể xây dựng nó bằng cách cộng một số lượng CMtmp[row,col,frames] thích hợp dọc theo thứ nguyên cột của nó. Nếu nó giúp, bạn có thể đặt tên cho phần np.where(...)CMtmp cho mục đích này và sau đó tính np.mean(CMtmp,axis=1) từ đó. Cùng một kết quả, rõ ràng, nhưng có lẽ minh bạch hơn.

+0

Cảm ơn bạn rất nhiều; đây là nhanh hơn nhiều so với các vòng – pallago

+1

10001 là một giá trị tốt đẹp cho đại diện, Nó sẽ là một sự xấu hổ nếu ai đó downvotes này. –

+0

@BhargavRao \ o/cảm ơn bạn, thưa bạn! :) Hoặc, cảm ơn bạn đã không downvoting: D –

1

Đây là vectơ hóa chức năng của bạn. Tôi đã làm việc từ trong ra ngoài và nhận xét các phiên bản cũ hơn khi tôi đi cùng. Vì vậy, vòng lặp đầu tiên mà tôi đã vector hóa có ### nhãn nhận xét.

Nó không sạch sẽ và hợp lý như câu trả lời @Andras's, nhưng hy vọng đó là hướng dẫn, đưa ra ý tưởng về cách bạn có thể giải quyết vấn đề này theo từng bước.

def foo2(data, threshold): 
    CM_tilde = np.mean(data, axis=1) 
    N = data.shape[1] 
    #data_cm = np.zeros((data.shape[0], data.shape[1], data.shape[2])) 
    ##all_CMs = np.zeros((data.shape[0], data.shape[2])) 
    bmask = data < (CM_tilde[:,None,:] + threshold) 
    CM = np.zeros_like(data) 
    CM[:] = CM_tilde[:,None,:] 
    CM[bmask] = data[bmask] 
    CM = CM.sum(axis=1) 
    CM = CM/N 
    all_CMs = CM.copy() 
    """ 
    for frame in range(data.shape[2]): 
     for row in range(data.shape[0]): 
      ###print(frame, row) 
      ###mask = data[row, :, frame] < (CM_tilde[row, frame]+threshold) 
      ###print(mask) 
      ##mask = bmask[row,:,frame] 
      ##CM = data[row, mask, frame].sum() 
      ##CM += (CM_tilde[row, frame]*(~mask)).sum() 

      ##CM = CM/N 
      ##all_CMs[row, frame] = CM 
      ## calculate CM corrected value 
      #for col in range(data.shape[1]): 
      # data_cm[row, col, frame] = data[row, col, frame] - CM[row,frame] 
     print "frame: ", frame 
    """ 
    data_cm = data - CM[:,None,:] 
    return data_cm, all_CMs 

Kết quả phù hợp cho trường hợp thử nghiệm nhỏ này, hơn bất kỳ thứ gì đã giúp tôi có được kích thước phù hợp.

threshold = .1 
data = np.arange(4*3*2,dtype=float).reshape(4,3,2) 
Các vấn đề liên quan