Chinaunix首页 | 论坛 | 博客
  • 博客访问: 1743422
  • 博文数量: 297
  • 博客积分: 285
  • 博客等级: 二等列兵
  • 技术积分: 3006
  • 用 户 组: 普通用户
  • 注册时间: 2010-03-06 22:04
个人简介

Linuxer, ex IBMer. GNU https://hmchzb19.github.io/

文章分类

全部博文(297)

文章存档

2020年(11)

2019年(15)

2018年(43)

2017年(79)

2016年(79)

2015年(58)

2014年(1)

2013年(8)

2012年(3)

分类: Python/Ruby

2018-06-21 14:41:33

polynomial regression

Not all relationships are linear.
Linear formula: y=mx+b
  This is a "first order" or "first degree" polynomial , as the power of x is 1.
Second order polynomial: y=ax**2 + bx + c
Third order polynomial: y = ax**3 + bx**2 + cx + d
Higher orders polynomial produce more complex curves.

#beware overfitting
Don't use more degrees than you need.
Visualize your data first to see how complex of a curve there might really be.
Visualize the fit - is your curve going out of its way to accomodate outliers?
A high r-squared simply means your curve fits your training data well, but it may not be a good predictor.

code:

点击(此处)折叠或打开

  1. #fabricate data
  2. np.random.seed(2)
  3. pageSpeeds = np.random.normal(3.0, 1.0, 1000)
  4. purchaseAmount = np.random.normal(50.0, 10.0, 1000) / pageSpeeds
  5. plt.scatter(pageSpeeds, purchaseAmount)
  6. plt.show()
  7. #numpy has a handy polyfit function we can use, to let us construct an nth-degree polynomial model of our data that minimizes squared error. Let's try it with a 4th degree polynomial.
  8. x=np.array(pageSpeeds)
  9. y=np.array(purchaseAmount)
  10. p4=np.poly1d(np.polyfit(x,y, 4))

  11. #visualize
  12. xp=np.linspace(0, 7, 100)
  13. plt.scatter(x, y)
  14. plt.plot(xp, p4(xp), c='r')
  15. plt.show()

  16. #measure the r-squared error, 0 is bad, and 1 is good.
  17. from sklearn.metrics import r2_score
  18. r2=r2_score(y, p4(x))
  19. print(r2)
  20. #output will be ,pretty good
  21. 0.82937663963

  22. #change the order to 8
  23. In [14]: p4=np.poly1d(np.polyfit(x,y, 8))
  24.     ...:
  25. In [15]: xp=np.linspace(0, 7, 100)
  26.     ...: plt.scatter(x, y)
  27.     ...: plt.plot(xp, p4(xp), c='r')
  28.     ...: plt.show()
  29.     ...:

  30. In [16]: from sklearn.metrics import r2_score
  31.     ...: r2=r2_score(y, p4(x))
  32.     ...: print(r2)
  33.     ...:
  34. #more accurate than order of 4
  35. 0.881439566368

  36. #change the order to 1 , this will be linear regression.
  37. p4=np.poly1d(np.polyfit(x,y, 1))
  38. xp=np.linspace(0, 7, 100)
  39. plt.scatter(x, y)
  40. plt.plot(xp, p4(xp), c='r')
  41. plt.show()
  42. from sklearn.metrics import r2_score
  43. r2=r2_score(y, p4(x))
  44. print(r2)
  45. #r-squared is only 0.50
  46. 0.502494130455



阅读(1071) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~