2015-12-28 26 views
8

Tôi cố gắng để thực hiện một đề nghị từ câu trả lời: Tensorflow: how to save/restore a model?tensorflow: Tiết kiệm và khôi phục phiên

tôi có một đối tượng mà kết thúc tốt đẹp một mô hình tensorflow trong một phong cách sklearn.

import tensorflow as tf 
class tflasso(): 
    saver = tf.train.Saver() 
    def __init__(self, 
       learning_rate = 2e-2, 
       training_epochs = 5000, 
        display_step = 50, 
        BATCH_SIZE = 100, 
        ALPHA = 1e-5, 
        checkpoint_dir = "./", 
      ): 
     ... 

    def _create_network(self): 
     ... 


    def _load_(self, sess, checkpoint_dir = None): 
     if checkpoint_dir: 
      self.checkpoint_dir = checkpoint_dir 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      self.saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

    def fit(self, train_X, train_Y , load = True): 
     self.X = train_X 
     self.xlen = train_X.shape[1] 
     # n_samples = y.shape[0] 

     self._create_network() 
     tot_loss = self._create_loss() 
     optimizer = tf.train.AdagradOptimizer(self.learning_rate).minimize(tot_loss) 

     # Initializing the variables 
     init = tf.initialize_all_variables() 
     " training per se" 
     getb = batchgen(self.BATCH_SIZE) 

     yvar = train_Y.var() 
     print(yvar) 
     # Launch the graph 
     NUM_CORES = 3 # Choose how many cores to use. 
     sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, 
                  intra_op_parallelism_threads=NUM_CORES) 
     with tf.Session(config= sess_config) as sess: 
      sess.run(init) 
      if load: 
       self._load_(sess) 
      # Fit all training data 
      for epoch in range(self.training_epochs): 
       for (_x_, _y_) in getb(train_X, train_Y): 
        _y_ = np.reshape(_y_, [-1, 1]) 
        sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) 
       # Display logs per epoch step 
       if (1+epoch) % self.display_step == 0: 
        cost = sess.run(tot_loss, 
          feed_dict={ self.vars.xx: train_X, 
            self.vars.yy: np.reshape(train_Y, [-1, 1])}) 
        rsq = 1 - cost/yvar 
        logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) 
        print(logstr) 
        self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', 
         global_step= 1+ epoch) 

      print("Optimization Finished!") 
     return self 

Khi tôi chạy:

tfl = tflasso() 
tfl.fit(train_X, train_Y , load = False) 

tôi nhận được kết quả:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 
    b1: 0.118122 
Epoch: 100 cost = 26.4506 R^2 = -0.5151 
    b1: 0.133597 
Epoch: 150 cost = 22.4330 R^2 = -0.2850 
    b1: 0.142261 
Epoch: 200 cost = 20.0361 R^2 = -0.1477 
    b1: 0.147998 

Tuy nhiên, khi tôi cố gắng để khôi phục lại các thông số (ngay cả khi không giết chết các đối tượng): tfl.fit(train_X, train_Y , load = True)

Tôi nhận được kết quả lạ. Trước hết, giá trị được nạp không tương ứng với giá trị đã lưu.

loading a session 
loaded b1: 0.1   <------- Loaded another value than saved 
Epoch: 50 cost = 30.8483 R^2 = -0.7670 
    b1: 0.137484 

Cách phù hợp để tải và có lẽ trước tiên kiểm tra các biến đã lưu là gì?

+0

tài liệu hướng dẫn lưu lượng không có các ví dụ khá cơ bản, bạn phải đào sâu trong các thư mục mẫu và ý thức về chủ yếu là tự mình – diffeomorphism

Trả lời

10

TL; DR: Bạn nên cố gắng làm lại lớp học này để self.create_network() được gọi (i) chỉ một lần và (ii) trước khi xây dựng tf.train.Saver().

Có hai vấn đề nhỏ ở đây, đó là do cấu trúc mã và hành vi mặc định của tf.train.Saver constructor. Khi bạn xây dựng một trình tiết kiệm không có đối số (như trong mã của bạn), nó sẽ thu thập tập các biến hiện tại trong chương trình của bạn và thêm các op vào biểu đồ để lưu và khôi phục chúng. Trong mã của bạn, khi bạn gọi tflasso(), nó sẽ tạo trình tiết kiệm và sẽ không có biến nào (vì create_network() chưa được gọi). Do đó, điểm kiểm tra phải trống.

Vấn đề thứ hai là — theo mặc định — định dạng của điểm kiểm tra đã lưu là bản đồ từ giá trị hiện tại name property of a variable. Nếu bạn tạo hai biến có cùng tên, họ sẽ tự động "uniquified" bởi TensorFlow:

v = tf.Variable(..., name="weights") 
assert v.name == "weights" 
w = tf.Variable(..., name="weights") 
assert v.name == "weights_1" # The "_1" is added by TensorFlow. 

Hậu quả của việc này là, khi bạn gọi self.create_network() trong cuộc gọi thứ hai để tfl.fit(), các biến sẽ tất cả đều có các tên khác nhau từ các tên được lưu trữ trong trạm kiểm soát — hoặc có thể là nếu trình tiết kiệm đã được xây dựng sau mạng. (Bạn có thể tránh hành vi này bằng cách thông qua một name Variable từ điển để các nhà xây dựng tiết kiệm, nhưng điều này thường là khá vụng về.)

Có hai cách giải quyết chính:

  1. Trong mỗi cuộc gọi đến tflasso.fit(), tạo toàn bộ mô hình được lặp lại, bằng cách xác định tf.Graph mới, thì trong biểu đồ đó tạo mạng và tạo một tf.train.Saver.

  2. ĐƯỢC KHUYẾN Tạo mạng, sau đó các tf.train.Saver trong constructor tflasso, và tái sử dụng biểu đồ này trên mỗi cuộc gọi đến tflasso.fit().Lưu ý rằng bạn có thể cần phải làm thêm một số công việc để tổ chức lại mọi thứ (đặc biệt, tôi không chắc chắn bạn làm gì với self.Xself.xlen) nhưng có thể đạt được điều này với placeholders và cho ăn.

+0

cảm ơn bạn! 'Xlen' được sử dụng trong' self._create_network() 'để thiết lập kích thước đầu vào của' X' (placeholder init: 'self.vars.xx = tf.placeholder (" float ", shape = [None, self.xlen ]) '). Từ những gì bạn nói, cách ưu tiên là chuyển 'xlen' thành bộ khởi tạo. –

+0

Có cách nào để đặt lại các biến tf cũ/độc đáo khi khởi tạo lại đối tượng không? –

+1

Để thực hiện điều đó, bạn cần tạo một 'tf.Graph' mới và đặt nó làm mặc định trước khi bạn (i) tạo mạng và (ii) tạo một' Tiết kiệm '. Nếu bạn quấn phần thân của 'tflasso.fit()' vào một 'bằng tf.Graph(). As_default():' block và di chuyển cấu trúc 'Saver' bên trong khối đó, các tên phải giống nhau mỗi lần bạn gọi 'fit()'. – mrry

Các vấn đề liên quan