2017-01-12 25 views
5

Tôi đang cố gắng sử dụng tf.while_loop() để xử lý các đầu vào có độ dài biến đổi. Tuy nhiên, tôi chỉ có thể sử dụng nó với độ dài cố định. Mã không còn hoạt động sau khi tôi thay đổi hình dạng = (4) thành hình = (Không). tf.dynamic_rnn dường như xử lý các đầu vào có độ dài biến đổi. Tôi không chắc chắn làm thế nào tf.dynamic_rnn đạt được điều này với tf.while_loop().Làm thế nào để sử dụng tf.while_loop() cho các đầu vào có độ dài biến đổi trong lưu lượng tensorflow?

import tensorflow as tf 
import numpy as np 
from tensorflow.python.ops import tensor_array_ops 
from tensorflow.python.ops import array_ops 

with tf.Graph().as_default(), tf.Session() as sess: 
    initial_m = tf.Variable(0.0, name='m') 

    inputs = tf.placeholder(dtype='float32', shape=(4)) 
    #The code no longer works after I change shape=(4) to shape=(None) 
    #inputs = tf.placeholder(dtype='float32', shape=(None)) 

    time_steps = tf.shape(inputs)[0] 

    initial_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps) 
    initial_t = tf.constant(0, dtype='int32') 

    def should_continue(t, *args): 
    return t < time_steps 

    def iteration(t, m, outputs_): 
    cur = tf.gather(inputs, t) 
    m = m * 0.5 + cur * 0.5 
    outputs_ = outputs_.write(t, m) 
    return t + 1, m, outputs_ 

    t, m, outputs = tf.while_loop(
    should_continue, iteration, 
    [initial_t, initial_m, initial_outputs]) 

    outputs = outputs.pack() 
    init = tf.global_variables_initializer() 
    sess.run([init]) 
    print sess.run([outputs], feed_dict={inputs: np.asarray([1,1,1,1])}) 

đầu ra (trước khi thay đổi):

[array([ 0.5 , 0.75 , 0.875 , 0.9375], dtype=float32)] 

đầu ra (sau khi thay đổi):

Traceback (most recent call last): 
    File "simple.py", line 26, in <module> 
    [initial_t, initial_m, initial_outputs]) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2636, in while_loop 
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2469, in BuildLoop 
    pred, body, original_loop_vars, loop_vars, shape_invariants) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2450, in _BuildLoop 
    _EnforceShapeInvariant(m_var, n_var) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 586, in _EnforceShapeInvariant 
    % (merge_var.name, m_shape, n_shape)) 
ValueError: The shape for while/Merge_1:0 is not an invariant for the loop. It enters the loop with shape(), but has shape <unknown> after one iteration. Provide shape invariants using either the `shape_invariants` argument of tf.while_loop or set_shape() on the loop variables. 

Trả lời

6

Nó hoạt động nếu bạn loại bỏ hình từ tất cả các biến:

import tensorflow as tf 
import numpy as np 

config = tf.ConfigProto(graph_options=tf.GraphOptions(
    optimizer_options=tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0))) 
tf.reset_default_graph() 
sess = tf.Session("", config=config) 
#initial_m = tf.Variable(0.0, name='m') 

#The code no longer works after I change shape=(4) to shape=(None) 
inputs = tf.placeholder(dtype='float32', shape=(None)) 
time_steps = tf.shape(inputs)[0] 
initial_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps) 
initial_t = tf.placeholder(dtype='int32') 
initial_m = tf.placeholder(dtype=tf.float32) 

def should_continue(t, *args): 
    return t < time_steps 

def iteration(t, m, outputs_): 
    cur = tf.gather(inputs, t) 
    m = m * 0.5 + cur * 0.5 
    outputs_ = outputs_.write(t, m) 
    return t + 1, m, outputs_ 

t, m, outputs = tf.while_loop(should_continue, iteration, 
           [initial_t, initial_m, initial_outputs]) 

outputs = outputs.stack() 
init = tf.global_variables_initializer() 
sess.run([init]) 
print(sess.run([outputs], 
       feed_dict={inputs: np.asarray([1, 1, 1, 1]), initial_t: 0, 
          initial_m: 0.})) 
+0

Nó hoạt động. Cảm ơn! –

+0

BTW, có thể tránh sử dụng trình giữ chỗ không? initial_t và initial_m là các số không. Tôi có thể tránh tiêm số không cho họ mỗi lần không? –

+0

Bạn chỉ cần một cái gì đó với hình dạng không rõ. Có vẻ như không thể liên tục với hình dạng không rõ. Nhưng bạn có thể sử dụng biến hoặc giữ chỗ –

Các vấn đề liên quan