2016-04-17 24 views
8

Các mô hình python được xuất khẩu chính xác như thế nào để sử dụng trong C++?Xuất đồ thị Tensorflow từ Python để sử dụng trong C++

Tôi đang cố gắng để làm một cái gì đó tương tự như hướng dẫn này: https://www.tensorflow.org/versions/r0.8/tutorials/image_recognition/index.html

Tôi đang cố gắng để nhập khẩu mô hình TF của riêng tôi trong C++ API thay cho sự ra đời một. Tôi đã điều chỉnh kích thước đầu vào và đường dẫn, nhưng các lỗi lạ vẫn tiếp tục xuất hiện. Tôi đã dành cả ngày đọc tràn ngăn xếp và các diễn đàn khác nhưng vô ích.

Tôi đã thử hai phương pháp để xuất đồ thị.

Phương pháp 1: đoạn văn.

...loading inputs, setting up the model, etc.... 

sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 


for i in range(num_steps): 
    x_batch, y_batch = batch(50) 
    if i%10 == 0: 
     train_accuracy = accuracy.eval(feed_dict={ 
     x:x_batch, y_: y_batch, keep_prob: 1.0}) 
     print("step %d, training accuracy %g"%(i, train_accuracy)) 
    train_step.run(feed_dict={x: x_batch, y_: y_batch, keep_prob: 0.5}) 

print("test accuracy %g"%accuracy.eval(feed_dict={ 
    x: features_test, y_: labels_test, keep_prob: 1.0})) 

saver = tf.train.Saver(tf.all_variables()) 
checkpoint = 
    '/home/sander/tensorflow/tensorflow/examples/cat_face/data/model.ckpt' 
    saver.save(sess, checkpoint) 

    tf.train.export_meta_graph(filename= 
    '/home/sander/tensorflow/tensorflow/examples/cat_face/data/cat_graph.pb', 
    meta_info_def=None, 
    graph_def=sess.graph_def, 
    saver_def=saver.restore(sess, checkpoint), 
    collection_list=None, as_text=False) 

Phương pháp 1 mang lại các lỗi sau khi cố gắng chạy chương trình:

[libprotobuf ERROR 
google/protobuf/src/google/protobuf/wire_format_lite.cc:532] String field 
'tensorflow.NodeDef.op' contains invalid UTF-8 data when parsing a protocol 
buffer. Use the 'bytes' type if you intend to send raw bytes. 
E tensorflow/examples/cat_face/main.cc:281] Not found: Failed to load 
compute graph at 'tensorflow/examples/cat_face/data/cat_graph.pb' 

Tôi cũng đã thử phương pháp khác xuất khẩu đồ thị:

Cách 2: write_graph:

tf.train.write_graph(sess.graph_def, 
'/home/sander/tensorflow/tensorflow/examples/cat_face/data/', 
'cat_graph.pb', as_text=False) 

Phiên bản này thực sự có vẻ như tải thứ gì đó, nhưng tôi gặp lỗi về các biến không được khởi tạo:

Running model failed: Failed precondition: Attempting to use uninitialized 
value weight1 
[[Node: weight1/read = Identity[T=DT_FLOAT, _class=["loc:@weight1"], 
_device="/job:localhost/replica:0/task:0/cpu:0"](weight1)]] 
+3

Có "Phương pháp 3: Sử dụng freeze_graph". Điều đó tránh phải sử dụng các Biến và chạy các lệnh khôi phục - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py –

+0

Ah, tôi đã thấy điều đó. Nhưng tôi đấu tranh để tìm hiểu làm thế nào để điền vào các đối số của nó, giống như tôi không biết những gì để điền vào cho mỗi đối số trong export_meta_graph. Bạn có biết một số mã ví dụ cho điều này không? – Sander

+1

Có một ví dụ ở đây: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py –

Trả lời

0

Lúc đầu, bạn cần phải vẽ đồ thị độ nét nộp bằng cách sử dụng lệnh sau

with tf.Session() as sess: 
//Build network here 
tf.train.write_graph(sess.graph.as_graph_def(), "C:\\output\\", "mymodel.pb") 

Sau đó, lưu mô hình của bạn bằng cách sử dụng tiết kiệm

saver = tf.train.Saver(tf.global_variables()) 
saver.save(sess, "C:\\output\\mymodel.ckpt") 

Sau đó, bạn sẽ có 2 tệp ở đầu ra của bạn, mymodel.ckpt, mymodel.pb

Tải xuống freeze_graph.py từ here và chạy lệnh sau trong C: \ output \. Thay đổi tên nút đầu ra nếu nó khác với bạn.

python freeze_graph.py --input_graph mymodel.pb --input_checkpoint mymodel.ckpt --output_node_names softmax/Reshape_1 --output_graph mymodelforc.pb

Bạn có thể sử dụng trực tiếp từ mymodelforc.pb C.

bạn có thể sử dụng mã C sau đây để tải file proto

#include "tensorflow/core/public/session.h" 
#include "tensorflow/core/platform/env.h" 
#include "tensorflow/cc/ops/image_ops.h" 

Session* session; 
NewSession(SessionOptions(), &session); 

GraphDef graph_def; 
ReadBinaryProto(Env::Default(), "C:\\output\\mymodelforc.pb", &graph_def); 

session->Create(graph_def); 

Bây giờ bạn có thể sử dụng phiên cho suy luận.

Bạn có thể áp dụng thông số suy luận như sau:

// Same dimension and type as input of your network 
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, height, width, channel })); 
std::vector<tensorflow::Tensor> finalOutput; 

// Fill input tensor with your input data 

std::string InputName = "input"; // Your input placeholder's name 
std::string OutputName = "softmax/Reshape_1"; // Your output placeholder's name 

session->Run({ { InputName, input_tensor } }, { OutputName }, {}, &finalOutput); 

// finalOutput will contain the inference output that you search for 
+0

Deniz phiên bản tensorflow nào? Lý do tôi yêu cầu là hàm 'saver.save' Trong gói tôi đang sử dụng có vẻ như tạo ra một tệp' .ckpt.meta'. Tôi giả định đó là điều tương tự mà 'saver.export_meta_graph' tạo ra ... Nghiên cứu internet gần đây dường như chỉ ra rằng đây là sự khác biệt giữa R11 và R12, nhưng bạn đã viết điều này gần đây đến nỗi tôi tự hỏi bạn đang sử dụng phiên bản nào. – Geronimo

+0

Tôi đã xác minh mã này với nguồn hàng hiện tại được cập nhật. Tuy nhiên, các tệp ckpt.meta không cần thiết cho C++ chạy khi bạn đã xuất bằng cách sử dụng "write_graph". Tôi sẽ cập nhật mã để tránh nhầm lẫn –

Các vấn đề liên quan