05360171创建于 2022年3月18日历史提交
# Copyright 2020 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

# ============================================================================

import torch

import torch.nn as nn





def weights_init_normal(m):

    class_name = m.__class__.__name__

    if class_name.find("Conv") != -1:

        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif class_name.find("BatchNorm2d") != -1:

        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)

        torch.nn.init.constant_(m.bias.data, 0.0)





class Generator(nn.Module):

    def __init__(self, img_size, latent_dim, channels):

        super(Generator, self).__init__()



        self.init_size = img_size // 4

        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))



        self.conv_blocks = nn.Sequential(

            nn.BatchNorm2d(128),

            nn.Upsample(scale_factor=2),

            nn.Conv2d(128, 128, 3, stride=1, padding=1),

            nn.BatchNorm2d(128, 0.8),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Upsample(scale_factor=2),

            nn.Conv2d(128, 32, 3, stride=1, padding=1),

            nn.BatchNorm2d(32, 0.8),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, channels, 3, stride=1, padding=1),

            nn.Tanh()

        )



    def forward(self, z):

        out = self.l1(z)

        out = out.view(out.shape[0], 128, self.init_size, self.init_size)

        img = self.conv_blocks(out)

        return img





class Discriminator(nn.Module):

    def __init__(self, img_size, channels):

        super(Discriminator, self).__init__()



        def discriminator_block(in_filters, out_filters, bn=True):

            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]

            if bn:

                block.append(nn.BatchNorm2d(out_filters, 0.8))

            return block



        self.model = nn.Sequential(

            *discriminator_block(channels, 16, bn=False),

            *discriminator_block(16, 32),

            *discriminator_block(32, 64),

            *discriminator_block(64, 128)

        )



        # The height and width of down_sampled image

        ds_size = img_size // 2 ** 4

        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1))



    def forward(self, img):

        out = self.model(img)

        out = out.view(out.shape[0], -1)

        validity = self.adv_layer(out)



        return validity