• 博客访问： 952959
• 博文数量： 33
• 博客积分： 803
• 博客等级： 军士长
• 技术积分： 1755
• 用 户 组： 普通用户
• 注册时间： 2010-03-05 18:58

《Python科学计算》的作者

2016年（1）

2014年（2）

2013年（3）

2012年（27）

2012-05-04 18:51:52

k = 3.0
c = 0.2
m = 0.1
f = lambda t: 10.0
init = 5.0, 0.0

mass_dump_spring()计算在时刻t、状态为status时，的值。

def mass_dump_spring(status, t):
x, v = status
dx = v
dv = (f(t) - k*x - c*v)/m
return dx, dv

import numpy as np
from scipy.integrate import odeint
import pylab as pl

def solve_by_odeint(time, h):
t = np.arange(0, time, h)
result = odeint(mass_dump_spring, init, t)
return t, result[:, 0], result[:, 1]

return [status[i] + dstatus[i]*h for i in xrange(len(status))]

def euler(func, status, time, h):
tlist = np.arange(0, time, h).tolist()
result = []
for t in tlist:
result.append(status)
dstatus = func(status, t)
return tlist, np.array(result)

def solve_by_euler(time, h):
t, result = euler(mass_dump_spring, init, time, h)
return t, result[:, 0], result[:, 1]

def euler_plot():
for i, h in enumerate([0.01, 0.001]):
pl.subplot(211 + i)
t, x, v = solve_by_odeint(5, h)
t, x_euler, v_euler = solve_by_euler(5, h)
pl.plot(t, x, label="odeint")
pl.plot(t, x_euler, "r", label="euler")
pl.legend(loc="best")
pl.title("h = %g" % h)

def midpoint(func, status, time, h):
tlist = np.arange(0, time, h).tolist()
result = []
for t in tlist:
result.append(status)
dstatus = func(status, t)
dstatus2 = func(status2, t+0.5*h)
return tlist, np.array(result)

def solve_by_midpoint(time, h):
t, result = midpoint(mass_dump_spring, init, time, h)
return t, result[:, 0], result[:, 1]

def rk4(func, status, time, h):
tlist = np.arange(0, time, h).tolist()
h2 = 0.5*h
result = []
for t in tlist:
result.append(status)
k1 = func(status, t)
k2 = func(add(status, k1, h2), t+h2)
k3 = func(add(status, k2, h2), t+h2)
k4 = func(add(status, k3, h), t+h)
dstatus = [v1+2*v2+2*v3+v4 for (v1,v2,v3,v4) in zip(k1,k2,k3,k4)]
return tlist, np.array(result)

def solve_by_rk4(time, h):
t, result = rk4(mass_dump_spring, init, time, h)
return t, result[:, 0], result[:, 1]

def error(func1, func2, time, h_list):
ex = []
ev = []
for h in h_list:
_, x1, v1 = func1(time, h)
_, x2, v2 = func2(time, h)
ex.append(np.mean(np.abs(x1-x2)))
ev.append(np.mean(np.abs(v1-v2)))
return ex, ev

def error_plot(func1, func2, title):
h_list = np.logspace(-3, -1, 20)
ex, ev = error(func1, func2, 5.0, h_list)
pl.loglog(h_list, ex, lw=2, label="error x of %s" % title)
pl.loglog(h_list, ev, lw=2, label="error v of %s" % title)

error_plot(solve_by_odeint, solve_by_euler, "euler")
error_plot(solve_by_odeint, solve_by_midpoint, "midpoint")
error_plot(solve_by_odeint, solve_by_rk4, "rk4")
pl.rcParams["legend.fontsize"] = "small"
pl.legend(loc="best")
pl.show()