2016-04-21 61 views
9

Tôi đang cố gắng thu thập các lát của một tensor về kích thước cuối cùng cho một phần kết nối giữa các lớp. Bởi vì hình dạng tensor đầu ra là [batch_size, h, w, depth], tôi muốn chọn lát dựa trên kích thước cuối cùng, chẳng hạn nhưTrong Tensorflow, làm thế nào để sử dụng tf.gather() cho kích thước cuối cùng?

# L is intermediate tensor 
partL = L[:, :, :, [0,2,3,8]] 

Tuy nhiên, tf.gather(L, [0, 2,3,8]) dường như chỉ có tác dụng đối với kích thước đầu tiên (phải không?) Bất cứ ai có thể cho tôi biết làm thế nào để làm nó?

Trả lời

9

Có một lỗi theo dõi để hỗ trợ này use-case ở đây: https://github.com/tensorflow/tensorflow/issues/206

Đối với bây giờ bạn có thể:

  1. transpose ma trận của bạn để chiều để thu thập là lần đầu tiên (transpose là tốn kém)

  2. định hình lại tensor thành 1d (định dạng lại giá rẻ) và biến chỉ mục cột thu thập của bạn thành danh sách các chỉ mục phần tử riêng lẻ tại chỉ mục tuyến tính, sau đó định lại lại

  3. sử dụng gather_nd. Vẫn sẽ cần phải biến các chỉ mục cột của bạn thành danh sách các chỉ mục phần tử riêng lẻ.
+3

Lưu ý rằng tf.gather có một tham số trục như của TensorFlow 1.3. – rryan

0

Thực hiện 2. từ @Yaroslav Bulatov của:

#Your indices 
indices = [0, 2, 3, 8] 

#Remember for final reshaping 
n_indices = tf.shape(indices)[0] 

flattened_L = tf.reshape(L, [-1]) 

#Walk strided over the flattened array 
offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1) 
flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1]) 
selected_rows = tf.gather(flattened_L, flattened_indices) 

#Final reshape 
partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]])) 

tín dụng để How to select rows from a 3-D Tensor in TensorFlow?

5

Với gather_nd bây giờ bạn có thể làm điều này như sau:

cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0) 
result = tf.gather_nd(matrix, cat_idx) 

Ngoài ra, theo báo cáo của người dùng Nova trong một chuỗi được tham chiếu bởi @Yaroslav Bulatov's:

x = tf.constant([[1, 2, 3], 
       [4, 5, 6], 
       [7, 8, 9]]) 
idx = tf.constant([1, 0, 2]) 
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx 
y = tf.gather(tf.reshape(x, [-1]), # flatten input 
       idx_flattened) # use flattened indices 

with tf.Session(''): 
    print y.eval() # [2 4 9] 

Gist đang san bằng tensor và sử dụng địa chỉ 1D có cấu trúc với tf.gather (...).

+1

Tôi không chắc chắn ví dụ đầu tiên của bạn hoạt động. Giả sử 'tf.shape (x) [0]' là 1, thì 'cat_idx' sẽ là' [0, 0, 2, 3, 8] ', không phải là thứ bạn muốn sử dụng với' tf.gather_nd' . Trong thực tế, trong trường hợp này nó sẽ ném một lỗi vì chiều dài của chiều sâu bên trong của 'chỉ số' (đối số thứ 2 đến' tập hợp_nd') không thể lớn hơn thứ hạng 'params' (đối số thứ nhất cho' gather_nd'). – kaufmanu

+0

Tôi đã đăng phiên bản đã sửa (sử dụng 'tf.stack') bên dưới. –

0

Tensor không có hình dạng thuộc tính, nhưng phương thức get_shape(). Dưới đây là Runnable bởi Python 2,7

import tensorflow as tf 
import numpy as np 
x = tf.constant([[1, 2, 3], 
       [4, 5, 6], 
       [7, 8, 9]]) 
idx = tf.constant([1, 0, 2]) 
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx 
y = tf.gather(tf.reshape(x, [-1]), # flatten input 
       idx_flattened) # use flattened indices 

with tf.Session(''): 
    print y.eval() # [2 4 9] 
3

Tuy nhiên, một giải pháp sử dụng tf.unstack (...), tf.gather (...) và tf.stack (..)

Code:

import tensorflow as tf 
import numpy as np 

shape = [2, 2, 2, 10] 
L = np.arange(np.prod(shape)) 
L = np.reshape(L, shape) 

indices = [0, 2, 3, 8] 
axis = -1 # last dimension 

def gather_axis(params, indices, axis=0): 
    return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis) 

print(L) 
with tf.Session() as sess: 
    partL = sess.run(gather_axis(L, indices, axis)) 
    print(partL) 

Kết quả:

L = 
[[[[ 0 1 2 3 4 5 6 7 8 9] 
    [10 11 12 13 14 15 16 17 18 19]] 

    [[20 21 22 23 24 25 26 27 28 29] 
    [30 31 32 33 34 35 36 37 38 39]]] 


[[[40 41 42 43 44 45 46 47 48 49] 
    [50 51 52 53 54 55 56 57 58 59]] 

    [[60 61 62 63 64 65 66 67 68 69] 
    [70 71 72 73 74 75 76 77 78 79]]]] 

partL = 
[[[[ 0 2 3 8] 
    [10 12 13 18]] 

    [[20 22 23 28] 
    [30 32 33 38]]] 


[[[40 42 43 48] 
    [50 52 53 58]] 

    [[60 62 63 68] 
    [70 72 73 78]]]] 
2

Một phiên bản đúng của @ câu trả lời Andrei của sẽ đọc

cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1) 
result = tf.gather_nd(matrix, cat_idx) 
1

Bạn có thể thử cách này, ví dụ (trong hầu hết trường hợp trong NLP tại ít nhất),

Các tham số là hình dạng [batch_size, depth] và các chỉ số là [i, j, k, n, m] của chiều dài là batch_size. Sau đó, gather_nd có thể hữu ích.

parameters = tf.constant([ 
          [11, 12, 13], 
          [21, 22, 23], 
          [31, 32, 33], 
          [41, 42, 43]])  
targets = tf.constant([2, 1, 0, 1])  
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])  
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number 
items = tf.gather_nd(parameters, indices) 
# which is what we want: [13, 22, 31, 42] 

Đoạn mã này đầu tiên tìm kích thước nắm tay qua lô và sau đó tìm mục dọc theo thứ nguyên đó theo số mục tiêu.

Các vấn đề liên quan