2017-06-18 13 views
16

này nó mã:ValueError: Đang cố gắng để chia sẻ biến rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel

X = tf.placeholder(tf.float32, [batch_size, seq_len_1, 1], name='X') 
labels = tf.placeholder(tf.float32, [None, alpha_size], name='labels') 

rnn_cell = tf.contrib.rnn.BasicLSTMCell(512) 
m_rnn_cell = tf.contrib.rnn.MultiRNNCell([rnn_cell] * 3, state_is_tuple=True) 
pre_prediction, state = tf.nn.dynamic_rnn(m_rnn_cell, X, dtype=tf.float32) 

này là đầy đủ lỗi:

ValueError: Trying to share variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (1024, 2048) and found shape (513, 2048).

Tôi đang sử dụng một phiên bản GPU của tensorflow.

Trả lời

25

Tôi gặp sự cố tương tự khi tôi nâng cấp lên v1.2 (tensorflow-gpu). Thay vì sử dụng [rnn_cell]*3, tôi đã tạo 3 rnn_cells (stacked_rnn) bởi một vòng lặp (để chúng không chia sẻ các biến) và cho ăn MultiRNNCell với stacked_rnn và sự cố sẽ biến mất. Tôi không chắc đó là cách đúng đắn để làm điều đó.

stacked_rnn = [] 
for iiLyr in range(3): 
    stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=512, state_is_tuple=True)) 
MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True) 
3

Tôi đoán đó là do các ô RNN của bạn trên mỗi 3 lớp của bạn có cùng hình dạng đầu vào và đầu ra.

Trên lớp 1, thứ nguyên đầu vào là 513 = 1 (kích thước x của bạn) + 512 (kích thước của lớp ẩn) cho mỗi dấu thời gian trên mỗi lô.

Trên lớp 2 và 3, kích thước đầu vào là 1024 = 512 (đầu ra từ lớp trước) + 512 (đầu ra từ dấu thời gian trước đó).

Cách bạn xếp chồng lên MultiRNNCell có thể ngụ ý rằng 3 ô chia sẻ cùng một hình dạng đầu vào và đầu ra.

tôi chồng lên MultiRNNCell bằng cách tuyên bố hai loại riêng biệt của các tế bào để ngăn chặn chúng từ việc chia sẻ hình dạng đầu vào

rnn_cell1 = tf.contrib.rnn.BasicLSTMCell(512) 
run_cell2 = tf.contrib.rnn.BasicLSTMCell(512) 
stack_rnn = [rnn_cell1] 
for i in range(1, 3): 
    stack_rnn.append(rnn_cell2) 
m_rnn_cell = tf.contrib.rnn.MultiRNNCell(stack_rnn, state_is_tuple = True) 

Sau đó, tôi có thể đào tạo dữ liệu của tôi mà không lỗi này. Tôi không chắc liệu phỏng đoán của tôi có chính xác không, nhưng nó có hiệu quả với tôi hay không. Hi vọng nó sẽ giúp ích cho bạn.

12

Một TensorFlow hướng dẫn chính thức khuyến cáo cách này nhiều định nghĩa mạng LSTM:

def lstm_cell(): 
    return tf.contrib.rnn.BasicLSTMCell(lstm_size) 
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
    [lstm_cell() for _ in range(number_of_layers)]) 

Bạn có thể tìm thấy nó ở đây: https://www.tensorflow.org/tutorials/recurrent

Trên thực tế nó nó gần như là cách tiếp cận tương tự mà Wasi Ahmad và Maosi Chen gợi ý ở trên nhưng có thể ở dạng thanh lịch hơn một chút.

Các vấn đề liên quan