《Python科学计算》的作者
分类: Python/Ruby
2012-11-17 15:23:06
在上有人问道是否能对如下的循环进行提速。
import numpy as np
N = 1000
data = np.random.rand(N, 10, 10)
dm = np.zeros(N)
for i in xrange(N):
dm[i] = np.linalg.det(data[i])
即调用N
次det()计算N
个相同大小的矩阵的行列式。NumPy给人的印象是它包装了大量高速运算的Fortran库,因此除非使用编译语言,很难再对其进行加速。然而实际上NumPy除了对Fortran库进行包装之外,它还需要做许多额外的工作,我们可以想办法提高这些额外工作的效率。
首先下面是numpy.linalg.det()
相关的代码:
def slogdet(a):
a = asarray(a)
_assertRank2(a)
_assertSquareness(a)
t, result_t = _commonType(a)
a = _fastCopyAndTranspose(t, a)
a = _to_native_byte_order(a)
n = a.shape[0]
if isComplexType(t):
lapack_routine = lapack_lite.zgetrf
else:
lapack_routine = lapack_lite.dgetrf
pivots = zeros((n,), fortran_int)
results = lapack_routine(n, n, a, n, pivots, 0)
info = results['info']
if (info < 0):
raise TypeError, "Illegal input to Fortran routine"
elif (info > 0):
return (t(0.0), _realType(t)(-Inf))
sign = 1. - 2. * (add.reduce(pivots != arange(1, n + 1)) % 2)
d = diagonal(a)
absd = absolute(d)
sign *= multiply.reduce(d / absd)
log(absd, absd)
logdet = add.reduce(absd, axis=-1)
return sign, logdet
def det(a):
sign, logdet = slogdet(a)
return sign * exp(logdet)
由这段代码可知,对于实数矩阵会调用Fortran库中的lapack_lite.dgetrf()
。在这句关键的Fortran函数调用之前,NumPy对输入数组进行了许多检测和转换工作。而在调用之后,还通过一些其它函数对输出进行运算。
我们可以重新编写这段代码,将循环集中在调用lapack_lite.dgetrf()
之上。尽量删除掉对输入数据的检测和转换工作,而对于输出结果我们希望能使用NumPy的广播功能代替循环计算。
不过在着手做这些事情之前,让我们先进行一次Profiling,看看能否真的提高计算速度。
在IPython的notebook中,我们可以使用%%prun命令对代码进行Profing:
%%prun
import numpy as np
N = 5000
data = np.random.rand(N, 10, 10)
dm = np.zeros(N)
for i in xrange(N):
dm[i] = np.linalg.det(data[i])
对上面的代码进行Profiling的结果如下:
165004 function calls in 1.581 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
5000 0.551 0.000 1.432 0.000 linalg.py:1560(slogdet)
15000 0.130 0.000 0.130 0.000 {method 'reduce' of 'numpy.ufunc' objects}
5000 0.078 0.000 1.510 0.000 linalg.py:1642(det)
5000 0.068 0.000 0.068 0.000 {numpy.linalg.lapack_lite.dgetrf}
5000 0.068 0.000 0.068 0.000 {numpy.core.multiarray._fastCopyAndTranspose}
5000 0.060 0.000 0.130 0.000 linalg.py:99(_commonType)
5000 0.052 0.000 0.052 0.000 {method 'diagonal' of 'numpy.ndarray' objects}
10000 0.051 0.000 0.097 0.000 numeric.py:167(asarray)
5000 0.047 0.000 0.123 0.000 linalg.py:139(_fastCopyAndTranspose)
10000 0.046 0.000 0.046 0.000 {numpy.core.multiarray.array}
1 0.040 0.040 1.581 1.581 :2()
5000 0.039 0.000 0.057 0.000 linalg.py:127(_to_native_byte_order)
5000 0.038 0.000 0.142 0.000 fromnumeric.py:902(diagonal)
10000 0.038 0.000 0.058 0.000 linalg.py:71(isComplexType)
5000 0.034 0.000 0.056 0.000 linalg.py:157(_assertSquareness)
5000 0.034 0.000 0.034 0.000 {numpy.core.multiarray.arange}
15000 0.034 0.000 0.034 0.000 {issubclass}
1 0.031 0.031 0.031 0.031 {method 'rand' of 'mtrand.RandomState' objects}
5000 0.031 0.000 0.040 0.000 linalg.py:151(_assertRank2)
5001 0.025 0.000 0.025 0.000 {numpy.core.multiarray.zeros}
15000 0.025 0.000 0.025 0.000 {len}
5000 0.020 0.000 0.029 0.000 linalg.py:84(_realType)
5000 0.012 0.000 0.012 0.000 {max}
5000 0.010 0.000 0.010 0.000 {method 'append' of 'list' objects}
5000 0.010 0.000 0.010 0.000 {min}
5000 0.009 0.000 0.009 0.000 {method 'get' of 'dict' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
可以看到Fortran库函数lapack_lite.dgetrf()
并不是最耗时的,如果我们将循环集中在它上面,尽量减少其它函数的调用次数,能提高将近10倍的运算速度。
我们将NumPy中的det()
和slogdet()
的运算放在一起,并对程序做了如下改动:
lapack_lite.dgetrf()
的计算结果将覆盖输入矩阵。如果需要保留原始矩阵,可以在调用dets_fast()
之前复制数据。lapack_lite.dgetrf()
时所得到的pivot数组。diagonal(a)
转换为使用高级下标存取:idx = np.arange(n); d = a[:, idx, idx]
import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite
def dets_fast(a):
m = a.shape[0]
n = a.shape[1]
lapack_routine = lapack_lite.dgetrf
pivots = np.zeros((m, n), intc)
flags = np.arange(1, n + 1).reshape(1, -1)
for i in xrange(m):
tmp = a[i]
lapack_routine(n, n, tmp, n, pivots[i], 0)
sign = 1. - 2. * (np.add.reduce(pivots != flags, axis=1) % 2)
idx = np.arange(n)
d = a[:, idx, idx]
absd = np.absolute(d)
sign *= np.multiply.reduce(d / absd, axis=1)
np.log(absd, absd)
logdet = np.add.reduce(absd, axis=-1)
return sign * np.exp(logdet)
下面是直接采用循环调用linalg.det()
的代码:
import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite
def dets(a):
length = a.shape[0]
dm = np.zeros(length)
for i in xrange(length):
dm[i] = np.linalg.det(M[i])
return dm
首先检测计算结果是否正确,由于dets_fast()
中没有对原始矩阵进行转置,因此运算结果和dets()
有微小差别,因此使用numpy.allclose()
比较二者的结果:
N = 1000
M = np.random.rand(N*10*10).reshape(N, 10, 10)
print np.allclose(dets(M), dets_fast(M.copy()))
下面比较运算速度,可以看出运算速度有10多倍的提升。由于dets_fast()
会改变输入数组,因此我们将M
复制一份再传递给它。
%timeit dets(M)
%timeit dets_fast(M.copy())