5

Tôi muốn quản lý hoạt động đào tạo của mình với số tf.estimator.Estimator nhưng gặp sự cố khi sử dụng nó cùng với API tf.data.Làm thế nào để sử dụng các biến lặp khởi tạo của tf.data trong input_fn của một tf.estimator?

Tôi có một cái gì đó như thế này:

def model_fn(features, labels, params, mode): 
    # Defines model's ops. 
    # Initializes with tf.train.Scaffold. 
    # Returns an tf.estimator.EstimatorSpec. 

def input_fn(): 
    dataset = tf.data.TextLineDataset("test.txt") 
    # map, shuffle, padded_batch, etc. 

    iterator = dataset.make_initializable_iterator() 

    return iterator.get_next() 

estimator = tf.estimator.Estimator(model_fn) 
estimator.train(input_fn) 

Như tôi đã không thể sử dụng một make_one_shot_iterator đối với trường hợp sử dụng của tôi, vấn đề của tôi là input_fn chứa một iterator rằng nên được khởi tạo trong vòng model_fn (ở đây, tôi sử dụng tf.train.Scaffold để khởi tạo ops cục bộ).

Ngoài ra, tôi hiểu rằng chúng tôi không thể chỉ sử dụng input_fn = iterator.get_next nếu không các ops khác sẽ không được thêm vào cùng một biểu đồ.

Cách được khuyến nghị để khởi tạo trình lặp là gì?

+0

@guillaumeklin - bạn đã thêm 'tf.add_to_collection (tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) 'trong input_fn()? – reese0106

+0

Có, bạn có thể thêm dòng này vào 'input_fn()' ngay trước 'return iterator.get_next()'. – guillaumekln

Trả lời

7

Tính đến TensorFlow 1.5, nó có thể làm cho input_fn trở lại một tf.data.Dataset, ví dụ .:

def input_fn(): 
    dataset = tf.data.TextLineDataset("test.txt") 
    # map, shuffle, padded_batch, etc. 
    return dataset 

Xem c294fcfd.


Đối với các phiên bản trước, bạn có thể thêm initializer của iterator trong tf.GraphKeys.TABLE_INITIALIZERS bộ sưu tập và dựa vào initializer mặc định.

tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) 
+0

Cảm ơn! +1. Chỉ cần làm rõ câu trả lời: cần thêm dòng 'tf.add_to_collection ...' trước khi trả về 'input_fn()' và sau đó nó hoạt động tốt và không cần phải làm gì với 'Scaffold' và' local_init_ops'. – Pekka

Các vấn đề liên quan