Chinaunix首页 | 论坛 | 博客
  • 博客访问: 1781958
  • 博文数量: 297
  • 博客积分: 285
  • 博客等级: 二等列兵
  • 技术积分: 3006
  • 用 户 组: 普通用户
  • 注册时间: 2010-03-06 22:04
个人简介

Linuxer, ex IBMer. GNU https://hmchzb19.github.io/

文章分类

全部博文(297)

文章存档

2020年(11)

2019年(15)

2018年(43)

2017年(79)

2016年(79)

2015年(58)

2014年(1)

2013年(8)

2012年(3)

分类: 大数据

2020-09-15 16:51:30

例子出自Hackers guide to machine learning with python

使用的数据文件非常小,只有两列,分别为Speed和Stopping Distances of Cars. 其中的cars

代码如下:


点击(此处)折叠或打开

  1. # coding: utf-8

  2. import pandas as pd
  3. import tensorflow as tf
  4. import matplotlib
  5. matplotlib.use('Qt5Agg')
  6. import matplotlib.pyplot as plt
  7. import seaborn as sns

  8. sns.set(style='whitegrid', palette='muted', font_scale=1.5)

  9. data_csv=pd.read_csv('datasets/cars.csv', usecols=['speed', 'dist'])
  10. print(data_csv.head())

  11. target = data_csv['dist']
  12. speed = data_csv['speed']

  13. #plot datasets
  14. sns.scatterplot(speed, target);
  15. plt.xlabel("speed")
  16. plt.ylabel("stopping distance");


  17. dataset=tf.data.Dataset.from_tensor_slices((speed.values, target.values))
  18. for feat, targ in dataset.take(5):
  19.     print('Feature: {}, Target: {}'.format(feat, targ))
  20.     

  21. from tensorflow.keras import layers
  22. lin_reg=tf.keras.Sequential([layers.Dense(1, activation='linear',input_shape=[1]),])
  23. optimizer = tf.keras.optimizers.RMSprop(0.001)

  24. lin_reg.compile(loss='mse', optimizer=optimizer, metrics=['mse'])
  25. history = lin_reg.fit(x=speed, y=target, shuffle=True, epochs=1000, validation_split=0.2, verbose=0)

  26. def plot_error(history):
  27.     hist = pd.DataFrame(history.history)
  28.     hist['epoch'] = history.epoch
  29.     
  30.     plt.figure()
  31.     plt.xlabel('Epoch')
  32.     plt.ylabel('Mean Square Error')
  33.     plt.plot(hist['epoch'], hist['mse'],
  34.                 label='Train Error')
  35.     plt.plot(hist['epoch'], hist['val_mse'],
  36.                 label = 'Val Error')
  37.     
  38.     plt.legend()
  39.     plt.show()

  40. plot_error(history)

  41. print(lin_reg.summary())
  42. #get weights
  43. weights = lin_reg.get_layer("dense").get_weights()
  44. intercept = weights[0][0][0]
  45. slope = weights[1][0]

  46. print('weigths: {}'.format(weights))
  47. print('intercept: {}'.format(intercept))
  48. print('slope: {}'.format(slope))


  49. #build simple neural network
  50. def build_neural_net():
  51.     net = tf.keras.Sequential([
  52.         layers.Dense(32, activation='relu', input_shape=[1]),
  53.         layers.Dense(16, activation='relu'),
  54.         layers.Dense(1),
  55.     ])
  56.     
  57.     optimizer = tf.keras.optimizers.RMSprop(0.001)
  58.     
  59.     net.compile(loss='mse',
  60.                     optimizer=optimizer,
  61.                     metrics=['mse', 'accuracy'])
  62.     
  63.     return net

  64. net = build_neural_net()

  65. history = net.fit(
  66.     x=speed,
  67.     y=target,
  68.     shuffle=True,
  69.     epochs=1000,
  70.     validation_split=0.2,
  71.     verbose=0
  72. )

  73. plot_error(history)

  74. #stop training early
  75. early_stop = tf.keras.callbacks.EarlyStopping(
  76.     monitor='val_loss',
  77.     patience=10
  78. )


  79. net = build_neural_net()

  80. #simple neural network summary
  81. print(net.summary())
  82. print(net.weights)


  83. history = net.fit(
  84.     x=speed,
  85.     y=target,
  86.     shuffle=True,
  87.     epochs=1000,
  88.     validation_split=0.2,
  89.     verbose=0,
  90.     callbacks=[early_stop]
  91.     )

  92. plot_error(history)

  93. #save & restore model
  94. net.save('simple_net.h5')
  95. simple_net = tf.keras.models.load_model('simple_net.h5')
  96. print(simple_net.summary())

simple linear regression得出的slope 和intercept
试过几次,差距很大, 为什么呢?

点击(此处)折叠或打开

  1. intercept: 2.5482685565948486
  2. slope: 0.8580179810523987

  3. intercept: 0.6354206800460815
  4. slope: 1.985124111175537
  5. 那么得出的equation就会分别是
  6. y = 0.8580179810523987 * x + 2.5482685565948486
  7. y = 1.985124111175537 * x + 0.6354206800460815


阅读(1832) | 评论(0) | 转发(0) |
0

上一篇:Classify_breast_cancer_with_sklearn

下一篇:没有了

给主人留下些什么吧!~~