2017-05-03 23 views
6

Đọc https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py cho chức năng average_gradients nhận xét sau đây được cung cấp: Note that this function provides a synchronization point across all towers. Có chức năng average_gradients cuộc gọi chặn và có nghĩa là synchronization point?Điểm đồng bộ hóa cation của Tensorflow

Tôi giả định đây là cuộc gọi chặn để tính trung bình của các gradient mà mỗi gradient phải được tính riêng lẻ? Nhưng đâu là mã khóa chờ đợi cho tất cả các tính toán gradient cá nhân?

Trả lời

6

Bản thân số average_gradients không phải là chức năng chặn. Nó có thể là một chức năng khác với các hoạt động tensorflow và điều này vẫn sẽ là một điểm đồng bộ hóa. Điều gì làm cho nó chặn là một thực tế rằng nó sử dụng các đối số tower_grads mà phụ thuộc vào tất cả các đồ thị được tạo ra trong vòng lặp trước đó.

Về cơ bản những gì xảy ra ở đây là việc tạo biểu đồ đào tạo. Đầu tiên, trong vòng lặp cho for i in xrange(FLAGS.num_gpus) một số đồ thị "chuỗi" được tạo. Mỗi trông như thế này:

mất tính toán -> tính toán gradient -> append để tower_grads

Mỗi của những đồ thị "đề" được gán cho một gpu khác nhau thông qua with tf.device('/gpu:%d' % i) và mỗi người có thể chạy độc lập với nhau (và sau đó sẽ chạy song song). Bây giờ, lần tiếp theo tower_grads được sử dụng mà không có thông số kỹ thuật của thiết bị, nó sẽ tạo ra một tiếp tục biểu đồ trên thiết bị chính, ràng buộc tất cả các "chuỗi" đồ thị riêng biệt đó thành một chuỗi duy nhất. Tensorflow sẽ đảm bảo rằng mọi đồ thị "thread" là một phần của việc tạo ra tower_grads được hoàn tất trước khi chạy biểu đồ bên trong hàm average_gradients. Vì vậy sau này khi sess.run([train_op, loss]) được gọi, đây sẽ là điểm đồng bộ của biểu đồ.

Các vấn đề liên quan