Cách tiêu chuẩn để lưu biến trong TensorFlow là sử dụng đối tượng tf.train.Saver
. Theo mặc định nó tiết kiệm tất cả các biến trong vấn đề của bạn (ví dụ, kết quả của tf.all_variables()
), nhưng bạn có thể tiết kiệm biến có chọn lọc bằng cách thông qua các đối số tùy chọn var_list
đến tf.train.Saver
constructor:
weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}
# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)
# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...
# ...then call the following methods as appropriate:
weights_saver.save(sess) # Save the current value of the weights.
biases_saver.save(sess) # Save the current value of the biases.
Lưu ý rằng nếu bạn chuyển từ điển tới hàm tạo tf.train.Saver
(chẳng hạn như weights
và/hoặc biases
từ điển từ câu hỏi của bạn), TensorFlow sẽ sử dụng khóa từ điển (ví dụ: 'wc1_0'
) làm tên cho biến tương ứng trong bất kỳ tệp điểm kiểm tra nào tạo ra hoặc tiêu thụ .
Theo mặc định hoặc nếu bạn chuyển danh sách các đối tượng tf.Variable
cho hàm tạo, TensorFlow sẽ sử dụng thuộc tính tf.Variable.name
thay thế.
Truyền từ điển cung cấp cho bạn khả năng chia sẻ các điểm kiểm tra giữa các mô hình cung cấp các thuộc tính khác nhau cho mỗi biến số Variable.name
. Chi tiết này chỉ quan trọng nếu bạn muốn sử dụng các điểm kiểm tra được tạo với một mô hình khác.
Có một số cách khác nhau để thực hiện. Bạn có thể sử dụng trình tiết kiệm lưu lượng hoặc sử dụng định dạng yêu thích của bạn như h5 hoặc npy. – jean
cảm ơn, tôi hiểu rồi. – luohao