2016-08-14 25 views
7

Làm thế nào tôi có thể đọc các biến và trạng thái của chúng từ một trạm kiểm soát?Tensorflow. Liệt kê các biến trong điểm kiểm tra

Tôi đang làm việc với bộ mã hóa tự động và điểm kiểm tra của tôi chứa trạng thái hoàn chỉnh của mạng, tức là bộ mã hóa, bộ giải mã, trình tối ưu hóa, v.v. Tôi muốn đánh lừa với mã hóa và do đó sẽ chỉ cần bộ phận giải mã của mạng trong chế độ đánh giá của tôi.

Câu hỏi tương tự theo cách trừu tượng hơn: làm cách nào tôi có thể đọc chỉ các biến cụ thể từ điểm kiểm tra hiện có để tái sử dụng trong mô hình khác?

Tôi có nên đặt tên biến của mình tương ứng không? Hoặc là có một cách để có được một cái gì đó như:

w_init = read_from_state(state_location, var_name) 

def read_from_state(state_location, var_name): 
    # the magic goes here 
    pass 

Trả lời

14

list_variables phương pháp trong checkpoint_utils.py cho phép bạn xem tất cả các biến lưu.

Tuy nhiên, đối với trường hợp sử dụng của bạn, có thể khôi phục dễ dàng hơn bằng Trình tiết kiệm. Nếu bạn biết tên của các biến khi bạn lưu điểm kiểm tra, bạn có thể tạo trình tiết kiệm mới và yêu cầu nó khởi tạo các tên đó thành các đối tượng mới Variable (có thể với các tên khác nhau). Điều này được sử dụng trong ví dụ CIFAR để chọn khôi phục subset of variables. Xem Choosing which Variables to Save and Restore trong Howto

0

Một cách khác, điều đó sẽ in tất cả tensors trạm kiểm soát (hoặc chỉ là một, nếu được chỉ định) cùng với nội dung của họ:

from tensorflow.python.tools import inspect_checkpoint as inch 
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True) 
""" 
Args: 
    file_name: Name of the checkpoint file. 
    tensor_name: Name of the tensor in the checkpoint file to print. 
    all_tensors: Boolean indicating whether to print all tensors. 
""" 

Nó sẽ luôn in nội dung của tensor.

Và, trong khi chúng tôi đang ở đó, ở đây là làm thế nào để sử dụng checkpoint_utils, được đề xuất bởi câu trả lời trước:

from tensorflow.contrib.framework.python.framework import checkpoint_utils 
    var_list = checkpoint_utils.list_variables('path/to/ckpt') 
    for v in var_list: print(v) 
Các vấn đề liên quan