2017-03-17 36 views
9

Tôi đang sử dụng Keras để xây dựng và đào tạo mô hình của mình. Mô hình nhìn như thế này:Mô hình đông lạnh từ Keras không dự đoán sau khi khôi phục

inputs = Input(shape=(input_size, 3), dtype='float32', name='input') 
lstm1 = LSTM(128, return_sequences=True)(inputs) 
dropout1 = Dropout(0.5)(lstm1) 
lstm2 = LSTM(128)(dropout1) 
dropout2 = Dropout(0.5)(lstm2) 
outputs = Dense(output_size, activation='softmax', name='output')(dropout2) 

Ngay trước khi thực hiện một trạm kiểm soát mô hình của tôi có thể dự đoán các lớp học rất tốt (Lớp phân phối sau softmax):

[[ 0.00117011 0.00631532 0.10080294 0.84386677 0.04784485]] 

Tuy nhiên sau khi mã tiếp theo:

all_saver = tf.train.Saver() 
sess.run(tf.global_variables_initializer()) 
print save_path + '/model_predeploy.chkp' 
all_saver.save(sess, save_path + '/model_predeploy.chkp', meta_graph_suffix='meta', write_meta_graph=True) 
tf.train.write_graph(sess.graph_def, save_path, "model.pb", False) 

Và làm lạnh nó bằng cách sử dụng

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/Users/denisermolin/Work/Projects/MotionRecognitionTraining/model/graph/model.pb --input_checkpoint=/Users/denisermolin/Work/Projects/MotionRecognitionTraining/model/graph/model_predeploy.chkp --output_graph=/Users/denisermolin/Work/Projects/MotionRecognitionTraining/model/graph/output.pb --output_node_names=Softmax --input_binary=true 

Và tải nó sau khi

graph = load_graph(args.frozen_model_filename) 

    # We can verify that we can access the list of operations in the graph 
    for op in graph.get_operations(): 
     print(op.name) 

    # We access the input and output nodes 
    x = graph.get_tensor_by_name('input:0') 
    y = graph.get_tensor_by_name('Softmax:0') 

    data = [7.4768066E-4,-0.02217102,0.07727051,7.4768066E-4,-0.02217102,0.07727051,7.4768066E-4,-0.02217102,0.07727051,0.004989624,-0.020874023,0.09140015,0.004989624,-0.020874023,0.09140015,0.010604858,-0.010665894,0.025527954,0.010299683,0.018035889,-0.052749634,-0.012786865,0.017837524,-0.020828247,-0.045898438,0.007095337,0.01550293,-0.06680298,0.013702393,0.02687073,-0.061767578,0.026550291,-1.373291E-4,-0.036621094,0.041778564,-0.011276245,-0.042678833,0.054336548,0.036697388,-0.07182312,0.036483765,0.081726074,-0.08639526,0.041793823,0.07392883,-0.051788326,0.07649231,0.092178345,-0.056396484,0.0771637,0.11044311,-0.08444214,0.06201172,0.0920105,-0.12609863,0.06137085,0.104537964,-0.14356995,0.079071045,0.11187744,-0.17756653,0.08576965,0.16818237,-0.2379303,0.07879639,0.19819641,-0.2631073,0.13290405,0.19137573,-0.23666382,0.21955872,0.16033936,-0.23666382,0.21955872,0.16033936,-0.22547913,0.23838806,0.27246094,-0.26376343,0.19580078,0.33566284,-0.26376343,0.19580078,0.33566284,-0.4733429,0.19911194,-0.0050811768,-0.48905945,0.14544678,-0.21205139,-0.48905945,0.14544678,-0.21205139,-0.37893677,0.15655518,-0.1382904,-0.27426147,0.16381836,-0.052841187,-0.21949767,0.18780518,-0.045913696,-0.28207397,0.17993164,-0.1550293,-0.37120056,0.13322449,-0.4617462,-0.3773346,0.17321777,-0.7678375,-0.20349121,0.12588501,-0.7908478,-4.8828125E-4,0.116516106,-0.57121277,-0.090042114,0.08895874,-0.3849945,-0.232193,-0.028884886,-0.4724579,-0.19163513,-0.06340027,-0.5598297,-0.068481445,-0.025268555,-0.54397583,-0.03288269,-0.12750244,-0.48367307,0.0057525635,-0.030532837,-0.45234683,0.099868774,-0.0070648193,-0.57225037,0.21514893,0.05860901,-0.5052185,0.3602295,0.14176941,-0.4087372,0.57940674,0.16700745,-0.35438538,0.75743103,0.2631073,-0.5294647,0.75743103,0.2631073,-0.5294647,0.74624634,0.2193451,-0.70674133,0.91960144,0.29077148,-0.7026367,0.91960144,0.29077148,-0.7026367,0.81611633,0.34953308,-0.50927734,0.8429718,0.41278076,-0.38298035,0.84576416,0.4597778,-0.15159607,0.9177856,0.47735596,0.099731445,0.9820862,0.57232666,0.20970154,0.9269562,0.5357971,0.45666504,0.7898865,0.48097226,0.5698242,0.5332794,0.4213867,0.6626892,0.5032654,0.4464111,0.59614563,0.5827484,0.4588318,0.8383636,0.60975647,0.46882626,1.050766,0.58917236,0.52201843,0.9510345,0.48217773,0.502121,0.8063202,0.24050903,0.42752075,0.81951904,0.10655212,0.43006897,0.7798157,0.15496826,0.5040283,0.7533417,0.18733215,0.55770874,0.63716125,0.22062683,0.5880585,0.503067,0.06762695,0.49337766,0.6584778,-0.14086914,0.4414215,0.615036,-0.14086914,0.4414215,0.615036,-0.03614807,0.6751251,0.06636047,-0.03614807,0.6751251,0.06636047,0.17774963,0.741272,-0.09466553,0.21842958,0.7971039,-0.050811768,0.06843567,0.7729645,-0.34933472,-0.2092285,0.5443878,-0.5428009,-0.43028256,0.37249756,-0.5168762,-0.23457338,0.3491211,-0.45985416,0.15863037,0.49960327,-0.5370636,0.31782532,0.5680084,-0.8007355,0.1651001,0.5300598,-0.87919617,-0.086135864,0.49140927,-0.6066437,-0.20877077,0.4261017,-0.55911255,-0.33840942,0.34194946,-0.7007904,-0.36250305,0.27163696,-0.76208496,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] 
    data = np.reshape(data, [1, 142, 3]) 

    # We launch a Session 
    with tf.Session(graph=graph) as sess: 
     # Note: we didn't initialize/restore anything, everything is stored in the graph_def 
     y_out = sess.run(y, feed_dict={ 
      x: data 
     }) 
     print(y_out) 

Cho tôi phân bố đồng đều trên tất cả các nhãn:

[[ 0.20328824 0.19835895 0.19692752 0.20159255 0.19983278]] 

Am i làm gì sai? Sử dụng tensorflow 0.12 vì tôi không thể thực hiện 1.0 chạy trên Android. Đây là một câu chuyện khác. Mọi thứ được xây dựng, đào tạo và xuất khẩu với 0.12.

+0

Kiểm tra [câu trả lời này] (http://stackoverflow.com/questions/37704362/tensorflow-freeze-graph-script-failing-on-model-defined-with-keras?rq=1). –

Trả lời

5
all_saver = tf.train.Saver() 
sess.run(tf.global_variables_initializer()) 
print save_path + '/model_predeploy.chkp' 
all_saver.save(sess, save_path + '/model_predeploy.chkp', meta_graph_suffix='meta', write_meta_graph=True) 
tf.train.write_graph(sess.graph_def, save_path, "model.pb", False) 

Trong dòng 2, bạn lại khởi tất cả biến từ đầu (không chỉ những người chưa được khởi tạo). Điều này có nghĩa là mô hình được đào tạo của bạn đã biến mất tại thời điểm đó và bạn lưu một mô hình chỉ là trọng số ngẫu nhiên/không đổi (tùy thuộc vào trình khởi tạo của bạn).

Demo kịch bản:

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import tensorflow as tf 
import numpy as np 

var = tf.get_variable('demo', dtype=tf.float32, shape=[], 
         initializer=tf.zeros_initializer()) 

sess = tf.Session() 

sess.run(tf.assign(var, 42)); 

print(var.eval(session=sess)) 

này in 42.

sess.run(tf.global_variables_initializer()) 

print(var.eval(session=sess)) 

này in 0, như biến đã được tái khởi tạo là 0.

Vì vậy, khởi tạo các biến của bạn trước bạn đào tạo mô hình của mình và không khởi tạo lại mô hình trước khi viết chúng ra.

+0

Đúng, hãy tìm ra, thx –

Các vấn đề liên quan