Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3657017
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2022-10-07 11:36:54

from torch import nn

import torch

from torch.nn import functional as F

class Conv_Block(nn.Module):   #卷积

    def __init__(self, in_channel, out_channel):

        super(Conv_Block, self).__init__()

        self.layer = nn.Sequential(

            nn.Conv2d(in_channel, out_channel, 3, 1, 1, padding_mode='reflect',

                      bias=False),

            nn.BatchNorm2d(out_channel),

            nn.Dropout2d(0.3),

            nn.LeakyReLU(),

            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect',

                      bias=False),

            nn.BatchNorm2d(out_channel),

            nn.Dropout2d(0.3),

            nn.LeakyReLU()

            )

    def forward(self, x):

        return self.layer(x)

class DownSample(nn.Module):    #下采样

    def __init__(self, channel):

        super(DownSample, self).__init__()

        self.layer =外汇跟单gendan5.com nn.Sequential(

            nn.Conv2d(channel, channel,3,2,1,padding_mode='reflect',

                      bias=False),

            nn.BatchNorm2d(channel),

            nn.LeakyReLU()

            )

    def forward(self,x):

        return self.layer(x)

class UpSample(nn.Module):   #上采样({BANNED}最佳邻近插值法)

    def __init__(self, channel):

        super(UpSample, self).__init__()

        self.layer = nn.Conv2d(channel, channel//2,1,1)

    def forward(self,x, feature_map):

        up = F.interpolate(x, scale_factor=2, mode='nearest')

        out = self.layer(up)

        return torch.cat((out,feature_map),dim=1)

class UNet(nn.Module):

    def __init__(self):

        super(UNet, self).__init__()

        self.c1=Conv_Block(3,64)

        self.d1=DownSample(64)

        self.c2=Conv_Block(64, 128)

        self.d2=DownSample(128)

        self.c3=Conv_Block(128,256)

        self.d3=DownSample(256)

        self.c4=Conv_Block(256,512)

        self.d4=DownSample(512)

        self.c5=Conv_Block(512,1024)

        self.u1=UpSample(1024)

        self.c6=Conv_Block(1024,512)

        self.u2=UpSample(512)

        self.c7=Conv_Block(512,256)

        self.u3=UpSample(256)

        self.c8=Conv_Block(256,128)

        self.u4=UpSample(128)

        self.c9=Conv_Block(128,64)

        self.out = nn.Conv2d(64,3,3,1,1)

        self.Th = nn.Sigmoid()

    def forward(self,x):

        R1 = self.c1(x)

        R2 = self.c2(self.d1(R1))

        R3 = self.c3(self.d2(R2))

        R4 = self.c4(self.d3(R3))

        R5 = self.c5(self.d4(R4))

        O1 = self.c6(self.u1(R5,R4))

        O2 = self.c7(self.u2(O1,R3))

        O3 = self.c8(self.u3(O2,R2))

        O4 = self.c9(self.u4(O3,R1))

        return self.Th(self.out(O4))

if __name__ == "__main__":

    x = torch.randn(2, 3, 256, 256)

    net  = UNet()

    print(net(x).shape)

阅读(815) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~