C++,python,热爱算法和机器学习
全部博文(1214)
分类: IT业界
2020-09-07 17:30:26
最近接触到了Scipy中optimize模块的一些函数,optimize模块中提供了很多数值优化算法,其中,最小二乘法可以说是最经典的数值优化技术了, 通过最小化误差的平方来寻找最符合数据的曲线。在optimize模块中,使用leastsq()函数可以很快速地使用最小二乘法对数据进行拟合。
首先来看leastsq()函数地调用格式:
leastsq(func, x0, args=(), Dfun=None, full_output=0, col_deriv=0, ftol=1.49012e-08, xtol=1.49012e-08, gtol=0.0, maxfev=0, epsfcn=0.0, factor=100, diag=None, warning=True)
参数还是非常多的,一般来说,我们只需要前三个参数就够了他们的作用分别是:
举个例子:
这里要进行拟合的数据点都分布在这条正弦曲线附近:
def func(x): return 2*np.sin(2*np.pi*x)
然后定义误差函数,所谓误差就是指我们拟合的曲线的值对应真实值的差:
def residuals(p, x, y): fun = np.poly1d(p) # poly1d()函数可以按照输入的列表p返回一个多项式函数 return y - fun(x) # 返回真实值 与我们拟合的曲线上对应的值的差
这里设计了一个poly1d()函数,关于这个函数,简单理解下就是输入一个列表,返回以这个列表中的值为参数的多项式,例如:
输入:[1,2,3] 返回:x^2 + 2x + 3 多项式的次数是从0开始记的,要注意这个地方
下面定义关于拟合的曲线的函数:
# 拟合函数 def fitting(p): pars = np.random.rand(p+1) # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p r = leastsq(residuals, pars, args=(X, Y)) # 三个参数:误差函数、函数参数列表、数据点 return r
注释里的内容就是要注意的地方,由于会多次调用拟合,多以写成了函数的形式,这里传入的p是一个数字,表示我们想要得到拟合曲线的次数,比如我想针对这些数据点得到一条3次的曲线,就调用p=3类似,注意这里leastsq()函数的返回值,这里的返回值保存的是拟合的曲线的信息,如果打印这里的r,就会发现返回了一个truple,其中第一维是一个列表,保存的是拟合的曲线的参数,所以要注意如何获得这些参数。
接下来定义一下我们要进行拟合的数据点,这里定义了10个:
# 要进行拟合的数据点 X = np.linspace(0, 1, 10) Y = [np.random.normal(0, 0.1)+num for num in func(X)] # 添加噪声 # 方便绘制曲线,所以创建多一些点 x_ = np.linspace(0, 1, 100) y_ = func(x_)
调用拟合函数,并进行绘图:
fit_pars = fitting(3)[0] # 注意返回值中的第一行才是拟合曲线的参数列表 plt.plot(x_, y_, label='real line') plt.scatter(X, Y, label='real points') plt.plot(x_, np.poly1d(fit_pars)(x_), label='fitting line') plt.legend() plt.show()
p=3的时候的图像:
当然,这里我直接传入p=3,也就是建立3次的曲线对数据点进行拟合,如果传入的p=1的时候,图像如下:
如果p=2,则是:
可以看到没有变化,也就是说没办法找到一条二次曲线,使得二次误差少于上面的一次曲线了。
完整代码如下:
import numpy as np from scipy.optimize import leastsq import matplotlib.pyplot as plt # 数据点分布在这条曲线附近 def func(x): return 2*np.sin(2*np.pi*x) # 误差函数, 计算拟合曲线与真实数据点之间的差 ,作为leastsq函数的输入 def residuals(p, x, y): fun = np.poly1d(p) # poly1d()函数可以按照输入的列表p返回一个多项式函数 return y - fun(x) # 拟合函数 def fitting(p): pars = np.random.rand(p+1) # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p r = leastsq(residuals, pars, args=(X, Y)) # 三个参数:误差函数、函数参数列表、数据点 return r # 要进行拟合的数据点 X = np.linspace(0, 1, 10) Y = [np.random.normal(0, 0.1)+num for num in func(X)] # 添加噪声 # 方便绘制曲线,所以创建 x_ = np.linspace(0, 1, 100) y_ = func(x_) # print(fitting(3)) 可以看一下返回的是什么 fit_pars = fitting(3)[0] plt.plot(x_, y_, label='real line') plt.scatter(X, Y, label='real points') plt.plot(x_, np.poly1d(fit_pars)(x_), label='fitting line') plt.legend() plt.show()
以上~