Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3593510
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2021-03-26 17:22:09

import numpy as np

from sklearn.datasets import load_iris

import matplotlib.pyplot as plt

import pandas as pd

from matplotlib.colors import ListedColormap

class Perceptron(object):

    def __init__(self,lr, n_iter):

        self.lr = lr

        self.n_iter = n_iter

    def train(self,X,y):

        #这个向量包含了两重信息,第一w_[0]代表了偏置项bw_[1:]代表了权重向量

        self.w_ = np.zeros(1+X.shape[1])

        self.errors_ = []

        for _ in range(self.n_iter):

            errors = 0

            for xi, target in zip(X,y):

                #这里的target - self.predict(xi)对应公式(5)中的yi

                update = self.lr * (target - self.predict(xi))

                self.w_[1:] += update * xi

                self.w_[0] += update

                errors += int(update != 0.0)

            self.errors_.append(errors)

        return self

    #这个函数完成了感知机定义里面的w·x+b这个操作

    def net_input(self, X):

        return np.dot(X, self.w_[1:])+self.w_[0]

    #按照学习算法的第三步,对式子判断其是否小于0

    def predict(self, X):

        return np.where(self.net_input(X) >= 0.0, 1, -1)

# bunch格式的数据集转化为pandasdataframe

def sklearn_to_df(datasets):

    df = pd.DataFrame(datasets.data, columns=datasets.feature_names)

    df['target'] = pd.Series(datasets.target)

    return df

iris = load_iris()

df_iris = sklearn_to_df(iris)

y = df_iris.iloc[0:100,4].values

y = np.where(y == 0,-1,1)

X = df_iris.iloc[0:100,[0,2]].values

#鸢尾花数据集的可视化

# plt.scatter(X[:50,0],X[:50,1],

#             color='r',marker='o',label='setosa')

# plt.scatter(X[50:100,0],X[50:100,1],

#             color='b',marker='x',label='versicolor')

# plt.xlabel('petal length')

# plt.ylabel('sepal length')

# plt.legend(loc='upper left')

# plt.show()

# 利用鸢尾花数据集来训练感知机

pn = Perceptron(lr=0.01,n_iter=10)

pn.train(X, y)

# plt.plot(range(1,len(pn.errors_)+1),pn.errors_,marker='o')

# plt.xlabel('Epoch')

# plt.ylabel('Number of miscalssifications')

# plt.show()

#画出决策边界

def plot_decision_regions(X, y, classifier, resolution=0.02):

    markers = ('s','x','o','^','v')

    colors = ('red','blue','lightgreen','gray','cyan')

    cmap = ListedColormap(colors[:len(np.unique(y))])

    x1_min, x1_max = X[:,0].min() - 1, X[:,0].max() + 1

    x2_min, x2_max = X[:,1].min() - 1, X[:,1].max() + 1

    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),

                           np.arange(x2_min, x2_max, resolution))

    z = classifier.predict(np.array([xx1.ravel(),xx2.ravel()]).T)

    z = z.reshape(xx1.shape)

    plt.contourf(xx1,xx2,z,alpha=0.4,cmap=cmap)

    plt.xlim(xx1.min(),xx1.max())

    plt.ylim(xx2.min(),xx2.max())

    for idx, cl in enumerate(np.unique(y)):

        plt.scatter(x = X[y == cl,0], y = X[y == cl,1],

                    alpha=0.8, c=cmap(idx),

                    marker=markers[idx],label=cl)

plot_decision_regions(X, y, classifier=pn)

plt.xlabel('sepal length [cm]')

plt.ylabel('petal length [cm]')

plt.legend(loc = 'upper left')

plt.show()

阅读(16959) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~