2016-07-04 11 views
16

Tôi muốn khởi tạo một số biến số trên mạng của mình với các giá trị tối đa. Vì lợi ích của ví dụ xem xét:Làm thế nào để khởi tạo một biến với tf.get_variable và một giá trị numpy trong TensorFlow?

init=np.random.rand(1,2) 
tf.get_variable('var_name',initializer=init) 

khi tôi làm điều đó tôi nhận được một lỗi:

ValueError: Shape of a new variable (var_name) must be fully defined, but instead was <unknown>. 

tại sao nó mà tôi đang nhận được lỗi đó?

Để cố gắng sửa chữa nó Tôi đã cố gắng thực hiện:

tf.get_variable('var_name',initializer=init, shape=[1,2]) 

mà mang lại một lỗi thậm chí lạ:

TypeError: 'numpy.ndarray' object is not callable 

Tôi cố gắng đọc the docs and examples nhưng nó không thực sự giúp đỡ.

Không thể khởi tạo biến với mảng có nhiều mảng với phương thức get_variable trong TensorFlow?

Trả lời

25

Các công trình sau đây:

init = tf.constant(np.random.rand(1, 2)) 
tf.get_variable('var_name', initializer=init) 

Các tài liệu cho get_variable là một chút thiếu thực sự. Chỉ để bạn tham khảo, đối số initializer phải là đối tượng TensorFlow Tensor (có thể được tạo bằng cách gọi tf.constant trên giá trị numpy trong trường hợp của bạn) hoặc 'có thể gọi' có hai đối số, shapedtype, hình dạng và loại dữ liệu của giá trị mà nó phải trả về. Một lần nữa, trong trường hợp của bạn, bạn có thể viết như sau trong trường hợp bạn muốn sử dụng cơ chế 'callable':

init = lambda shape, dtype: np.random.rand(*shape) 
tf.tf.get_variable('var_name', initializer=init, shape=[1, 2]) 
+2

[này] (http://stackoverflow.com/questions/111234/what-is-a-callable-in-python) là một câu trả lời tuyệt vời cho câu hỏi của bạn. – keveman

+0

Một 'callable' là một chức năng hoặc một cái gì đó có thể được gọi là một hàm. – hpaulj

+0

'tf.get_variable ('var_name', initializer = np.random.rand (1, 2))' dường như hoạt động ngay bây giờ trên r0.10. – ldavid

6

@keveman trả lời tốt, và cho bổ sung, có việc sử dụng tf.get_variable (' var_name ', initializer = init), tài liệu tensorflow đã đưa ra một ví dụ toàn diện.

import numpy as np 
import tensorflow as tf 

value = [0, 1, 2, 3, 4, 5, 6, 7] 
# value = np.array(value) 
# value = value.reshape([2, 4]) 
init = tf.constant_initializer(value) 

print('fitting shape:') 
tf.reset_default_graph() 
with tf.Session() : 
    x = tf.get_variable('x', shape = [2, 4], initializer = init) 
    x.initializer.run() 
    print(x.eval()) 

    fitting shape : 
[[0. 1. 2. 3.] 
[4. 5. 6. 7.]] 

print('larger shape:') 
tf.reset_default_graph() 
with tf.Session() : 
    x = tf.get_variable('x', shape = [3, 4], initializer = init) 
    x.initializer.run() 
    print(x.eval()) 

    larger shape : 
[[0. 1. 2. 3.] 
[4. 5. 6. 7.] 
[7. 7. 7. 7.]] 

print('smaller shape:') 
tf.reset_default_graph() 
with tf.Session() : 
    x = tf.get_variable('x', shape = [2, 3], initializer = init) 

    * <b>`ValueError`</b > : Too many elements provided.Needed at most 6, but received 8 

https://www.tensorflow.org/api_docs/python/tf/constant_initializer

2

Nếu biến đã được tạo (tức là từ một số chức năng phức tạp), chỉ cần sử dụng load.

https://www.tensorflow.org/api_docs/python/tf/Variable#load

x_var = tf.Variable(tf.zeros((1, 2), tf.float32)) 
x_val = np.random.rand(1,2).astype(np.float32) 

sess = tf.Session() 
x_var.load(x_val, session=sess) 

# test 
assert np.all(sess.run(x_var) == x_val) 
Các vấn đề liên quan