2016-05-05 29 views
13

Tôi đang gặp sự cố khi hiểu khái niệm cơ bản với lưu lượng tensorflow. Việc lập chỉ mục hoạt động như thế nào đối với các hoạt động đọc/ghi tensor? Để làm điều này cụ thể, làm thế nào có thể các ví dụ sau đây NumPy được dịch sang tensorflow (sử dụng tensors cho các mảng, chỉ số và các giá trị được gán):cách lập chỉ mục tensorflow hoạt động

x = np.zeros((3, 4)) 
row_indices = np.array([1, 1, 2]) 
col_indices = np.array([0, 2, 3]) 
x[row_indices, col_indices] = 2 
x 

với sản lượng:

array([[ 0., 0., 0., 0.], 
     [ 2., 0., 2., 0.], 
     [ 0., 0., 0., 2.]]) 

.. . và ...

x[row_indices, col_indices] = np.array([5, 4, 3]) 
x 

với sản lượng:

array([[ 0., 0., 0., 0.], 
     [ 5., 0., 4., 0.], 
     [ 0., 0., 0., 3.]]) 

... và cuối cùng ...

y = x[row_indices, col_indices] 
y 

với sản lượng:

array([ 5., 4., 3.]) 
+0

Nhiều việc bạn có thể thực hiện với tính năng gọn gàng không được hỗ trợ trong lưu lượng. Xem câu hỏi này http://stackoverflow.com/questions/33736795/tensorflow-numpy-like-tensor-indexing Có lẽ bạn có thể giải thích thêm những gì bạn đang cố gắng thực hiện bằng cách cập nhật các giá trị ma trận đó và ai đó có thể trả lời cách đạt được kết quả đó tensorflow. – Aaron

+0

Cảm ơn bạn đã liên kết.Tôi muốn thử một số hình ảnh tăng thêm (chưa có sẵn trong thư viện tensorflow) để hy vọng cải thiện tính tổng quát của mạng phân loại hình ảnh được học. Tôi không cần gradiants để chảy qua các ops và tôi có thể làm điều đó một cách dễ dàng trên CPU nhưng đó sẽ trở thành một nút cổ chai hiệu suất rất lớn. Tôi nghĩ tôi có thể làm những gì tôi cần với câu trả lời của Yaroslav. – Keith

Trả lời

9

Có github vấn đề #206 để hỗ trợ này độc đáo, trong khi đó bạn phải nghỉ mát để tiết công việc ở quanh

Ví dụ đầu tiên có thể được thực hiện với tf.select kết hợp hai tensors hình giống nhau bằng cách chọn từng phần tử từ một hoặc khác

tf.reset_default_graph() 
row_indices = tf.constant([1, 1, 2]) 
col_indices = tf.constant([0, 2, 3]) 
x = tf.zeros((3, 4)) 
sess = tf.InteractiveSession() 

# get list of ((row1, col1), (row2, col2), ..) 
coords = tf.transpose(tf.pack([row_indices, col_indices])) 

# get tensor with 1's at positions (row1, col1),... 
binary_mask = tf.sparse_to_dense(coords, x.get_shape(), 1) 

# convert 1/0 to True/False 
binary_mask = tf.cast(binary_mask, tf.bool) 

twos = 2*tf.ones(x.get_shape()) 

# make new x out of old values or 2, depending on mask 
x = tf.select(binary_mask, twos, x) 

print x.eval() 

cho

[[ 0. 0. 0. 0.] 
[ 2. 0. 2. 0.] 
[ 0. 0. 0. 2.]] 

thứ hai có thể được thực hiện với scatter_update, trừ scatter_update chỉ hỗ trợ trên chỉ số tuyến tính và hoạt động trên các biến. Vì vậy, bạn có thể tạo một biến tạm thời và sử dụng định dạng lại như thế này. (Để tránh biến mà bạn có thể sử dụng dynamic_stitch, xem cuối cùng)

# get linear indices 
linear_indices = row_indices*x.get_shape()[1]+col_indices 

# turn 'x' into 1d variable since "scatter_update" supports linear indexing only 
x_flat = tf.Variable(tf.reshape(x, [-1])) 

# no automatic promotion, so make updates float32 to match x 
updates = tf.constant([5, 4, 3], dtype=tf.float32) 

sess.run(tf.initialize_all_variables()) 
sess.run(tf.scatter_update(x_flat, linear_indices, updates)) 

# convert back into original shape 
x = tf.reshape(x_flat, x.get_shape()) 

print x.eval() 

cho

[[ 0. 0. 0. 0.] 
[ 5. 0. 4. 0.] 
[ 0. 0. 0. 3.]] 

Cuối cùng ví dụ thứ ba đã được hỗ trợ với gather_nd, bạn viết

print tf.gather_nd(x, coords).eval() 

Để có được

[ 5. 4. 3.] 

Chỉnh sửa, ngày 06 tháng 5

Bản cập nhật x[cols,rows]=newvals thể được thực hiện mà không sử dụng biến (trong đó chiếm bộ nhớ giữa các cuộc gọi phiên chạy) bằng cách sử dụng select với sparse_to_dense mà sẽ đưa vector của các giá trị thưa thớt, hoặc dựa vào dynamic_stitch

sess = tf.InteractiveSession() 
x = tf.zeros((3, 4)) 
row_indices = tf.constant([1, 1, 2]) 
col_indices = tf.constant([0, 2, 3]) 

# no automatic promotion, so specify float type 
replacement_vals = tf.constant([5, 4, 3], dtype=tf.float32) 

# convert to linear indexing in row-major form 
linear_indices = row_indices*x.get_shape()[1]+col_indices 
x_flat = tf.reshape(x, [-1]) 

# use dynamic stitch, it merges the array by taking value either 
# from array1[index1] or array2[index2], if indices conflict, 
# the later one is used 
unchanged_indices = tf.range(tf.size(x_flat)) 
changed_indices = linear_indices 
x_flat = tf.dynamic_stitch([unchanged_indices, changed_indices], 
          [x_flat, replacement_vals]) 
x = tf.reshape(x_flat, x.get_shape()) 
print x.eval() 
+0

Những ví dụ này rất hữu ích. Cảm ơn bạn! – Keith

+2

đồng nghiệp đề xuất sử dụng 'dynamic_stitch' thay vì Biến, cập nhật công thức nấu ăn –

Các vấn đề liên quan