深度卷积生成对抗网络(DCGAN)|完整代码实现
2024-01-07 18:46:42
生成对抗网络(GAN)由Ian Goodfellow在2014年提出。GAN通过训练两个神经网络解决了非监督问题。这两个网络一个称为生成网络,一个称为判别网络。
事实上,该网络的训练过程很有趣。我们可以借助一个例子来理解。最初,伪造者(生成网络)向警察(判别网络)展示假币,警察识别出币是假的,伪造者根据接收到的反馈制造了新的假币,如此重复多次,直到伪造者可以造出警察无法识别的假币。
在GAN中,也就是最后得到了可以生成和真实图片非常类似的生成网络,以及可以高度识别伪造品的判别网络。
训练的过程就是两个网络互相博弈的过程,最后达到纳什均衡。
而DCGAN则是GAN的一个变体,它在生成网络和判别网络中使用了卷积层和转置卷积层。
代码如下:
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
lr = 0.0002
nz = 100 # noise dimension
image_size = 64
nc = 3 # chanel of img
ngf = 64 # generate channel
ndf = 64 # discriminative channel
beta1 = 0.5
BatchSize = 32
max_epoch = 2 #
transform=transforms.Compose([
transforms.Resize(64) ,
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3)
])
dataset=CIFAR10(root='cifar10/',transform=transform,download=True)
dataloader=torch.utils.data.DataLoader(dataset,BatchSize,shuffle = True)
def weights_init(m):
classname=m.__class__.__name__
if classname.find('Conv')!=-1:
m.weight.data.normal_(0.0,0.02)
elif classname.find('BatchNorm')!=-1:
m.weight.data.normal_(1.0,0.02)
m.bias.data.fill_(0)
# define model
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
nn.BatchNorm2d(ngf*8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf,nc,4,2,1,bias=False),
nn.Tanh()
)
def forward(self,input):
output=self.main(input)
return output
netG=Generator()
netG.apply(weights_init)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc,ndf,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(ndf*8,1,4,1,0,bias=False),
nn.Sigmoid()
)
def forward(self,input):
output=self.main(input)
return output.view(-1,1).squeeze(1)
netD=Discriminator()
netD.apply(weights_init)
# optimizer
optimizerD = optim.Adam(netD.parameters(),lr,betas=(beta1,0.999))
optimizerG = optim.Adam(netG.parameters(),lr,betas=(beta1,0.999))
# criterion
criterion = nn.BCELoss()
fix_noise = torch.randn(BatchSize,nz,1,1).normal_(0,1)
if torch.cuda.is_available():
fix_noise = fix_noise.cuda()
netG.cuda()
netD.cuda()
criterion.cuda()
print('begin training, be patient')
for epoch in range(max_epoch):
for ii, data in enumerate(dataloader,0):
real,_=data
batch_size=real.size(0)
input=real
label = torch.ones(batch_size) # 1 for real
label2 = torch.zeros(batch_size)
noise = torch.randn(batch_size,nz,1,1).normal_(0,1)
if torch.cuda.is_available:
input = input.cuda()
label = label.cuda()
label2 = label.cuda()
noise = noise.cuda()
# ----- train netd -----
netD.zero_grad()
## train netd with real img
output=netD(input)
errorD_real=criterion(output,label)
errorD_real.backward()
D_x=output.data.mean()
## train netd with fake img
fake_pic=netG(noise)
output2=netD(fake_pic.detach())
errorD_fake=criterion(output2,label2)
errorD_fake.backward()
D_x2=output2.data.mean()
error_D=errorD_real+errorD_fake
optimizerD.step()
# ------ train netg -------
netG.zero_grad()
fake_pic=netG(noise)
output=netD(fake_pic)
error_G=criterion(output,label)
error_G.backward()
D_G_z2=output.data.mean()
optimizerG.step()
#生成图片
fake_u=netG(fix_noise)
imgs = torchvision.utils.make_grid(fake_u*0.5+0.5).cpu()
plt.imshow(imgs.permute(1,2,0).numpy())
plt.show()
GAN在多个应用领域都取得了许多令人振奋的结果,比如,利用CycleGAN进行图像转换,利用StackGAN自动从文本中制作逼真的图像,利用SRGAN通过预训练模型提高图像品质...
文章来源:https://blog.csdn.net/m0_57569438/article/details/135441364
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!