Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3593386
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2022-07-27 17:35:09

import numpy as np

from layers import FCLayer

from dataloader import build_dataloader

from network import Network

from optimizer import SGD

from loss import SoftmaxCrossEntropyLoss

from visualize import plot_loss_and_acc

class Solver(object):

    def __init__(self, cfg):

        self.cfg = cfg

        # build dataloader

        train_loader, val_loader, test_loader = self.build_loader(cfg)

        self.train_loader = train_loader

        self.val_loader = val_loader

        self.test_loader = test_loader

        # build model

        self.model = self.build_model(cfg)

        # build optimizer

        self.optimizer = self.build_optimizer(self.model, cfg)

        # build evaluation criterion

        self.criterion = SoftmaxCrossEntropyLoss()

    @staticmethod

    def build_loader(cfg):

        train_loader = build_dataloader(

            cfg['data_root'], cfg['max_epoch'], cfg['batch_size'], shuffle=True, mode='train')

        val_loader = build_dataloader(

            cfg['data_root'], 1, cfg['batch_size'], shuffle=False, mode='val')

        test_loader = build_dataloader(

            cfg['data_root'], 1, cfg['batch_size'], shuffle=False, mode='test')

        return train_loader, val_loader, test_loader

    @staticmethod

    def build_model(cfg):

        model = Network()

        model.add(FCLayer(784, 10))

        return model

    @staticmethod

    def build_optimizer(model, cfg):

        return SGD(model, cfg['learning_rate'], cfg['momentum'])

    def train(self):

        max_epoch = self.cfg['max_epoch']

        epoch_train_loss, epoch_train_acc = [], []

        for epoch in range(max_epoch):

            iteration_train_loss, iteration_train_acc = [], []

            for iteration, (images, labels) in enumerate(self.train_loader):

                # forward pass

                logits = self.model.forward(images)

                loss, acc = self.criterion.forward(logits, labels)

                # backward_pass

                delta = self.criterion.backward()

                self.model.backward(delta)

                # updata the model weights

                self.optimizer.step()

                # restore loss and accuracy

                iteration_train_loss.append(loss)

                iteration_train_acc.append(acc)

                # display iteration training info

                if iteration % self.cfg['display_freq'] == 0:

                    print("Epoch [{}][{}]\t Batch [{}][{}]\t Training Loss {:.4f}\t Accuracy {:.4f}".format(

                        epoch, max_epoch, iteration, len(self.train_loader), loss, acc))

            avg_train_loss, avg_train_acc = np.mean(iteration_train_loss), np.mean(iteration_train_acc)

            epoch_train_loss.append(avg_train_loss)

            epoch_train_acc.append(avg_train_acc)

            # validate

            avg_val_loss, avg_val_acc = self.validate()

            # display epoch training info

            print('\nEpoch [{}]\t Average training loss {:.4f}\t Average training accuracy {:.4f}'.format(

                epoch, avg_train_loss, avg_train_acc))

            # display epoch valiation info

            print('Epoch [{}]\t Average validation loss {:.4f}\t Average validation accuracy {:.4f}\n'.format(

                epoch, avg_val_loss, avg_val_acc))

        return epoch_train_loss, epoch_train_acc

    def validate(self):

        logits_set, labels_set = [], []

        for images, labels in self.val_loader:

            logits = self.model.forward(images)

            logits_set.append(logits)

            labels_set.append(labels)

        logits = np.concatenate(logits_set)

        labels = np.concatenate(labels_set)

        loss, acc = self.criterion.forward(logits, labels)

        return loss, acc

    def test(self):

        logits_set, labels_set = [], []

        for images, labels in self.test_loader:

            logits = self.model.forward(images)

            logits_set.append(logits)

            labels_set.append(labels)

        logits =跟单网gendan5.com np.concatenate(logits_set)

        labels = np.concatenate(labels_set)

        loss, acc = self.criterion.forward(logits, labels)

        return loss, acc

if __name__ == '__main__':

    # You can modify the hyerparameters by yourself.

    relu_cfg = {

        'data_root': 'data',

        'max_epoch': 10,

        'batch_size': 100,

        'learning_rate': 0.1,

        'momentum': 0.9,

        'display_freq': 50,

        'activation_function': 'relu',

    }

    runner = Solver(relu_cfg)

    relu_loss, relu_acc = runner.train()

    test_loss, test_acc = runner.test()

    print('Final test accuracy {:.4f}\n'.format(test_acc))

    # You can modify the hyerparameters by yourself.

    sigmoid_cfg = {

        'data_root': 'data',

        'max_epoch': 10,

        'batch_size': 100,

        'learning_rate': 0.1,

        'momentum': 0.9,

        'display_freq': 50,

        'activation_function': 'sigmoid',

    }

    runner = Solver(sigmoid_cfg)

    sigmoid_loss, sigmoid_acc = runner.train()

    test_loss, test_acc = runner.test()

    print('Final test accuracy {:.4f}\n'.format(test_acc))

    plot_loss_and_acc({

        "relu": [relu_loss, relu_acc],

        "sigmoid": [sigmoid_loss, sigmoid_acc],

    })

阅读(536) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~