Chinaunix首页 | 论坛 | 博客
  • 博客访问: 962404
  • 博文数量: 33
  • 博客积分: 803
  • 博客等级: 军士长
  • 技术积分: 1755
  • 用 户 组: 普通用户
  • 注册时间: 2010-03-05 18:58
个人简介

《Python科学计算》的作者

文章分类

全部博文(33)

文章存档

2016年(1)

2014年(2)

2013年(3)

2012年(27)

分类: Python/Ruby

2012-11-17 15:23:06

在上有人问道是否能对如下的循环进行提速。

import numpy as np

N = 1000
data = np.random.rand(N, 10, 10)
dm = np.zeros(N)
for i in xrange(N):
    dm[i] = np.linalg.det(data[i])

即调用N次det()计算N个相同大小的矩阵的行列式。NumPy给人的印象是它包装了大量高速运算的Fortran库,因此除非使用编译语言,很难再对其进行加速。然而实际上NumPy除了对Fortran库进行包装之外,它还需要做许多额外的工作,我们可以想办法提高这些额外工作的效率。

NumPy中的det()代码

首先下面是numpy.linalg.det()相关的代码:

def slogdet(a):
    a = asarray(a)
    _assertRank2(a)
    _assertSquareness(a)
    t, result_t = _commonType(a)
    a = _fastCopyAndTranspose(t, a)
    a = _to_native_byte_order(a)
    n = a.shape[0]
    if isComplexType(t):
        lapack_routine = lapack_lite.zgetrf
    else:
        lapack_routine = lapack_lite.dgetrf
    pivots = zeros((n,), fortran_int)
    results = lapack_routine(n, n, a, n, pivots, 0)
    info = results['info']
    if (info < 0):
        raise TypeError, "Illegal input to Fortran routine"
    elif (info > 0):
        return (t(0.0), _realType(t)(-Inf))
    sign = 1. - 2. * (add.reduce(pivots != arange(1, n + 1)) % 2)
    d = diagonal(a)
    absd = absolute(d)
    sign *= multiply.reduce(d / absd)
    log(absd, absd)
    logdet = add.reduce(absd, axis=-1)
    return sign, logdet

def det(a):
    sign, logdet = slogdet(a)
    return sign * exp(logdet)

由这段代码可知,对于实数矩阵会调用Fortran库中的lapack_lite.dgetrf()。在这句关键的Fortran函数调用之前,NumPy对输入数组进行了许多检测和转换工作。而在调用之后,还通过一些其它函数对输出进行运算。

我们可以重新编写这段代码,将循环集中在调用lapack_lite.dgetrf()之上。尽量删除掉对输入数据的检测和转换工作,而对于输出结果我们希望能使用NumPy的广播功能代替循环计算。

不过在着手做这些事情之前,让我们先进行一次Profiling,看看能否真的提高计算速度。

Profiling

在IPython的notebook中,我们可以使用%%prun命令对代码进行Profing:

%%prun
import numpy as np

N = 5000
data = np.random.rand(N, 10, 10)
dm = np.zeros(N)
for i in xrange(N):
    dm[i] = np.linalg.det(data[i])

对上面的代码进行Profiling的结果如下:

165004 function calls in 1.581 seconds

Ordered by: internal time

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 5000    0.551    0.000    1.432    0.000 linalg.py:1560(slogdet)
15000    0.130    0.000    0.130    0.000 {method 'reduce' of 'numpy.ufunc' objects}
 5000    0.078    0.000    1.510    0.000 linalg.py:1642(det)
 5000    0.068    0.000    0.068    0.000 {numpy.linalg.lapack_lite.dgetrf}
 5000    0.068    0.000    0.068    0.000 {numpy.core.multiarray._fastCopyAndTranspose}
 5000    0.060    0.000    0.130    0.000 linalg.py:99(_commonType)
 5000    0.052    0.000    0.052    0.000 {method 'diagonal' of 'numpy.ndarray' objects}
10000    0.051    0.000    0.097    0.000 numeric.py:167(asarray)
 5000    0.047    0.000    0.123    0.000 linalg.py:139(_fastCopyAndTranspose)
10000    0.046    0.000    0.046    0.000 {numpy.core.multiarray.array}
    1    0.040    0.040    1.581    1.581 :2()
 5000    0.039    0.000    0.057    0.000 linalg.py:127(_to_native_byte_order)
 5000    0.038    0.000    0.142    0.000 fromnumeric.py:902(diagonal)
10000    0.038    0.000    0.058    0.000 linalg.py:71(isComplexType)
 5000    0.034    0.000    0.056    0.000 linalg.py:157(_assertSquareness)
 5000    0.034    0.000    0.034    0.000 {numpy.core.multiarray.arange}
15000    0.034    0.000    0.034    0.000 {issubclass}
    1    0.031    0.031    0.031    0.031 {method 'rand' of 'mtrand.RandomState' objects}
 5000    0.031    0.000    0.040    0.000 linalg.py:151(_assertRank2)
 5001    0.025    0.000    0.025    0.000 {numpy.core.multiarray.zeros}
15000    0.025    0.000    0.025    0.000 {len}
 5000    0.020    0.000    0.029    0.000 linalg.py:84(_realType)
 5000    0.012    0.000    0.012    0.000 {max}
 5000    0.010    0.000    0.010    0.000 {method 'append' of 'list' objects}
 5000    0.010    0.000    0.010    0.000 {min}
 5000    0.009    0.000    0.009    0.000 {method 'get' of 'dict' objects}
    1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

可以看到Fortran库函数lapack_lite.dgetrf()并不是最耗时的,如果我们将循环集中在它上面,尽量减少其它函数的调用次数,能提高将近10倍的运算速度。

编写高速运算的代码

我们将NumPy中的det()slogdet()的运算放在一起,并对程序做了如下改动:

  • 删除对输入矩阵的检测,直接将其当作是满足条件的双精度浮点数组。
  • 无须对输入矩阵进行转置,因为转置矩阵的行列式与原始矩阵的行列式相同。
  • 不对输入矩阵进行复制,lapack_lite.dgetrf()的计算结果将覆盖输入矩阵。如果需要保留原始矩阵,可以在调用dets_fast()之前复制数据。
  • 创建一个形状为MxN的整数数组pivots,用来保存每次调用lapack_lite.dgetrf()时所得到的pivot数组。
  • 将取对象线元素的diagonal(a)转换为使用高级下标存取:idx = np.arange(n); d = a[:, idx, idx]
  • 采用NumPy的矢量运算,省略Python级别的循环。
import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite

def dets_fast(a):
    m = a.shape[0]
    n = a.shape[1]
    lapack_routine = lapack_lite.dgetrf
    pivots = np.zeros((m, n), intc)
    flags = np.arange(1, n + 1).reshape(1, -1)
    for i in xrange(m):
        tmp = a[i]
        lapack_routine(n, n, tmp, n, pivots[i], 0)
    sign = 1. - 2. * (np.add.reduce(pivots != flags, axis=1) % 2)
    idx = np.arange(n)
    d = a[:, idx, idx]
    absd = np.absolute(d)
    sign *= np.multiply.reduce(d / absd, axis=1)
    np.log(absd, absd)
    logdet = np.add.reduce(absd, axis=-1)
    return sign * np.exp(logdet)

下面是直接采用循环调用linalg.det()的代码:

import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite

def dets(a):
    length = a.shape[0]
    dm = np.zeros(length)
    for i in xrange(length):
        dm[i] = np.linalg.det(M[i])
    return dm

首先检测计算结果是否正确,由于dets_fast()中没有对原始矩阵进行转置,因此运算结果和dets()有微小差别,因此使用numpy.allclose()比较二者的结果:

N = 1000
M = np.random.rand(N*10*10).reshape(N, 10, 10)
print np.allclose(dets(M), dets_fast(M.copy()))
True

下面比较运算速度,可以看出运算速度有10多倍的提升。由于dets_fast()会改变输入数组,因此我们将M复制一份再传递给它。

%timeit dets(M)
%timeit dets_fast(M.copy())
1 loops, best of 3: 173 ms per loop
100 loops, best of 3: 14.1 ms per loop
阅读(8169) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~