使用残差网络resnet与WGAN制作一个生成二次元人物头像的GAN(pytorch)

a GAN using Wasserstein loss and resnet to generate anime pics.

一个resnet-WGAN用于生成各种二次元头像(你也可以使用别的图像数据集,用于生成图片)

@本项目用于深度学习中的学习交流,如有任何问题,欢迎联系我!联系方式QQ:741533684

#我使用了残差模块设计了了两个相对对称的残差网络,分别做生成对抗网络的的生成器与判别器,基本原理其实与DCGAN类似。在此基础上,使用了不同于Binary cross entropy loss的Wasserstein loss, 并将优化器从Adam修改为RMSprop(注意:Adam会导致训练不稳定,建议使用RMSprop或者SGD,且学习率不能太大,最好使用学习率衰减。)

之后我会上传我训练的模型,以供大家使用作为预训练模型。

train.py文件重要代码如下:


def weights_init(m):  # 初始化模型权重
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=96)
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--epoch', type=int, default=500, help='number of epochs to train for')
parser.add_argument('--lrd', type=float, default=5e-5,
                    help="Discriminator's learning rate, default=0.00005")  # Discriminator's learning rate
parser.add_argument('--lrg', type=float, default=5e-5,
                    help="Generator's learning rate, default=0.00005")  # Generator's learning rate
parser.add_argument('--data_path', default='data/', help='folder to train data')  # 将数据集放在此处
parser.add_argument('--outf', default='imgv3/',
                    help='folder to output images and model checkpoints')  # 输出生成图片以及保存模型的位置
opt = parser.parse_args()
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 图像读入与预处理
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(opt.imageSize),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)

dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=opt.batchSize,
    shuffle=True,
    drop_last=True,
)
netG = NetG().to(device)
netG.apply(weights_init)
print('Generator:' )
print(sum(p.numel() for p in netG.parameters()))

netD = NetD().to(device)
netD.apply(weights_init)
print('Discriminator')
print(sum(p.numel() for p in netD.parameters()))

print(dataset)

netG.load_state_dict(torch.load('imgv2.5/netG_0280.pth', map_location=device))  # 这两句用来读取预训练模型
netD.load_state_dict(torch.load('imgv2.5/netD_0280.pth', map_location=device))  # 这两句用来读取预训练模型
criterionG = Hinge()
optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.lrg)
optimizerD = torch.optim.RMSprop(netD.parameters(), lr=opt.lrd)
lrd_scheduler    = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=5, gamma=0.92)
lrg_scheduler    = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=5, gamma=0.92)
criterionD = Hinge()
criterion = torch.nn.BCELoss()
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0
total_lossD = 0.0
total_lossG = 0.0
label = label.unsqueeze(1)

start_epoch = 280 # 设置初始epoch大小
for epoch in range(start_epoch + 1, opt.epoch + 1):
    with tqdm(total=len(dataloader), desc=f'Epoch {epoch}/{opt.epoch}', postfix=dict, mininterval=0.3) as pbar:
        for i, (imgs, _) in enumerate(dataloader):
            # 固定生成器G,训练鉴别器D
            # 让D尽可能的把真图片判别为1
            imgs = imgs.to(device)
            # for k in range(1,5):
            outputreal = netD(imgs)

            optimizerD.zero_grad()
            ## 让D尽可能把假图片判别为0
            # label.data.fill_(fake_label)
            noise = torch.randn(opt.batchSize, opt.nz)
            # noise = torch.randn(opt.batchSize, opt.nz)
            noise = noise.to(device)

            fake = netG(noise)  # 生成假图
            outputfake = netD(fake.detach())  # 避免梯度传到G,因为G不用更新
            lossD = criterionD(outputreal, outputfake)
            total_lossD += lossD.item()
            lossD.backward()
            optimizerD.step()
            # 固定鉴别器D,训练生成器G
            optimizerG.zero_grad()
            # 让D尽可能把G生成的假图判别为1

            output = netD(fake)
            lossG = criterionG(output)
            total_lossG += lossG.item()
            lossG.backward()
            optimizerG.step()
            # print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f'% (epoch, opt.epoch, i, len(dataloader), lossD.item(), lossG.item()))
            pbar.set_postfix(**{'total_lossD': total_lossD / (i + 1),
                                'lrd':get_lr(optimizerD), 'total_lossG': total_lossG / (i + 1), 'lrg': get_lr(optimizerG)})
            pbar.update(1)

    lrg_scheduler.step()
    lrd_scheduler.step()
    vutils.save_image(fake.data,
                      '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                      normalize=True)
    log = open("./log.txt", 'a')
    print('[%d/%d] total_Loss_D: %.3f total_Loss_G %.3f' % (
    epoch, opt.epoch, total_lossD / (len(dataloader)), total_lossG / (len(dataloader))),
          file=log)
    total_lossG = 0.0
    total_lossD = 0.0
    log.close()
    if epoch % 5 == 0:  # 每5个epoch,保存一次模型参数.
        torch.save(netG.state_dict(), '%s/netG_%04d.pth' % (opt.outf, epoch))
        torch.save(netD.state_dict(), '%s/netD_%04d.pth' % (opt.outf, epoch))

以下是残差模块,residual block的定义:


class BasicBlock(nn.Module):
    def __init__(self, in1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in1, in1 * 2, kernel_size=1,
                               stride=1, padding=0, bias=False)
        self.bn1 =nn.BatchNorm2d(in1*2)
        self.relu1 = nn.LeakyReLU(0.2)

        self.conv2 = nn.Conv2d(in1*2, in1, kernel_size=3,
                        stride=1, padding=1, bias=False)
        self.bn2 =nn.BatchNorm2d(in1)
        self.relu2 = nn.LeakyReLU(0.2)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
      #  out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
      #  out = self.bn2(out)
        out = self.relu2(out)

        out += residual
        return out

损失函数使用了wasserstein loss,相比于BCEloss(JS距离),能准确衡量生成器产生图片的质量,而Hinge loss相对于W loss来说,能解决其梯度爆炸导致训练不稳定的问题。

class Wasserstein(nn.Module):
    def forward(self, pred_real, pred_fake=None):
        if pred_fake is not None:
            loss_real = -pred_real.mean()
            loss_fake = pred_fake.mean()
            loss = loss_real + loss_fake
            return loss
        else:
            loss = -pred_real.mean()
            return loss


class Hinge(nn.Module):#与Wasserstein相比,Hinge能防止梯度暴增。
    def forward(self, pred_real, pred_fake=None):
        if pred_fake is not None:
            loss_real = F.relu(1 - pred_real).mean()
            loss_fake = F.relu(1 + pred_fake).mean()
            return loss_real + loss_fake
        else:
            loss = -pred_real.mean()
            return loss

所使用的判别器代码:


class RestNet18(nn.Module):
    def __init__(self):
        super(RestNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3 ,stride=1, padding=1)


        self.layer1 = nn.Sequential(
            BasicBlock(64),
            nn.AvgPool2d(3, 2),
            BasicBlock(64),

            BasicBlock(64),

        )

        self.layer2 = nn.Sequential(
            nn.AvgPool2d(3,2),
            BasicBlock(64),

            BasicBlock(64),
                                   )

        self.layer3 = nn.Sequential(
            nn.AvgPool2d(3, 2),
            BasicBlock(64),

            BasicBlock(64),
                                    )

        self.layer4 = nn.Sequential(
            nn.AvgPool2d(3, 2),
            BasicBlock(64),
            BasicBlock(64)
          #  nn.LayerNorm([64,5,5]),
          )

        self.layer5 = nn.Sequential(
            nn.BatchNorm2d(64),
         #   nn.LayerNorm([64,5,5]),
            nn.ReLU(True)
        )
        self.fc = nn.Sequential(

          nn.Linear(1600, 1),

        )
    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = torch.flatten(out,start_dim=1)
        out = self.fc(out)
        out = F.sigmoid(out)
        return out

生成器代码如下:


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(nz, 64*3*3)
        self.layer1 = nn.Sequential(
            BasicBlock(64),
            nn.UpsamplingNearest2d(scale_factor=2),
            BasicBlock(64),
            nn.UpsamplingNearest2d(scale_factor=2),

        )
        self.layer2 = nn.Sequential(
            BasicBlock(64),
            nn.UpsamplingNearest2d(scale_factor=2),
            BasicBlock(64),
            nn.UpsamplingNearest2d(scale_factor=2)

        )
        self.layer3 = nn.Sequential(
            BasicBlock(64),
            BasicBlock(64),
            nn.UpsamplingNearest2d(scale_factor=2)

        )
        self.layer4 = nn.Sequential(
            BasicBlock(64),

        )
        self.Conv = nn.Sequential(
            BasicBlock(64),
            nn.BatchNorm2d(64),
          #  nn.LayerNorm([64,96,96]),
            nn.ReLU(True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1, stride=1),
            nn.Tanh()
        )


    def forward(self, z):
        x = self.linear(z)
        x = x.view(batch_size,64,3,3)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.Conv(x)
        return x

所有代码在我的Github中已经上传:

https://github.com/rabbitdeng/anime-WGAN-resnet-pytorch

readme.md中有所使用的数据集的百度云盘链接!

你可能感兴趣的