2017-09-26 23 views
8
def train(): 
# Model 
model = Model() 

# Loss, Optimizer 
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step') 
loss_fn = model.loss() 
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step) 

# Summaries 
summary_op = summaries(model, loss_fn) 

with tf.Session(config=TrainConfig.session_conf) as sess: 

    # Initialized, Load state 
    sess.run(tf.global_variables_initializer()) 
    model.load_state(sess, TrainConfig.CKPT_PATH) 

    writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph) 

    # Input source 
    data = Data(TrainConfig.DATA_PATH) 

    loss = Diff() 
    for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP): 

      mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step) 

      mixed_spec = to_spectrogram(mixed_wav) 
      mixed_mag = get_magnitude(mixed_spec) 

      src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav) 
      src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec) 

      src1_batch, _ = model.spec_to_batch(src1_mag) 
      src2_batch, _ = model.spec_to_batch(src2_mag) 
      mixed_batch, _ = model.spec_to_batch(mixed_mag) 

      # Initializae our callback. 
      #early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5) 


      l, _, summary = sess.run([loss_fn, optimizer, summary_op], 
            feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch, 
               model.y_src2: src2_batch}) 

      loss.update(l) 
      print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value)) 

      writer.add_summary(summary, global_step=step) 

      # Save state 
      if step % TrainConfig.CKPT_STEP == 0: 
       tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step) 

    writer.close() 

Tôi có mã mạng neural này tách nhạc khỏi giọng nói trong tệp .wav. làm cách nào tôi có thể giới thiệu thuật toán dừng sớm để dừng phần đào tạo? Tôi thấy một số dự án nói về một ValidationMonitor. Ai đó có thể giúp tôi?cách triển khai dừng sớm trong lưu lượng truy cập

Trả lời

0

Xác thựcMonitor được đánh dấu là không dùng nữa. nó không được khuyến khích. nhưng bạn vẫn có thể sử dụng nó. đây là một ví dụ về cách tạo:

validation_monitor = monitors.ValidationMonitor(
     input_fn=functools.partial(input_fn, subset="evaluation"), 
     eval_steps=128, 
     every_n_steps=88, 
     early_stopping_metric="accuracy", 
     early_stopping_rounds = 1000 
    ) 

và bạn có thể thực hiện một mình, ở đây tôi thực hiện của tôi:

  if (loss_value < self.best_loss): 
      self.stopping_step = 0 
      self.best_loss = loss_value 
      else: 
      self.stopping_step += 1 
      if self.stopping_step >= FLAGS.early_stopping_step: 
      self.should_stop = True 
      print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value)) 
      run_context.request_stop() 
Các vấn đề liên quan