Bạn cần sửa đổi np.linalg.det
để tăng tốc độ. Ý tưởng là det()
là một hàm Python, nó thực hiện rất nhiều kiểm tra đầu tiên, và gọi thường trình fortran, và có một số mảng tính toán để có được kết quả.
Đây là mã từ NumPy:
def slogdet(a):
a = asarray(a)
_assertRank2(a)
_assertSquareness(a)
t, result_t = _commonType(a)
a = _fastCopyAndTranspose(t, a)
a = _to_native_byte_order(a)
n = a.shape[0]
if isComplexType(t):
lapack_routine = lapack_lite.zgetrf
else:
lapack_routine = lapack_lite.dgetrf
pivots = zeros((n,), fortran_int)
results = lapack_routine(n, n, a, n, pivots, 0)
info = results['info']
if (info < 0):
raise TypeError, "Illegal input to Fortran routine"
elif (info > 0):
return (t(0.0), _realType(t)(-Inf))
sign = 1. - 2. * (add.reduce(pivots != arange(1, n + 1)) % 2)
d = diagonal(a)
absd = absolute(d)
sign *= multiply.reduce(d/absd)
log(absd, absd)
logdet = add.reduce(absd, axis=-1)
return sign, logdet
def det(a):
sign, logdet = slogdet(a)
return sign * exp(logdet)
Để tăng tốc chức năng này, bạn có thể bỏ qua việc kiểm tra (nó trở thành trách nhiệm của bạn để giữ cho các đầu vào bên phải), và thu thập kết quả fortran trong một mảng, và thực hiện các phép tính cuối cùng cho tất cả các mảng nhỏ mà không có vòng lặp.
Dưới đây là kết quả của tôi:
import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite
N = 1000
M = np.random.rand(N*10*10).reshape(N, 10, 10)
def dets(a):
length = a.shape[0]
dm = np.zeros(length)
for i in xrange(length):
dm[i] = np.linalg.det(M[i])
return dm
def dets_fast(a):
m = a.shape[0]
n = a.shape[1]
lapack_routine = lapack_lite.dgetrf
pivots = np.zeros((m, n), intc)
flags = np.arange(1, n + 1).reshape(1, -1)
for i in xrange(m):
tmp = a[i]
lapack_routine(n, n, tmp, n, pivots[i], 0)
sign = 1. - 2. * (np.add.reduce(pivots != flags, axis=1) % 2)
idx = np.arange(n)
d = a[:, idx, idx]
absd = np.absolute(d)
sign *= np.multiply.reduce(d/absd, axis=1)
np.log(absd, absd)
logdet = np.add.reduce(absd, axis=-1)
return sign * np.exp(logdet)
print np.allclose(dets(M), dets_fast(M.copy()))
và tốc độ là:
timeit dets(M)
10 loops, best of 3: 159 ms per loop
timeit dets_fast(M)
100 loops, best of 3: 10.7 ms per loop
Vì vậy, bằng cách làm này, bạn có thể tăng tốc bằng 15 lần. Đó là một kết quả tốt mà không có bất kỳ mã được biên dịch nào.
lưu ý: Tôi bỏ qua kiểm tra lỗi cho thường trình fortran.
Cảm ơn bạn rất nhiều vì mã ví dụ của bạn và rằng bạn thậm chí đã làm thời gian. Nó hoạt động rất tốt cho các ma trận bậc hai nhỏ (O (MxM)) và không trở nên tồi tệ hơn numpy.linalg.det được thực hiện cho N ~ M. – user1825991