2017-05-17 26 views
7

Tôi có đoạn code sauTheano chuyển đổi hàng khôn ngoan một cách hiệu quả

output = T.switch(cond, a, b) 

nơi cond là một tensor (N,1) bool, trong khi ab(N, M) tensors số với M là khá lớn. Điều kiện hoạt động theo cách khôn ngoan.

Tôi có thể dễ dàng thực hiện chuyển đổi bằng cách chạy T.repeat() trên cond, nhưng điều này khá chậm. Có cách nào tôi có thể thực hiện hiệu quả các bool trong cond quyết định xem có nên trả lại a hoặc b không?

Trả lời

3

Có cách nào tôi có thể thực hiện hiệu quả các bool trong cond quyết định xem a hoặc b có được trả lại không?

Vâng, bạn có thể làm

cond * a + (1-cond) * b 

cond sẽ được phát sóng vào (N, M) hình dạng.

Điều này phải gần với giới hạn lý thuyết, là băng thông bộ nhớ: thao tác này cần đọc về các yếu tố N*M và viết N*M.

Thay vào đó, chúng tôi đọc 2*N*M nhưng xóa logic có điều kiện.

(Tôi không có Theano trước mặt mình, vì vậy tôi không chắc liệu nó có nhanh hơn T.switch hay không, nhưng nó phải là tốt như nó được. Ngoài ra, tôi sẽ thử đúc cond vào cùng một dtype như ab)


Nếu bạn muốn cập nhật a tại chỗ, bạn có thể làm điều đó bằng T.set_subtensor:

a = np.random.uniform(size=(N, M)).astype(np.float32) 
b = np.random.uniform(size=(N, M)).astype(np.float32) 

a = theano.shared(a) 
b = theano.shared(b) 

c = T.vector() # mostly 0, presumably (1-cond) 

nz = T.nonzero(c) 

s = T.set_subtensor(a[nz], b[nz]) 
fn = theano.function([c], [], updates=[(a, s)]) 

... 

fn(1-cond) 

nó có thể hoặc không có thể nhanh hơn t han cách tiếp cận đầu tiên, tùy thuộc vào N, M và các yếu tố khác.

+0

Cảm ơn câu trả lời, tôi sẽ dùng thử! Những suy nghĩ thú vị về giới hạn lý thuyết. Tôi đoán tôi có thể tránh những lần đọc lớn và viết bằng cách khai thác thường xuyên nhất 'a' sẽ là giá trị đúng để trả lại và nó là tốt cho phương pháp để sửa đổi' a' trực tiếp. Giả sử chỉ có 5% thời gian 'b' nên được trả lại cho một hàng nhất định, không thể có được hiệu năng tốt hơn bằng cách sửa đổi' a' trực tiếp chỉ trên các hàng cần sửa đổi? – pir

+0

@pir Bạn có đang tối ưu hóa cho CPU hoặc GPU không? N, N và dtype điển hình là gì? – MaxB

+0

@pir cũng là phần này của NN hay cái gì đó cần gradient? – MaxB

Các vấn đề liên quan