Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3648218
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2022-08-10 18:13:18

import numpy as np

import matplotlib.pyplot as plt

class DrawConfusionMatrix:

    def __init__(self, labels_name, normalize=True):

        """

normalize:是否设元素为百分比形式

        """

        self.normalize = normalize

        self.labels_name = labels_name

        self.num_classes = len(labels_name)

        self.matrix = np.zeros((self.num_classes, self.num_classes), dtype="float32")

    def update(self, predicts, labels):

        """

        :param predicts: 一维预测向量,egarray([0,5,1,6,3,...],dtype=int64)

        :param labels:   一维标签向量:egarray([0,5,0,6,2,...],dtype=int64)

        :return:

        """

        for predict, label in zip(predicts, labels):

            self.matrix[predict, label] += 1

    def getMatrix(self,normalize=True):

        """

        根据传入的normalize判断要进行percent的转换,

        如果normalizeTrue,则矩阵元素转换为百分比形式,

        如果normalizeFalse,跟单网gendan5.com则矩阵元素就为数量

        Returns:返回一个以百分比或者数量为元素的矩阵

        """

        if normalize:

            per_sum = self.matrix.sum(axis=1)  # 计算每行的和,用于百分比计算

            for i in range(self.num_classes):

                self.matrix[i] =(self.matrix[i] / per_sum[i])   # 百分比转换

            self.matrix=np.around(self.matrix, 2)   # 保留2位小数点

            self.matrix[np.isnan(self.matrix)] = 0  # 可能存在NaN,将其设为0

        return self.matrix

    def drawMatrix(self):

        self.matrix = self.getMatrix(self.normalize)

        plt.imshow(self.matrix, cmap=plt.cm.Blues)  # 仅画出颜色格子,没有值

        plt.title("Normalized confusion matrix")  # title

        plt.xlabel("Predict label")

        plt.ylabel("Truth label")

        plt.yticks(range(self.num_classes), self.labels_name)  # y轴标签

        plt.xticks(range(self.num_classes), self.labels_name, rotation=45)  # x轴标签

        for x in range(self.num_classes):

            for y in range(self.num_classes):

                value = float(format('%.2f' % self.matrix[y, x]))  # 数值处理

                plt.text(x, y, value, verticalalignment='center', horizontalalignment='center')  # 写值

        plt.tight_layout()  # 自动调整子图参数,使之填充整个图像区域

        plt.colorbar()  # 色条

        plt.savefig('./ConfusionMatrix.png', bbox_inches='tight')  # bbox_inches='tight'可确保标签信息显示全

        plt.show()

阅读(1250) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~