2016-07-06 22 views
11

Tôi muốn xem các biến được lưu trong trạm kiểm soát lưu lượng cùng với giá trị của chúng. Làm cách nào để tìm tên biến được lưu trong trạm kiểm soát lưu lượng?Làm thế nào để tìm tên biến được lưu trong một trạm kiểm soát lưu lượng?

EDIT:

tôi đã sử dụng tf.train.NewCheckpointReader đó được giải thích here. Nhưng, nó không được đưa ra trong tài liệu của tensorflow. Còn cách nào khác không?

'

import tensorflow as tf 
    v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0") 
    v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32, 
        name="v1") 
    init_all_op = tf.initialize_all_variables() 
    save = tf.train.Saver({"v0": v0, "v1": v1}) 
    checkpoint_path = os.path.join(model_dir, "model.ckpt")  

    with tf.Session() as sess: 
     sess.run(init_all_op) 
     # Saves a checkpoint.  
     save.save(sess, checkpoint_path) 

     # Creates a reader. 
     reader = tf.train.NewCheckpointReader(checkpoint_path) 
     print('reder:\n', reader) 

     # Verifies that the tensors exist. 
     print('is exist v0?', reader.has_tensor("v0")) 
     print('is exist v1?', reader.has_tensor("v1")) 

     # Verifies that debug string contains the right strings. 
     debug_string = reader.debug_string() 
     print('\n All Variables: \n', debug_string) 

     # Verifies get_variable_to_shape_map() returns the correct information. 
     var_map = reader.get_variable_to_shape_map() 
     print('\n All Variables information :\n', var_map) 

     # Verifies get_tensor() returns the tensor value. 
     v0_tensor = reader.get_tensor("v0") 
     v1_tensor = reader.get_tensor("v1") 
     print('\n returns the v0 tensor value:\n', v0_tensor) 
     print('\n returns the v1 tensor value:\n', v1_tensor) 

'

+0

Tôi thấy rằng bạn đã chấp nhận câu trả lời. Vì vậy, mã mà bạn đã viết để chạy hàm 'print_tensors_in_checkpoint_file là gì?' Tôi đã cố gắng sử dụng nó nhưng bất cứ khi nào tôi làm 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' python nói rằng module' tensorflow.python' không có thuộc tính 'công cụ'. Tôi nghĩ rằng nó sẽ vô cùng hữu ích nếu bạn cung cấp một kịch bản ví dụ nhỏ về cách chạy hàm này (vì tệp đó không cung cấp một ví dụ), đặc biệt vì bạn đã chấp nhận câu trả lời vì vậy tôi giả sử một cái gì đó làm việc cho bạn. – Pinocchio

Trả lời

4

Bạn có thể sử dụng công cụ inspect_checkpoint.py.

+2

Tôi đã cố gắng sử dụng nó nhưng bất cứ khi nào tôi làm 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' python nói rằng module' tensorflow.python' không có thuộc tính 'tools'. Tôi nghĩ rằng ti sẽ vô cùng hữu ích nếu bạn cung cấp một kịch bản ví dụ nhỏ về cách chạy hàm này (vì tệp đó không cung cấp ví dụ) hoặc là – Pinocchio

19

Ví dụ sử dụng:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='') 

# List contents of v0 tensor. 
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0') 

# List contents of v1 tensor. 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1') 

Cập nhật:all_tensors luận được thêm vào print_tensors_in_checkpoint_file từ Tensorflow 0.12.0-rc0, do đó bạn có thể cần thêm all_tensors=False hoặc all_tensors=True nếu cần thiết.

phương pháp thay thế:

from tensorflow.python import pywrap_tensorflow 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
for key in var_to_shape_map: 
    print("tensor_name: ", key) 
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names 

Hy vọng nó giúp.

+0

thực sự hữu ích, cảm ơn! – allen

1

Thêm vào câu trả lời ở trên:

Nếu mô hình được lưu bằng định dạng V2

model-10000.data-00000-of-00001 
model-10000.index 
model-10000.meta 

tên đầu vào trạm kiểm soát của bạn chỉ nên là tiền tố

print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True) 

nguồn: bởi @LingjiaDeng tại https://github.com/tensorflow/tensorflow/issues/7696

Các vấn đề liên quan