2017-07-10 17 views
5

Tôi muốn vượt qua một mảng multidimensionnal vào reluprime chức năngReLU Thủ với NumPy mảng

def reluprime(x): 
    if x > 0: 
     return 1 
    else: 
     return 0 

... nơi x là toàn bộ mảng. Nó trả

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Tôi đã có vấn đề này với relu chức năng bình thường, và thay vì sử dụng các chức năng python max() tôi đã sử dụng np.max() và nó làm việc. Nhưng với nguyên tố relu, nó không hoạt động theo cách nào cả. Tôi đã thử:

def reluprime(x): 
    if np.greater(x, 0): 
     return 1 
    else: 
     return 0 

... và nó vẫn trả về cùng lỗi. Làm thế nào tôi có thể sửa lỗi này? Cảm ơn bạn.

+2

vấn đề của bạn ở đây là Câu lệnh 'if' không hoạt động với mã vectơ. Toán tử '>' là tốt – Eric

Trả lời

5

Câu lệnh if không có ý nghĩa vì nó chỉ được đánh giá một lần, cho toàn bộ mảng. Nếu bạn muốn tương đương với một câu lệnh if cho mỗi phần tử của mảng, bạn nên làm một cái gì đó như:

def reluprime(x): 
    return np.where(x > 0, 1.0, 0.0) 
+3

Tốc độ này chậm hơn 3 lần so với câu trả lời của Miriam, nhưng thực sự là vector hóa tự nhiên nhất của mã gốc, và nói chung là cách tiếp cận đúng để dịch câu lệnh 'if'. – Eric

6

Kể từ relu lợi nhuận thủ 1 nếu một mục trong một vector lớn hơn 0 và 0 nếu ngược lại, bạn có thể chỉ cần làm:

def reluprime(x): 
    return (x>0).astype(x.dtype) 

trong đoạn mã trên, các mảng đầu vào x được giả định là một mảng numPy. Ví dụ: reluprime(np.array([-1,1,2])) trả về array([0, 1, 1]).

+0

'astype (x.dtype)' có khả năng hữu ích hơn ở đây – Eric

+0

@Eric, cảm ơn, điều đó có ý nghĩa! Tôi đã sửa đổi câu trả lời. –

+0

Vì trong ngữ cảnh học máy, bạn thường sẽ sử dụng 'float32' hoặc' float64' và sẽ muốn có kết quả chính xác giống nhau – Eric

4

"relu prime" hoặc gradient của hàm ReLU, được gọi là "hàm bước đột ngột".

NumPy 1,13 giới thiệu một ufunc cho việc này:

def reluprime(x): 
    return np.heaviside(x, 0) 
    # second value is value at x == 0 
    # note that ReLU is not differentiable at x==0, so there is no right value to 
    # pass here 

Thời gian kết quả trên cho thấy máy tính của tôi này để thực hiện khá kém, cho thấy nhiều việc phải làm ở đó:

In [1]: x = np.random.randn(100000) 

In [2]: %timeit np.heaviside(x, 0) #mine 
1.31 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 

In [3]: %timeit np.where(x > 0, 1.0, 0.0) # Jonas Adler's 
658 µs ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 

In [4]: %timeit (x>0).astype(x.dtype) # Miriam Farber's 
172 µs ± 34.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)