2017-08-22 18 views
5

Tôi cần phải thực hiện tích hợp số trong 6D trong python. Bởi vì hàm scipy.integrate.nquad là chậm Tôi hiện đang cố gắng tăng tốc mọi thứ bằng cách định nghĩa tích phân là scipy.LowLevelCallable với Numba.Làm thế nào bạn có thể thực hiện một C có thể gọi từ Numba để tích hợp hiệu quả với nquad?

tôi đã có thể làm điều này trong 1D với scipy.integrate.quad bằng cách tái tạo ví dụ đưa ra here:

import numpy as np 
from numba import cfunc 
from scipy import integrate 

def integrand(t): 
    return np.exp(-t)/t**2 

nb_integrand = cfunc("float64(float64)")(integrand) 

# regular integration 
%timeit integrate.quad(integrand, 1, np.inf) 

10000 vòng, tốt nhất là 3: 128 ms mỗi vòng lặp

# integration with compiled function 
%timeit integrate.quad(nb_integrand.ctypes, 1, np.inf) 

100000 vòng, tốt nhất là 3: 7,08 µs trên mỗi vòng

Khi tôi muốn làm điều này ngay bây giờ với nquad, tài liệu nquad cho biết:

If the user desires improved integration performance, then f may be a scipy.LowLevelCallable with one of the signatures:

double func(int n, double *xx) 
double func(int n, double *xx, void *user_data) 

where n is the number of extra parameters and args is an array of doubles of the additional parameters, the xx array contains the coordinates. The user_data is the data contained in the scipy.LowLevelCallable.

Nhưng đoạn mã sau mang lại cho tôi một lỗi:

from numba import cfunc 
import ctypes 

def func(n_arg,x): 
    xe = x[0] 
    xh = x[1] 
    return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh) 

nb_func = cfunc("float64(int64,CPointer(float64))")(func) 

integrate.nquad(nb_func.ctypes, [[0,1],[0,1]], full_output=True) 

lỗi: quad: số đầu tiên là một con trỏ hàm ctypes có chữ ký không chính xác

Có thể biên dịch một hàm với numba rằng có thể được sử dụng với nquad trực tiếp trong mã và không định nghĩa hàm trong một tệp bên ngoài?

Cảm ơn bạn rất nhiều trước!

Trả lời

3

Bao bì các chức năng trong một scipy.LowLevelCallable làm nquad hạnh phúc:

si.nquad(sp.LowLevelCallable(nb_func.ctypes), [[0,1],[0,1]], full_output=True) 
# (-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323}) 
Các vấn đề liên quan