深度卷积生成对抗网络(DCGAN)|完整代码实现
2024-01-07 18:46:42
生成对抗网络(GAN)由Ian Goodfellow在2014年提出。GAN通过训练两个神经网络解决了非监督问题。这两个网络一个称为生成网络,一个称为判别网络。
事实上,该网络的训练过程很有趣。我们可以借助一个例子来理解。最初,伪造者(生成网络)向警察(判别网络)展示假币,警察识别出币是假的,伪造者根据接收到的反馈制造了新的假币,如此重复多次,直到伪造者可以造出警察无法识别的假币。

在GAN中,也就是最后得到了可以生成和真实图片非常类似的生成网络,以及可以高度识别伪造品的判别网络。
训练的过程就是两个网络互相博弈的过程,最后达到纳什均衡。
而DCGAN则是GAN的一个变体,它在生成网络和判别网络中使用了卷积层和转置卷积层。
代码如下:
import torchimport torchvisionfrom torch import nnfrom torch import optimfrom torchvision import transformsfrom torchvision.datasets import CIFAR10import matplotlib.pyplot as pltlr = 0.0002nz = 100 # noise dimensionimage_size = 64nc = 3 # chanel of imgngf = 64 # generate channelndf = 64 # discriminative channelbeta1 = 0.5BatchSize = 32max_epoch = 2 #transform=transforms.Compose([transforms.Resize(64) ,transforms.ToTensor(),transforms.Normalize([0.5]*3,[0.5]*3)])dataset=CIFAR10(root='cifar10/',transform=transform,download=True)dataloader=torch.utils.data.DataLoader(dataset,BatchSize,shuffle = True)def weights_init(m):classname=m.__class__.__name__if classname.find('Conv')!=-1:m.weight.data.normal_(0.0,0.02)elif classname.find('BatchNorm')!=-1:m.weight.data.normal_(1.0,0.02)m.bias.data.fill_(0)# define modelclass Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.main = nn.Sequential(nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),nn.BatchNorm2d(ngf*8),nn.ReLU(True),nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),nn.BatchNorm2d(ngf*4),nn.ReLU(True),nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),nn.BatchNorm2d(ngf*2),nn.ReLU(True),nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),nn.ConvTranspose2d(ngf,nc,4,2,1,bias=False),nn.Tanh())def forward(self,input):output=self.main(input)return outputnetG=Generator()netG.apply(weights_init)class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(nc,ndf,4,2,1,bias=False),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*8,1,4,1,0,bias=False),nn.Sigmoid())def forward(self,input):output=self.main(input)return output.view(-1,1).squeeze(1)netD=Discriminator()netD.apply(weights_init)# optimizeroptimizerD = optim.Adam(netD.parameters(),lr,betas=(beta1,0.999))optimizerG = optim.Adam(netG.parameters(),lr,betas=(beta1,0.999))# criterioncriterion = nn.BCELoss()fix_noise = torch.randn(BatchSize,nz,1,1).normal_(0,1)if torch.cuda.is_available():fix_noise = fix_noise.cuda()netG.cuda()netD.cuda()criterion.cuda()print('begin training, be patient')for epoch in range(max_epoch):for ii, data in enumerate(dataloader,0):real,_=databatch_size=real.size(0)input=reallabel = torch.ones(batch_size) # 1 for reallabel2 = torch.zeros(batch_size)noise = torch.randn(batch_size,nz,1,1).normal_(0,1)if torch.cuda.is_available:input = input.cuda()label = label.cuda()label2 = label.cuda()noise = noise.cuda()# ----- train netd -----netD.zero_grad()## train netd with real imgoutput=netD(input)errorD_real=criterion(output,label)errorD_real.backward()D_x=output.data.mean()## train netd with fake imgfake_pic=netG(noise)output2=netD(fake_pic.detach())errorD_fake=criterion(output2,label2)errorD_fake.backward()D_x2=output2.data.mean()error_D=errorD_real+errorD_fakeoptimizerD.step()# ------ train netg -------netG.zero_grad()fake_pic=netG(noise)output=netD(fake_pic)error_G=criterion(output,label)error_G.backward()D_G_z2=output.data.mean()optimizerG.step()#生成图片fake_u=netG(fix_noise)imgs = torchvision.utils.make_grid(fake_u*0.5+0.5).cpu()plt.imshow(imgs.permute(1,2,0).numpy())plt.show()
GAN在多个应用领域都取得了许多令人振奋的结果,比如,利用CycleGAN进行图像转换,利用StackGAN自动从文本中制作逼真的图像,利用SRGAN通过预训练模型提高图像品质...
文章来源:https://blog.csdn.net/m0_57569438/article/details/135441364
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!