05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the BSD 3-Clause License  (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# https://opensource.org/licenses/BSD-3-Clause

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import torch.nn as nn

import numpy as np



channels = 1

image_size = 28

img_shape =(channels,image_size,image_size)

latent_dim = 100



class Generator(nn.Module):

    def __init__(self):

        super(Generator, self).__init__()



        def block(in_feat, out_feat, normalize=True):

            layers = [nn.Linear(in_feat, out_feat)]

            if normalize:

                layers.append(nn.BatchNorm1d(out_feat, 0.8))

            layers.append(nn.LeakyReLU(0.2, inplace=True))

            return layers



        self.model = nn.Sequential(

            *block(latent_dim, 128, normalize=False),

            *block(128, 256),

            *block(256, 512),

            *block(512, 1024),

            nn.Linear(1024, int(np.prod(img_shape))),

            nn.Tanh()

        )



    def forward(self, z):

        img = self.model(z)

        img = img.view(img.size(0), *img_shape)

        return img



class Discriminator(nn.Module):

    def __init__(self):

        super(Discriminator, self).__init__()



        self.model = nn.Sequential(

            nn.Linear(int(np.prod(img_shape)), 512),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),

            nn.Sigmoid(),

        )



    def forward(self, img):

        img_flat = img.view(img.size(0), -1)

        validity = self.model(img_flat)



        return validity