Tôi muốn quản lý hoạt động đào tạo của mình với số tf.estimator.Estimator
nhưng gặp sự cố khi sử dụng nó cùng với API tf.data
.Làm thế nào để sử dụng các biến lặp khởi tạo của tf.data trong input_fn của một tf.estimator?
Tôi có một cái gì đó như thế này:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
Như tôi đã không thể sử dụng một make_one_shot_iterator
đối với trường hợp sử dụng của tôi, vấn đề của tôi là input_fn
chứa một iterator rằng nên được khởi tạo trong vòng model_fn
(ở đây, tôi sử dụng tf.train.Scaffold
để khởi tạo ops cục bộ).
Ngoài ra, tôi hiểu rằng chúng tôi không thể chỉ sử dụng input_fn = iterator.get_next
nếu không các ops khác sẽ không được thêm vào cùng một biểu đồ.
Cách được khuyến nghị để khởi tạo trình lặp là gì?
@guillaumeklin - bạn đã thêm 'tf.add_to_collection (tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) 'trong input_fn()? – reese0106
Có, bạn có thể thêm dòng này vào 'input_fn()' ngay trước 'return iterator.get_next()'. – guillaumekln