Tôi đang cố gắng bọc chức năng LAPACK dgtsv
(bộ giải cho các hệ phương trình tridiagonal) sử dụng Cython.Gói chức năng LAPACKE bằng cách sử dụng Cython
Tôi đã xem qua this previous answer, nhưng kể từ dgtsv
không phải là một trong các hàm LAPACK được gói trong scipy.linalg
Tôi không nghĩ rằng tôi có thể sử dụng phương pháp tiếp cận cụ thể này. Thay vào đó tôi đã cố gắng theo dõi this example.
Dưới đây là nội dung của tập tin lapacke.pxd
tôi:
ctypedef int lapack_int
cdef extern from "lapacke.h" nogil:
int LAPACK_ROW_MAJOR
int LAPACK_COL_MAJOR
lapack_int LAPACKE_dgtsv(int matrix_order,
lapack_int n,
lapack_int nrhs,
double * dl,
double * d,
double * du,
double * b,
lapack_int ldb)
... đây là wrapper Cython mỏng của tôi trong _solvers.pyx
:
#!python
cimport cython
from lapacke cimport *
cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU,
double[:, ::1] B):
cdef:
lapack_int n = D.shape[0]
lapack_int nrhs = B.shape[1]
lapack_int ldb = B.shape[0]
double * dl = &DL[0]
double * d = &D[0]
double * du = &DU[0]
double * b = &B[0, 0]
lapack_int info
info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb)
return info
... và đây là một Python wrapper và kịch bản thử nghiệm:
import numpy as np
from scipy import sparse
from cymodules import _solvers
def trisolve_lapacke(dl, d, du, b, inplace=False):
if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1
or b.shape != d.shape):
raise ValueError('Invalid diagonal shapes')
if b.ndim == 1:
# b is (LDB, NRHS)
b = b[:, None]
# be sure to force a copy of d and b if we're not solving in place
if not inplace:
d = d.copy()
b = b.copy()
# this may also force copies if arrays are improperly typed/noncontiguous
dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64)
for v in (dl, d, du, b))
# b will now be modified in place to contain the solution
info = _solvers.TDMA_lapacke(dl, d, du, b)
print info
return b.ravel()
def test_trisolve(n=20000):
dl = np.random.randn(n - 1)
d = np.random.randn(n)
du = np.random.randn(n - 1)
M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc')
x = np.random.randn(n)
b = M.dot(x)
x_hat = trisolve_lapacke(dl, d, du, b)
print "||x - x_hat|| = ", np.linalg.norm(x - x_hat)
Thật không may, test_trisolve
chỉ cần se gfaults trên cuộc gọi đến _solvers.TDMA_lapacke
. Tôi chắc chắn rằng setup.py
của tôi là chính xác - ldd _solvers.so
cho biết rằng _solvers.so
đang được liên kết với các thư viện được chia sẻ chính xác khi chạy.
Tôi không thực sự chắc chắn cách tiếp tục từ đây - bất kỳ ý tưởng nào?
Một cập nhật ngắn gọn:
cho các giá trị nhỏ hơn của n
tôi có xu hướng không để có được segfaults ngay lập tức, nhưng tôi có được kết quả vô nghĩa (|| x - x_hat || nên được rất gần với 0):
In [28]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 6.23202576396
In [29]: test_trisolve2.test_trisolve(10)
-7
||x - x_hat|| = 3.88623414288
In [30]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 2.60190676562
In [31]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 3.86631743386
In [32]: test_trisolve2.test_trisolve(10)
Segmentation fault
thường LAPACKE_dgtsv
lợi nhuận với mã 0
(mà nên chỉ thành công), nhưng đôi khi tôi nhận được -7
, có nghĩa là đối số 7 (b
) có giá trị bất hợp pháp. Những gì đang xảy ra là chỉ có giá trị đầu tiên của b
thực sự đang được sửa đổi tại chỗ. Nếu tôi tiếp tục gọi số test_trisolve
, tôi cuối cùng sẽ đạt được một khoảng cách ngay cả khi n
là nhỏ.