2016-04-10 46 views
14

Tôi mới dùng Python, và tôi đang học TensorFlow. Trong hướng dẫn sử dụng bộ dữ liệu notMNIST, chúng cung cấp mã ví dụ để chuyển ma trận nhãn thành mảng được mã hóa một-n.Hiểu == áp dụng cho một mảng NumPy

Mục đích là để có một mảng gồm các số nguyên nhãn 0 ... 9, và trả về một ma trận trong đó mỗi số nguyên đã được chuyển thành một-of-n mảng mã hóa như thế này:

0 -> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0] 
1 -> [0, 1, 0, 0, 0, 0, 0, 0, 0, 0] 
2 -> [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] 
... 

Mã họ đưa ra để thực hiện điều này là:

# Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...] 
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32) 

Tuy nhiên, tôi không hiểu mã này thực sự như thế nào. Dường như nó chỉ tạo ra một mảng các số nguyên trong khoảng 0 đến 9, và sau đó so sánh với các ma trận nhãn, và chuyển đổi kết quả thành một phao. Làm thế nào để kết quả của toán tử == trong một ma trận mã hóa một trong các mã hoá ?

Trả lời

22

Có một vài điều đang diễn ra ở đây: hoạt động vector của numpy, thêm trục đơn và phát sóng.

Trước tiên, bạn sẽ có thể thấy cách thực hiện phép thuật ==.

Giả sử chúng ta bắt đầu với một mảng nhãn đơn giản. == hoạt động theo kiểu vectơ hóa, có nghĩa là chúng ta có thể so sánh toàn bộ mảng với vô hướng và nhận được một mảng bao gồm các giá trị của từng so sánh từng phần tử. Ví dụ:

Đầu tiên chúng tôi nhận được một mảng boolean, sau đó chúng tôi ép buộc để nổi: False == 0 in Python và True == 1. Vì vậy, chúng tôi kết thúc với một mảng là 0 trong đó labels không bằng 0 và 1 ở đâu.

Nhưng không có gì đặc biệt về so với 0 là, chúng ta có thể so sánh với 1 hoặc 2 hoặc 3 thay thế cho kết quả tương tự:

>>> (labels == 2).astype(np.float32) 
array([ 0., 1., 0., 0., 1.], dtype=float32) 

Trong thực tế, chúng ta có thể lặp qua tất cả các nhãn có thể và tạo ra mảng này. Chúng tôi có thể sử dụng một listcomp:

>>> np.array([(labels == i).astype(np.float32) for i in np.arange(3)]) 
array([[ 0., 0., 1., 1., 0.], 
     [ 1., 0., 0., 0., 0.], 
     [ 0., 1., 0., 0., 1.]], dtype=float32) 

nhưng điều này thực sự không tận dụng lợi thế. Những gì chúng ta muốn làm là có mỗi nhãn có thể so sánh với mỗi phần tử, IOW để so sánh

>>> np.arange(3) 
array([0, 1, 2]) 

với

>>> labels 
array([1, 2, 0, 0, 2]) 

Và đây là nơi mà sự kỳ diệu của truyền NumPy đến. Ngay bây giờ, labels là một Đối tượng hình dạng 1 chiều (5,). Nếu chúng ta tạo thành đối tượng hình dạng 2 chiều (5,1) thì thao tác sẽ "phát sóng" trên trục cuối cùng và chúng ta sẽ nhận được kết quả hình dạng (5,3) với kết quả so sánh từng mục nhập trong phạm vi với từng phần tử của nhãn.

Đầu tiên chúng ta có thể thêm một "thêm" trục labels sử dụng None (hoặc np.newaxis), thay đổi hình dạng của nó:

>>> labels[:,None] 
array([[1], 
     [2], 
     [0], 
     [0], 
     [2]]) 
>>> labels[:,None].shape 
(5, 1) 

Và sau đó chúng ta có thể làm cho việc so sánh (đây là transpose của sự sắp xếp chúng tôi nhìn vào trước đó, nhưng điều đó không quan trọng).

>>> np.arange(3) == labels[:,None] 
array([[False, True, False], 
     [False, False, True], 
     [ True, False, False], 
     [ True, False, False], 
     [False, False, True]], dtype=bool) 
>>> (np.arange(3) == labels[:,None]).astype(np.float32) 
array([[ 0., 1., 0.], 
     [ 0., 0., 1.], 
     [ 1., 0., 0.], 
     [ 1., 0., 0.], 
     [ 0., 0., 1.]], dtype=float32) 

Phát sóng trong phần mềm là rất mạnh mẽ và đáng để đọc.

+0

Một lời giải thích rất chi tiết & thoải mái. Hầu hết những người đi qua khóa học Sâu sắc của Udacity, chắc hẳn đã vấp phải câu trả lời này. – AgentX

0

Tóm lại, == được áp dụng cho mảng có nhiều mảng có nghĩa là áp dụng thành phần khôn ngoan == cho mảng. Kết quả là một mảng các boolean. Dưới đây là một ví dụ:

>>> b = np.array([1,0,0,1,1,0]) 
>>> b == 1 
array([ True, False, False, True, True, False], dtype=bool) 

Để đếm biết có bao nhiêu 1s có trong b, bạn không cần phải cast mảng nổi, tức là .astype(np.float32) có thể được lưu lại, bởi vì trong python boolean là một lớp con của int và trong Python 3 bạn có True == 1 False == 0. Vì vậy, đây là cách bạn đếm có bao nhiêu người là trong b:

>>> np.sum((b == 1)) 
3 

Hoặc:

>>> np.count_nonzero(b == 1) 
3 
Các vấn đề liên quan