2016-09-25 11 views
5

Nếu A là một biến TensorFlow như vậyTensorFlow nhận các yếu tố của mỗi hàng cho các cột cụ thể

A = tf.Variable([[1, 2], [3, 4]]) 

index là một biến

index = tf.Variable([0, 1]) 

Tôi muốn sử dụng chỉ số này để chọn cột trong mỗi hàng. Trong trường hợp này, mục 0 từ hàng đầu tiên và mục 1 từ hàng thứ hai.

Nếu A là một mảng NumPy sau đó để có được các cột của hàng tương ứng nêu tại chỉ số chúng ta có thể làm

x = A[np.arange(A.shape[0]), index] 

và kết quả sẽ là

[1, 4] 

hoạt động tương đương TensorFlow là gì/hoạt động cho điều này? Tôi biết TensorFlow không hỗ trợ nhiều hoạt động lập chỉ mục. Điều gì sẽ là công việc xung quanh nếu nó không thể được thực hiện trực tiếp?

Trả lời

2

Sau khi trải nghiệm một thời gian dài. Tôi đã tìm thấy hai hàm có thể hữu ích.

Một là tf.gather_nd() mà có thể hữu ích nếu bạn có thể tạo ra một tensor dạng [[0, 0], [1, 1]] và do đó bạn có thể làm

index = tf.constant([[0, 0], [1, 1]])

tf.gather_nd(A, index)

Nếu bạn không thể tạo ra một vector của biểu mẫu [[0, 0], [1, 1]] (Tôi không thể sản xuất điều này vì số lượng hàng trong trường hợp của tôi phụ thuộc vào trình giữ chỗ) vì lý do nào đó thì công việc xung quanh tôi thấy là sử dụng tf.py_func(). Dưới đây là một mã số ví dụ về cách này có thể được thực hiện

import tensorflow as tf 
import numpy as np 

def index_along_every_row(array, index): 
    N, _ = array.shape 
    return array[np.arange(N), index] 

a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32) 
index = tf.Variable([0, 1], dtype=tf.int32) 
a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0] 
session = tf.InteractiveSession() 

a.initializer.run() 
index.initializer.run() 
a_slice = a_slice_op.eval() 

a_slice sẽ là một mảng NumPy [1, 4]

0

Chúng ta có thể làm như vậy bằng sự kết hợp này của map_fngather_nd.

def get_element(a, indices): 
    """ 
    Outputs (ith element of indices) from (ith row of a) 
    """ 
    return tf.map_fn(lambda x: tf.gather_nd(x[0], x[1]), 
            (a, indices), 
            dtype = tf.float32) 

Đây là ví dụ về cách sử dụng.

A = tf.constant(np.array([[1,2,3], 
          [4,5,6], 
          [7,8,9]], dtype = np.float32)) 

idx = tf.constant(np.array([[2],[1],[0]])) 
elems = get_element(A, idx) 

with tf.Session() as sess: 
    e = sess.run(elems) 

print(e) 

Tôi không biết điều này có chậm hơn nhiều so với các câu trả lời khác hay không.

Có lợi thế là bạn không cần chỉ định số hàng của A trước, miễn là aindices có cùng số hàng khi chạy.

Note đầu ra của trên sẽ được cấp bậc 1. Nếu bạn muốn nó để có thứ hạng 2, thay thế gather_nd bởi gather

2

Bạn có thể sử dụng one hot phương pháp để tạo ra một mảng one_hot và sử dụng nó như một boolean mặt nạ để chọn các chỉ mục bạn muốn.

A = tf.Variable([[1, 2], [3, 4]]) 
index = tf.Variable([0, 1]) 

one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool) 
output = tf.boolean_mask(A, one_hot_mask) 
+0

Trong khi điều này có thể giúp với câu hỏi, mỗi khi kết nối qua đời, không có lời giải thích về những gì đang xảy ra. Bạn nên hình thành câu trả lời của riêng bạn, giải thích những gì đang xảy ra, và không dựa vào các liên kết không được phép trên SO như là nguồn thông tin duy nhất. https://stackoverflow.com/help/how-to-answer – Rob

1

Bạn có thể mở rộng chỉ số cột của bạn với chỉ số hàng và sau đó sử dụng gather_nd:

import tensorflow as tf 

A = tf.constant([[1, 2], [3, 4]]) 
indices = tf.constant([1, 0]) 

# prepare row indices 
row_indices = tf.range(tf.shape(indices)[0]) 

# zip row indices with column indices 
full_indices = tf.stack([row_indices, indices], axis=1) 

# retrieve values by indices 
S = tf.gather_nd(A, full_indices) 

session = tf.InteractiveSession() 
session.run(S) 
Các vấn đề liên quan