2016-08-14 24 views
61

Từ những gì tôi đã thu thập cho đến nay, có một số cách khác nhau để bán đồ thị TensorFlow vào một tệp và sau đó tải nó vào một chương trình khác, nhưng tôi đã không thể để tìm các ví dụ/thông tin rõ ràng về cách chúng hoạt động.TensorFlow lưu vào/tải một biểu đồ từ một tệp

  1. Save the của mô hình biến thành một file checkpoint (.ckpt) sử dụng một tf.train.Saver() và khôi phục chúng sau này (source)
  2. Lưu một mô hình vào một tập tin .pb và tải nó trở lại: Những gì tôi đã biết đây là trong việc sử dụng tf.train.write_graph()tf.import_graph_def() (source)
  3. tải trong một mô hình từ một tập tin .pb, đào tạo lại nó, và đổ nó vào một tập tin .pb mới sử dụng Bazel (source)
  4. Freeze biểu đồ để lưu đồ thị và trọng lượng cùng nhau (source)
  5. Sử dụng as_graph_def() để lưu các mô hình, và cho trọng lượng/biến, bản đồ chúng vào các hằng số (source)

Tuy nhiên, tôi đã không thể làm sáng tỏ một số câu hỏi liên quan đến các phương pháp khác nhau:

  1. Về tệp điểm kiểm tra, chúng có chỉ lưu trọng số được đào tạo của một mô hình không? Các tệp điểm kiểm tra có thể được tải vào một chương trình mới và được sử dụng để chạy mô hình hay chúng chỉ đơn giản là cách lưu các trọng số trong một mô hình tại một thời điểm/giai đoạn nhất định?
  2. Về tf.train.write_graph(), cũng là các trọng số/biến được lưu không?
  3. Về Bazel, nó có thể lưu vào/tải từ tệp .pb để đào tạo lại không? Có một lệnh Bazel đơn giản chỉ để đổ một đồ thị vào một .pb?
  4. Về việc đóng băng, biểu đồ được cố định có được tải bằng cách sử dụng tf.import_graph_def() không?
  5. Bản trình diễn Android cho tải TensorFlow trong mô hình Khởi động của Google từ tệp .pb. Nếu tôi muốn thay thế tệp .pb của riêng tôi, tôi sẽ làm thế nào? Tôi có cần thay đổi bất kỳ mã/phương thức gốc nào không?
  6. Nói chung, sự khác biệt chính xác giữa tất cả các phương pháp này là gì? Hoặc rộng hơn, sự khác biệt giữa as_graph_def() /.ckpt/.pb là gì?

Tóm lại, những gì tôi đang tìm là một phương pháp để lưu cả biểu đồ (như trong, các hoạt động khác nhau và như vậy) và trọng số/biến của nó vào tệp, sau đó có thể được sử dụng để tải biểu đồ và trọng số vào một chương trình khác, để sử dụng (không nhất thiết phải tiếp tục/đào tạo lại).

Tài liệu về chủ đề này không đơn giản, vì vậy mọi câu trả lời/thông tin sẽ được đánh giá cao.

+2

API mới nhất/đầy đủ nhất là biểu đồ meta, sẽ cung cấp cho bạn cách lưu cả ba cùng một lúc - 1) đồ thị 2) giá trị tham số 3) bộ sưu tập: https: //www.tensorflow. org/versions/r0.10/how_tos/meta_graph/index.html –

Trả lời

53

Có nhiều cách để tiếp cận vấn đề tiết kiệm mô hình trong TensorFlow, điều này có thể làm cho nó hơi khó hiểu. Lấy mỗi câu hỏi phụ của bạn lần lượt:

  1. Các file checkpoint (sản xuất ví dụ bằng cách gọi saver.save() trên một đối tượng tf.train.Saver) chỉ chứa trọng lượng, và bất kỳ biến khác quy định tại cùng một chương trình. Để sử dụng chúng trong một chương trình khác, bạn phải tạo lại cấu trúc biểu đồ liên quan (ví dụ: bằng cách chạy mã để tạo lại cấu trúc đó hoặc gọi tf.import_graph_def()), cho TensorFlow biết phải làm gì với những trọng số đó.Lưu ý rằng việc gọi số saver.save() cũng tạo tệp có chứa MetaGraphDef, chứa biểu đồ và chi tiết về cách liên kết trọng số từ điểm kiểm tra với biểu đồ đó. Xem the tutorial để biết thêm chi tiết.

  2. tf.train.write_graph() chỉ ghi cấu trúc biểu đồ; không phải trọng lượng.

  3. Bazel không liên quan đến việc đọc hoặc viết đồ thị TensorFlow. (Có lẽ tôi hiểu lầm câu hỏi của bạn: cảm thấy tự do để làm rõ nó trong một bình luận.)

  4. Một đồ thị đông lạnh có thể được nạp bằng cách sử dụng tf.import_graph_def(). Trong trường hợp này, các trọng số (thường) được nhúng trong biểu đồ, do đó bạn không cần phải tải một trạm kiểm soát riêng biệt.

  5. Thay đổi chính là cập nhật tên của (các) tensor được đưa vào mô hình và tên của (các) tensor được lấy từ mô hình. Trong bản demo Android TensorFlow, điều này sẽ tương ứng với các chuỗi inputNameoutputName được chuyển đến TensorFlowClassifier.initializeTensorFlow().

  6. GraphDef là cấu trúc chương trình, thường không thay đổi thông qua quá trình đào tạo. Điểm kiểm tra là ảnh chụp nhanh trạng thái của quá trình đào tạo, thường thay đổi ở mọi bước của quá trình đào tạo. Kết quả là, TensorFlow sử dụng các định dạng lưu trữ khác nhau cho các loại dữ liệu này và API cấp thấp cung cấp các cách khác nhau để lưu và tải chúng. Thư viện cấp cao hơn, chẳng hạn như thư viện MetaGraphDef, Kerasskflow xây dựng trên các cơ chế này để cung cấp các cách thuận tiện hơn để lưu và khôi phục toàn bộ mô hình.

+0

Điều này có nghĩa là [tài liệu API C++] (https://www.tensorflow.org/versions/r0.11/api_docs/cc/index.html) nằm, khi nó nói rằng bạn có thể tải các đồ thị được lưu với 'tf.train.write_graph()' và sau đó thực hiện nó? – mnicky

+2

Tài liệu API C++ không nói dối, nhưng thiếu một vài chi tiết. Chi tiết quan trọng nhất là, ngoài 'GraphDef' được lưu bởi' tf.train.write_graph() ', bạn cũng cần phải nhớ tên của các tensors mà bạn muốn nạp và tìm nạp khi thực hiện biểu đồ (mục 5 ở trên). – mrry

+0

@mrry: Tôi đã cố gắng sử dụng ví dụ DeepDream tensorflows. nhưng có vẻ như nó cần mô hình pretrained ở định dạng pb! Tôi đã chạy ví dụ Cifar10, nhưng nó chỉ tạo ra các trạm kiểm soát! Tôi không thể tìm thấy bất kỳ tập tin pb hoặc bất cứ điều gì! làm thế nào tôi có thể chuyển đổi các trạm kiểm soát của tôi sang định dạng pb mà ví dụ sâu sắc sử dụng? – Breeze

Các vấn đề liên quan