2014-04-21 15 views
5

Tôi đang cố gắng bọc chức năng LAPACK dgtsv (bộ giải cho các hệ phương trình tridiagonal) sử dụng Cython.Gói chức năng LAPACKE bằng cách sử dụng Cython

Tôi đã xem qua this previous answer, nhưng kể từ dgtsv không phải là một trong các hàm LAPACK được gói trong scipy.linalg Tôi không nghĩ rằng tôi có thể sử dụng phương pháp tiếp cận cụ thể này. Thay vào đó tôi đã cố gắng theo dõi this example.

Dưới đây là nội dung của tập tin lapacke.pxd tôi:

ctypedef int lapack_int 

cdef extern from "lapacke.h" nogil: 

    int LAPACK_ROW_MAJOR 
    int LAPACK_COL_MAJOR 

    lapack_int LAPACKE_dgtsv(int matrix_order, 
          lapack_int n, 
          lapack_int nrhs, 
          double * dl, 
          double * d, 
          double * du, 
          double * b, 
          lapack_int ldb) 

... đây là wrapper Cython mỏng của tôi trong _solvers.pyx:

#!python 

cimport cython 
from lapacke cimport * 

cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU, 
        double[:, ::1] B): 

    cdef: 
     lapack_int n = D.shape[0] 
     lapack_int nrhs = B.shape[1] 
     lapack_int ldb = B.shape[0] 
     double * dl = &DL[0] 
     double * d = &D[0] 
     double * du = &DU[0] 
     double * b = &B[0, 0] 
     lapack_int info 

    info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb) 

    return info 

... và đây là một Python wrapper và kịch bản thử nghiệm:

import numpy as np 
from scipy import sparse 
from cymodules import _solvers 


def trisolve_lapacke(dl, d, du, b, inplace=False): 

    if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1 
      or b.shape != d.shape): 
     raise ValueError('Invalid diagonal shapes') 

    if b.ndim == 1: 
     # b is (LDB, NRHS) 
     b = b[:, None] 

    # be sure to force a copy of d and b if we're not solving in place 
    if not inplace: 
     d = d.copy() 
     b = b.copy() 

    # this may also force copies if arrays are improperly typed/noncontiguous 
    dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64) 
        for v in (dl, d, du, b)) 

    # b will now be modified in place to contain the solution 
    info = _solvers.TDMA_lapacke(dl, d, du, b) 
    print info 

    return b.ravel() 


def test_trisolve(n=20000): 

    dl = np.random.randn(n - 1) 
    d = np.random.randn(n) 
    du = np.random.randn(n - 1) 

    M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc') 
    x = np.random.randn(n) 
    b = M.dot(x) 

    x_hat = trisolve_lapacke(dl, d, du, b) 

    print "||x - x_hat|| = ", np.linalg.norm(x - x_hat) 

Thật không may, test_trisolve chỉ cần se gfaults trên cuộc gọi đến _solvers.TDMA_lapacke. Tôi chắc chắn rằng setup.py của tôi là chính xác - ldd _solvers.so cho biết rằng _solvers.so đang được liên kết với các thư viện được chia sẻ chính xác khi chạy.

Tôi không thực sự chắc chắn cách tiếp tục từ đây - bất kỳ ý tưởng nào?


Một cập nhật ngắn gọn:

cho các giá trị nhỏ hơn của n tôi có xu hướng không để có được segfaults ngay lập tức, nhưng tôi có được kết quả vô nghĩa (|| x - x_hat || nên được rất gần với 0):

In [28]: test_trisolve2.test_trisolve(10) 
0 
||x - x_hat|| = 6.23202576396 

In [29]: test_trisolve2.test_trisolve(10) 
-7 
||x - x_hat|| = 3.88623414288 

In [30]: test_trisolve2.test_trisolve(10) 
0 
||x - x_hat|| = 2.60190676562 

In [31]: test_trisolve2.test_trisolve(10) 
0 
||x - x_hat|| = 3.86631743386 

In [32]: test_trisolve2.test_trisolve(10) 
Segmentation fault 

thường LAPACKE_dgtsv lợi nhuận với mã 0 (mà nên chỉ thành công), nhưng đôi khi tôi nhận được -7, có nghĩa là đối số 7 (b) có giá trị bất hợp pháp. Những gì đang xảy ra là chỉ có giá trị đầu tiên của b thực sự đang được sửa đổi tại chỗ. Nếu tôi tiếp tục gọi số test_trisolve, tôi cuối cùng sẽ đạt được một khoảng cách ngay cả khi n là nhỏ.

Trả lời

3

OK, cuối cùng tôi đã tìm ra nó - có vẻ như tôi đã hiểu lầm về những gì chính hàng và cột đề cập đến trong trường hợp này.

Vì các mảng tiếp giáp C theo thứ tự hàng lớn, tôi giả định rằng tôi phải chỉ định LAPACK_ROW_MAJOR làm đối số đầu tiên cho LAPACKE_dgtsv.

Trong thực tế, nếu tôi thay đổi

info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, ...) 

để

info = LAPACKE_dgtsv(LAPACK_COL_MAJOR, ...) 

sau đó chức năng của tôi làm việc:

test_trisolve2.test_trisolve() 
0 
||x - x_hat|| = 6.67064747632e-12 

này có vẻ khá phản trực giác đối với tôi - bất cứ ai có thể giải thích tại sao đây là trường hợp?

1

Mặc dù câu hỏi cũ dường như vẫn có liên quan. Các hành vi quan sát được là kết quả của một sự hiểu sai của tham số LDB:

  • mảng Fortran là col lớn và kích thước hàng đầu của mảng B tương ứng với N. Do đó LDB> = max (1, N).
  • Với hàng lớn LDB tương ứng với NRHS và do đó điều kiện LDB> = max (1, NRHS) phải được đáp ứng.

Comment # b là (LDB, NRHS) là không đúng vì b có kích thước (LDB, N) và LDB nên 1 trong trường hợp này.

Chuyển từ LAPACK_ROW_MAJOR sang LAPACK_COL_MAJOR khắc phục sự cố miễn là NRHS bằng 1. Bố cục bộ nhớ của một cột chính (N, 1) giống với hàng chính (1, N). Tuy nhiên, sẽ thất bại nếu NRHS lớn hơn 1.

Các vấn đề liên quan