Bạn nói đúng, tf.nn.batch_normalization
chỉ cung cấp chức năng cơ bản để thực hiện bình thường hóa hàng loạt. Bạn phải thêm logic bổ sung để theo dõi di chuyển phương tiện và phương sai trong quá trình đào tạo, và sử dụng các phương tiện được đào tạo và phương sai trong suy luận. Bạn có thể nhìn vào example này cho một thực hiện rất chung chung, nhưng một phiên bản nhanh chóng mà không sử dụng gamma
là ở đây:
beta = tf.Variable(tf.zeros(shape), name='beta')
moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
trainable=False)
moving_variance = tf.Variable(tf.ones(shape),
name='moving_variance',
trainable=False)
control_inputs = []
if is_training:
mean, variance = tf.nn.moments(image, [0, 1, 2])
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, self.decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, self.decay)
control_inputs = [update_moving_mean, update_moving_variance]
else:
mean = moving_mean
variance = moving_variance
with tf.control_dependencies(control_inputs):
return tf.nn.batch_normalization(
image, mean=mean, variance=variance, offset=beta,
scale=None, variance_epsilon=0.001)
Xem xét sử dụng các lớp được xác định trước từ apis cấp cao như 'tf.contrib .layers'. – danijar