42

Tôi đang cố gắng gán một giá trị mới cho một biến tensorflow trong python.Cách gán giá trị cho biến TensorFlow?

import tensorflow as tf 
import numpy as np 

x = tf.Variable(0) 
init = tf.initialize_all_variables() 
sess = tf.InteractiveSession() 
sess.run(init) 

print(x.eval()) 

x.assign(1) 
print(x.eval()) 

Nhưng đầu ra tôi nhận được là

0 
0 

Vì vậy, giá trị vẫn không thay đổi. Tôi đang thiếu gì?

Trả lời

72

Tuyên bố x.assign(1) không thực sự gán giá trị 1-x, nhưng thay vì tạo ra một tf.Operation mà bạn phải rõ ràng chạy để cập nhật các biến * Một cuộc gọi đến Operation.run() hoặc Session.run() có thể được sử dụng để chạy các hoạt động.:

assign_op = x.assign(1) 
sess.run(assign_op) # or `assign_op.op.run()` 
print(x.eval()) 
# ==> 1 

(* trong thực tế, nó trả về một tf.Tensor, tương ứng với giá trị cập nhật của biến, để làm cho nó dễ dàng hơn để tập chuỗi.)

+0

Cảm ơn! assign_op.run() đưa ra một lỗi: AttributeError: đối tượng 'Tensor' không có thuộc tính 'run'. Nhưng sess.run (assign_op) chạy hoàn toàn tốt. – abora

+0

Trong ví dụ này, là dữ liệu mà 'Biến'' x' được lưu trữ trong bộ nhớ trước khi hoạt động' assign'/tensable tensor được ghi đè hoặc là một tensor mới được tạo để lưu trữ giá trị cập nhật? – dannygoldstein

+3

Việc thực hiện hiện tại 'assign()' sẽ ghi đè giá trị hiện tại. – mrry

-4

Có một cách tiếp cận dễ dàng hơn:

x = tf.Variable(0) 
x = x + 1 
print x.eval() 
+2

o.p. đã kiểm tra việc sử dụng 'tf.assign', chứ không phải bổ sung. – vega

6

Trước hết bạn có thể gán giá trị cho biến/hằng số chỉ bằng cách ăn các giá trị vào chúng giống như cách bạn làm điều đó với placeholders. Vì vậy, điều này hoàn toàn hợp pháp để làm:

import tensorflow as tf 
x = tf.Variable(0) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print sess.run(x, feed_dict={x: 3}) 

Về sự nhầm lẫn của bạn với nhà điều hành tf.assign(). Trong TF không có gì được thực thi trước khi bạn chạy nó bên trong phiên. Vì vậy, bạn luôn phải làm một cái gì đó như thế này: op_name = tf.some_function_that_create_op(params) và sau đó bên trong phiên bạn chạy sess.run(op_name). Sử dụng gán như một ví dụ bạn sẽ làm điều gì đó như thế này:

import tensorflow as tf 
x = tf.Variable(0) 
y = tf.assign(x, 1) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print sess.run(x) 
    print sess.run(y) 
    print sess.run(x) 
+1

Lưu ý rằng việc nạp giá trị thông qua 'feed_dict' không gán giá trị đó cho biến. –

2

Ngoài ra, nó phải được lưu ý rằng nếu bạn đang sử dụng your_tensor.assign() thì tf.global_variables_initializer không cần phải được gọi một cách rõ ràng kể từ khi hoạt động assign nào đó cho bạn trong nền.

Ví dụ:

In [212]: w = tf.Variable(12) 
In [213]: w_new = w.assign(34) 

In [214]: with tf.Session() as sess: 
    ...:  sess.run(w_new) 
    ...:  print(w_new.eval()) 

# output 
34 

Tuy nhiên, điều này sẽ không khởi tạo tất cả các biến, nhưng nó sẽ chỉ khởi tạo biến mà assign được thực thi trên.

2

Bạn cũng có thể chỉ định giá trị mới cho tf.Variable mà không thêm hoạt động vào biểu đồ: tf.Variable.load(value, session). Hàm này cũng có thể giúp bạn tiết kiệm thêm phần giữ chỗ khi gán giá trị từ bên ngoài biểu đồ và nó hữu ích trong trường hợp biểu đồ được hoàn thành.

import tensorflow as tf 
x = tf.Variable(0) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(x)) # Prints 0. 
x.load(1, sess) 
print(sess.run(x)) # Prints 1. 
Các vấn đề liên quan