2014-04-29 16 views
9

Tôi nhận thấy một hành vi không nhất quán trong numpy.dot khi có nan s và số không.Lỗi Numpy.dot? Hành vi NaN không nhất quán

Có ai có thể hiểu được không? Đây có phải là một lỗi? Điều này có dành riêng cho hàm dot không?

Tôi đang sử dụng phiên bản v1.6.1, 64bit, chạy trên Linux (cũng được thử nghiệm trên v1.6.2). Tôi cũng đã thử nghiệm trên v1.8.0 trên windows 32bit (vì vậy tôi không thể biết được sự khác biệt là do phiên bản hoặc hệ điều hành hay vòm).

from numpy import * 
0*nan, nan*0 
=> (nan, nan) # makes sense 

#1 
a = array([[0]]) 
b = array([[nan]]) 
dot(a, b) 
=> array([[ nan]]) # OK 

#2 -- adding a value to b. the first value in the result is 
#  not expected to be affected. 
a = array([[0]]) 
b = array([[nan, 1]]) 
dot(a, b) 
=> array([[ 0., 0.]]) # EXPECTED : array([[ nan, 0.]]) 
# (also happens in 1.6.2 and 1.8.0) 
# Also, as @Bill noted, a*b works as expected, but not dot(a,b) 

#3 -- changing a from 0 to 1, the first value in the result is 
#  not expected to be affected. 
a = array([[1]]) 
b = array([[nan, 1]]) 
dot(a, b) 
=> array([[ nan, 1.]]) # OK 

#4 -- changing shape of a, changes nan in result 
a = array([[0],[0]]) 
b = array([[ nan, 1.]]) 
dot(a, b) 
=> array([[ 0., 0.], [ 0., 0.]]) # EXPECTED : array([[ nan, 0.], [ nan, 0.]]) 
# (works as expected in 1.6.2 and 1.8.0) 

Case # 4 dường như được làm việc một cách chính xác trong v1.6.2 và v1.8.0, nhưng không phân # 2 ...


EDIT: @seberg chỉ ra đây là một vấn đề blas , vì vậy đây là các thông tin về quá trình cài đặt blas tôi tìm thấy bằng cách chạy from numpy.distutils.system_info import get_info; get_info('blas_opt'):

1.6.1 linux 64bit 
/usr/lib/python2.7/dist-packages/numpy/distutils/system_info.py:1423: UserWarning: 
    Atlas (http://math-atlas.sourceforge.net/) libraries not found. 
    Directories to search for the libraries can be specified in the 
    numpy/distutils/site.cfg file (section [atlas]) or by setting 
    the ATLAS environment variable. 
    warnings.warn(AtlasNotFoundError.__doc__) 
{'libraries': ['blas'], 'library_dirs': ['/usr/lib'], 'language': 'f77', 'define_macros': [('NO_ATLAS_INFO', 1)]} 

1.8.0 windows 32bit (anaconda) 
c:\Anaconda\Lib\site-packages\numpy\distutils\system_info.py:1534: UserWarning: 
    Blas (http://www.netlib.org/blas/) sources not found. 
    Directories to search for the sources can be specified in the 
    numpy/distutils/site.cfg file (section [blas_src]) or by setting 
    the BLAS_SRC environment variable. 
warnings.warn(BlasSrcNotFoundError.__doc__) 
{} 

(cá nhân tôi không biết phải làm gì với nó)

+1

Thật thú vị đối với trường hợp 2, 'a * b' cho kết quả mong muốn nhưng không phải là' np.dot (a, b) '. – wflynny

+3

Kết quả của dấu chấm phụ thuộc vào thư viện blas bạn đang sử dụng. Ví dụ, tôi nhìn thấy cùng với openblas (nhưng không phải với tập bản đồ), do đó, hoặc điều này là không xác định, hoặc một lỗi trong thư viện blas. Phép nhân không liên quan thực sự ... – seberg

+2

Hmm, hãy thử 'từ numpy.distutils.system_info nhập get_info; get_info ('blas_opt') ' – seberg

Trả lời

3

Tôi nghĩ rằng, như seberg đề nghị, đây là một vấn đề với thư viện BLAS được sử dụng. Nếu bạn nhìn vào cách numpy.dot được thực hiện herehere bạn sẽ tìm thấy một cuộc gọi đến cblas_dgemm() cho trường hợp ma trận ma trận thời gian ma trận kép chính xác.

Chương trình C này, sao chép một số ví dụ của bạn, cho ra cùng một đầu ra khi sử dụng BLAS "đơn giản" và câu trả lời đúng khi sử dụng ATLAS.

#include <stdio.h> 
#include <math.h> 

#include "cblas.h" 

void onebyone(double a11, double b11, double expectc11) 
{ 
    enum CBLAS_ORDER order=CblasRowMajor; 
    enum CBLAS_TRANSPOSE transA=CblasNoTrans; 
    enum CBLAS_TRANSPOSE transB=CblasNoTrans; 
    int M=1; 
    int N=1; 
    int K=1; 
    double alpha=1.0; 
    double A[1]={a11}; 
    int lda=1; 
    double B[1]={b11}; 
    int ldb=1; 
    double beta=0.0; 
    double C[1]; 
    int ldc=1; 

    cblas_dgemm(order, transA, transB, 
       M, N, K, 
       alpha,A,lda, 
       B, ldb, 
       beta, C, ldc); 

    printf("dot([ %.18g],[%.18g]) -> [%.18g]; expected [%.18g]\n",a11,b11,C[0],expectc11); 
} 

void onebytwo(double a11, double b11, double b12, 
       double expectc11, double expectc12) 
{ 
    enum CBLAS_ORDER order=CblasRowMajor; 
    enum CBLAS_TRANSPOSE transA=CblasNoTrans; 
    enum CBLAS_TRANSPOSE transB=CblasNoTrans; 
    int M=1; 
    int N=2; 
    int K=1; 
    double alpha=1.0; 
    double A[]={a11}; 
    int lda=1; 
    double B[2]={b11,b12}; 
    int ldb=2; 
    double beta=0.0; 
    double C[2]; 
    int ldc=2; 

    cblas_dgemm(order, transA, transB, 
       M, N, K, 
       alpha,A,lda, 
       B, ldb, 
       beta, C, ldc); 

    printf("dot([ %.18g],[%.18g, %.18g]) -> [%.18g, %.18g]; expected [%.18g, %.18g]\n", 
     a11,b11,b12,C[0],C[1],expectc11,expectc12); 
} 

int 
main() 
{ 
    onebyone(0, 0, 0); 
    onebyone(2, 3, 6); 
    onebyone(NAN, 0, NAN); 
    onebyone(0, NAN, NAN); 
    onebytwo(0, 0,0, 0,0); 
    onebytwo(2, 3,5, 6,10); 
    onebytwo(0, NAN,0, NAN,0); 
    onebytwo(NAN, 0,0, NAN,NAN); 
    return 0; 
} 

Output với BLAS:

dot([ 0],[0]) -> [0]; expected [0] 
dot([ 2],[3]) -> [6]; expected [6] 
dot([ nan],[0]) -> [nan]; expected [nan] 
dot([ 0],[nan]) -> [0]; expected [nan] 
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0] 
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10] 
dot([ 0],[nan, 0]) -> [0, 0]; expected [nan, 0] 
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan] 

Output với ATLAS:

dot([ 0],[0]) -> [0]; expected [0] 
dot([ 2],[3]) -> [6]; expected [6] 
dot([ nan],[0]) -> [nan]; expected [nan] 
dot([ 0],[nan]) -> [nan]; expected [nan] 
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0] 
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10] 
dot([ 0],[nan, 0]) -> [nan, 0]; expected [nan, 0] 
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan] 

BLAS dường như đã dự kiến ​​hành vi khi các toán hạng đầu tiên có một NaN, và sai khi toán hạng đầu tiên là số không và số thứ hai có NaN.

Dù sao, tôi không nghĩ lỗi này nằm trong lớp Numpy; nó ở trong BLAS. Có vẻ như có thể giải quyết bằng cách sử dụng ATLAS thay thế.

Được tạo ở trên trên Ubuntu 14.04, sử dụng gcc do Ubuntu cung cấp, BLAS và ATLAS.

Các vấn đề liên quan