2015-02-24 17 views
15

Tôi không thể hiểu đầu ra của argmaxargmin khi sử dụng với thông số trục. Ví dụ:numpy: logic của hàm argmin() và argmax() là gì?

>>> a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]]) 
>>> a 
array([[ 1, 2, 4, 7], 
     [ 9, 88, 6, 45], 
     [ 9, 76, 3, 4]]) 
>>> a.shape 
(3, 4) 
>>> a.size 
12 
>>> np.argmax(a) 
5 
>>> np.argmax(a,axis=0) 
array([1, 1, 1, 1]) 
>>> np.argmax(a,axis=1) 
array([3, 1, 1]) 
>>> np.argmin(a) 
0 
>>> np.argmin(a,axis=0) 
array([0, 0, 2, 2]) 
>>> np.argmin(a,axis=1) 
array([0, 2, 2]) 

Như bạn có thể thấy, giá trị tối đa là điểm (1,1) và điểm tối thiểu là điểm (0,0). Vì vậy, trong logic của tôi khi tôi chạy:

  • np.argmin(a,axis=0) tôi mong đợi array([0,0,0,0])
  • np.argmin(a,axis=1) tôi mong đợi array([0,0,0])
  • np.argmax(a,axis=0) tôi mong đợi array([1,1,1,1])
  • np.argmax(a,axis=1) tôi mong đợi array([1,1,1])

Điều gì là sai với tôi hiểu biết về mọi thứ?

Trả lời

22

Bằng cách thêm đối số axis, NumPy sẽ xem xét các hàng và cột riêng lẻ. Khi nó không được đưa ra, mảng a được san phẳng thành một mảng 1D duy nhất.

axis=0 có nghĩa là hoạt động được thực hiện xuống các cột của mảng 2D a lần lượt.

Ví dụ np.argmin(a, axis=0) trả về chỉ mục giá trị nhỏ nhất trong mỗi cột: giá trị tối thiểu của cột đầu tiên là chỉ số 0, giá trị tối thiểu của cột thứ ba là chỉ số 2 và tiếp tục.

axis=1 có nghĩa là thao tác được thực hiện trên các hàng của a.

Vì vậy, np.argmin(a, axis=1) trả về [0, 2, 2]a có ba hàng. Chỉ số giá trị nhỏ nhất trong hàng đầu tiên là 0, chỉ số giá trị tối thiểu của hàng thứ hai là 2, v.v.

+0

còn trục = -1 thì sao? –

+0

đã nhận nó, nó phải được kích thước cuối cùng, ở đây cho 2d nó là cột tôi đoán –

3

Chức năng np.argmax theo mặc định hoạt động along the flattened array, trừ khi bạn chỉ định trục. Để xem những gì đang xảy ra, bạn có thể sử dụng flatten một cách rõ ràng:

np.argmax(a) 
>>> 5 

a.flatten() 
>>>> array([ 1, 2, 4, 7, 9, 88, 6, 45, 9, 76, 3, 4]) 
      0 1 2 3 4 5 

Tôi đã đánh số các chỉ số dưới mảng trên để làm cho nó rõ ràng hơn. Lưu ý rằng các chỉ số được đánh số từ số không trong numpy.

Trong trường hợp bạn chỉ định trục, nó cũng đang làm việc như mong đợi:

np.argmax(a,axis=0) 
>>> array([1, 1, 1, 1]) 

này sẽ cho bạn biết rằng giá trị lớn nhất là ở hàng 1 (giá trị thứ 2), cho mỗi cột dọc axis=0 (giảm). Bạn có thể thấy điều này rõ ràng hơn nếu bạn thay đổi dữ liệu của bạn một chút:

a=np.array([[100,2,4,7],[9,88,6,45],[9,76,3,100]]) 
a 
>>> array([[100, 2, 4, 7], 
      [ 9, 88, 6, 45], 
      [ 9, 76, 3, 100]]) 

np.argmax(a, axis=0) 
>>> array([0, 1, 1, 2]) 

Như bạn có thể thấy nó bây giờ xác định giá trị lớn nhất trong hàng 0 cho cột 1, hàng 1 cho cột 2 và 3 và hàng 3 cho cột 4.

Có hướng dẫn hữu ích để numpy lập chỉ mục trong documentation.

0

Trục trong đối số hàm argmax, tham chiếu đến trục dọc theo đó mảng sẽ được cắt.

Trong một từ khác, np.argmin(a,axis=0) có hiệu quả giống như np.apply_along_axis(np.argmin, 0, a), đó là tìm ra vị trí tối thiểu cho các vectơ cắt lát này dọc theo trục = 0.

Vì vậy trong ví dụ của bạn, np.argmin(a, axis=0)[0, 0, 2, 2] mà tương ứng với giá trị của [1, 2, 3, 4] trên các cột tương ứng

+0

Cảm ơn bạn sir. Tôi upvote nhưng tôi chấp nhận một câu trả lời khác mà là quá rõ ràng cho con đường của tôi về sự hiểu biết. –

4

Như một mặt lưu ý: nếu bạn muốn tìm tọa độ của giá trị lớn nhất của bạn trong mảng đầy đủ, bạn có thể sử dụng

a=np.array([[1,2,4,7],[9,88,6,45],[9,76,3,4]]) 
>>> a 
[[ 1 2 4 7] 
[ 9 88 6 45] 
[ 9 76 3 4]] 

c=(np.argmax(a)/len(a[0]),np.argmax(a)%len(a[0])) 
>>> c 
(1, 1) 
+3

Hoặc đơn giản: 'np.unravel_index (np.argmax (a), a.shape)' –

0
""" ....READ THE COMMENTS FOR CLARIFICATION.....""" 

import numpy as np 
a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]]) 

"""np.argmax(a) will give index of max value in flatted array of given matrix """ 
>>np.arg(max) 
5 

"""np.argmax(a,axis=0) will return list of indexes of max value coloumnwise""" 
>>print(np.argmax(a,axis=0)) 
[1,1,1,1] 

"""np.argmax(a,axis=1) will return list of indexes of max value rowwise""" 
>>print(np.argmax(a,axis=1)) 
[3,1,1] 

"""np.argmin(a) will give index of min value in flatted array of given matrix """ 
>>np.arg(min) 
0 

"""np.argmin(a,axis=0) will return list of indexes of min value coloumnwise""" 
>>print(np.argmin(a,axis=0)) 
[0,0,2,2] 

"""np.argmin(a,axis=0) will return list of indexes of min value rowwise""" 
>>print(np.argmin(a,axis=1)) 
[0,2,2] 
Các vấn đề liên quan