Chinaunix首页 | 论坛 | 博客
  • 博客访问: 1797219
  • 博文数量: 297
  • 博客积分: 285
  • 博客等级: 二等列兵
  • 技术积分: 3006
  • 用 户 组: 普通用户
  • 注册时间: 2010-03-06 22:04
个人简介

Linuxer, ex IBMer. GNU https://hmchzb19.github.io/

文章分类

全部博文(297)

文章存档

2020年(11)

2019年(15)

2018年(43)

2017年(79)

2016年(79)

2015年(58)

2014年(1)

2013年(8)

2012年(3)

分类: 大数据

2020-04-12 18:09:32

logistic_regression 说白了就是个binary classifier.遵从伯努利分布

代码如下:

点击(此处)折叠或打开

  1. # coding: utf-8

  2. import seaborn as sns
  3. import matplotlib.pyplot as plt
  4. import pandas as pd
  5. import numpy as np

  6. raw_data=pd.read_csv('data-analysis/python-jupyter/2.01. Admittance.csv')

  7. type(raw_data)
  8. print(raw_data.columns)


  9. data=raw_data.copy()
  10. data['Admitted']=raw_data['Admitted'].map({'Yes':1, 'No':0})

  11. X1=data['SAT']
  12. y=data['Admitted']
  13. print('X shape: {},y shape: {}'.format(X1.shape, y.shape))

  14. #plot
  15. plt.scatter(X1, y, color='c0')
  16. plt.xlabel('SAT', fontsize=20)
  17. plt.ylabel('Admitted', fontsize=20)
  18. plt.show()


  19. import statsmodels.api as sm
  20. import statsmodels.formula.api as smf

  21. x=sm.add_constant(X1)
  22. reg_log=sm.Logit(y, x)
  23. results_log = reg_log.fit()
  24. from scipy import stats

  25. print(results_log.summary())
  26. print()

  27. from sklearn.linear_model import LogisticRegression
  28. #Scikit-Learn LogisticRegression
  29. reg=LogisticRegression(solver='lbfgs')
  30. #reshape X
  31. x_matrix = X1.values.reshape(168, 1)
  32. reg.fit(x_matrix, y)

  33. #cm_df: confustion matrix
  34. cm_df=pd.DataFrame(results_log.pred_table())
  35. cm_df.columns=['Predicted 0', 'Predicted 1']
  36. cm_df = cm_df.rename(index={0:'Acutal 0', 1:'Actual 1'})

  37. #print the confusion-matrix
  38. print(cm_df)

  39. #calculate the accuracy
  40. cm = np.array(cm_df)
  41. accuracy_train = (cm[0,0] + cm[1,1]) / cm.sum()
  42. print('accuracy: {}'.format(accuracy_train))


阅读(9016) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~