第G1周:生成对抗网络(GAN)入门
2023-12-21 19:40:57
前期工作
定义超参数:
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
## 创建文件夹
os.makedirs("D:\GAN-Data\images", exist_ok=True) ## 记录训练过程的图片效果
os.makedirs("D:\GAN-Data\save", exist_ok=True) ## 训练完成时模型保存的位置
os.makedirs("D:\GAN-Data\datasets\mnist", exist_ok=True) ## 下载数据集存放的位置
## 超参数配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500
## 图像的尺寸:(1, 28, 28), 和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)
下载数据集训练模型(以下代码二选一):
方法一:如果?GPU 驱动程序是最新的,并且与安装的 CUDA 版本兼容支持则使用 CUDA 的 PyTorch 下运行模型
## mnist数据集下载
mnist = datasets.MNIST(
root='./datasets/', train=True, download=True, transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置数据到加载器
dataloader = DataLoader(
mnist,
batch_size=batch_size,
shuffle=True,
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_area, 512), # 输入特征数为784,输出为512
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(512, 256), # 输入特征数为512,输出为256
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(256, 1), # 输入特征数为256,输出为1
nn.Sigmoid(), # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
validity = self.model(img_flat) # 通过鉴别器网络
return validity # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
## 模型中间块儿
def block(in_feat, out_feat, normalize=True): # block(in, out )
layers = [nn.Linear(in_feat, out_feat)] # 线性变换将输入映射到out维
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
layers.append(nn.LeakyReLU(0.2, inplace=True)) # 非线性激活函数
return layers
## prod():返回给定轴上的数组元素的乘积:1*28*28=784
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
*block(128, 256), # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
*block(256, 512), # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
*block(512, 1024), # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
nn.Linear(1024, img_area), # 线性变化将输入映射 1024 to 784
nn.Tanh() # 将(784)的数据每一个都映射到[-1, 1]之间
)
## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
def forward(self, z): # 输入的是(64, 100)的噪声数据
imgs = self.model(z) # 噪声数据通过生成器模型
imgs = imgs.view(imgs.size(0), *img_shape) # reshape成(64, 1, 28, 28)
return imgs # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
## 首先需要定义loss的度量方式 (二分类的交叉熵)
criterion = torch.nn.BCELoss()
## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
for epoch in range(n_epochs): # epoch:50
for i, (imgs, _) in enumerate(dataloader): # imgs:(64, 1, 28, 28) _:label(64)
imgs = imgs.view(imgs.size(0), -1) # 将图片展开为28*28=784 imgs:(64, 784)
real_img = Variable(imgs).cuda() # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
real_label = Variable(torch.ones(imgs.size(0), 1)).cuda() ## 定义真实的图片label为1
fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda() ## 定义假的图片的label为0
real_out = discriminator(real_img) # 将真实图片放入判别器中
loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
real_scores = real_out # 得到真实图片的判别值,输出的值越接近1越好
## 计算假的图片的损失
## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() ## 随机生成一些噪声, 大小为(128, 100)
fake_img = generator(z).detach() ## 随机噪声放入生成网络中,生成一张假的图片。
fake_out = discriminator(fake_img) ## 判别器判断假的图片
loss_fake_D = criterion(fake_out, fake_label) ## 得到假的图片的loss
fake_scores = fake_out
## 损失函数和优化
loss_D = loss_real_D + loss_fake_D # 损失包括判真损失和判假损失
optimizer_D.zero_grad() # 在反向传播之前,先将梯度归0
loss_D.backward() # 将误差反向传播
optimizer_D.step() # 更新参数
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() ## 得到随机噪声
fake_img = generator(z) ## 随机噪声输入到生成器中,得到一副假的图片
output = discriminator(fake_img) ## 经过判别器得到的结果
## 损失函数和优化
loss_G = criterion(output, real_label) ## 得到的假的图片与真实的图片的label的loss
optimizer_G.zero_grad() ## 梯度归0
loss_G.backward() ## 进行反向传播
optimizer_G.step() ## step()一般用在反向传播后面,用于更新生成网络的参数
## 打印训练过程中的日志
## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
if ( i + 1 ) % 100 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
)
## 保存训练过程中的图像
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')
?方法二:如果系统没有可用的 CUDA 支持或者您不想使用 GPU 进行计算,可以将模型切换到 CPU 运行。
## mnist数据集下载
mnist = datasets.MNIST(
root='./datasets/', train=True, download=True, transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置数据到加载器
dataloader = DataLoader(
mnist,
batch_size=batch_size,
shuffle=True,
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_area, 512), # 输入特征数为784,输出为512
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(512, 256), # 输入特征数为512,输出为256
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(256, 1), # 输入特征数为256,输出为1
nn.Sigmoid(), # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
validity = self.model(img_flat) # 通过鉴别器网络
return validity # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
## 模型中间块儿
def block(in_feat, out_feat, normalize=True): # block(in, out )
layers = [nn.Linear(in_feat, out_feat)] # 线性变换将输入映射到out维
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
layers.append(nn.LeakyReLU(0.2, inplace=True)) # 非线性激活函数
return layers
## prod():返回给定轴上的数组元素的乘积:1*28*28=784
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
*block(128, 256), # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
*block(256, 512), # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
*block(512, 1024), # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
nn.Linear(1024, img_area), # 线性变化将输入映射 1024 to 784
nn.Tanh() # 将(784)的数据每一个都映射到[-1, 1]之间
)
## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
def forward(self, z): # 输入的是(64, 100)的噪声数据
imgs = self.model(z) # 噪声数据通过生成器模型
imgs = imgs.view(imgs.size(0), *img_shape) # reshape成(64, 1, 28, 28)
return imgs # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
## 将模型切换到CPU
generator = generator.cpu()
discriminator = discriminator.cpu()
## 首先需要定义loss的度量方式 (二分类的交叉熵)
criterion = torch.nn.BCELoss()
## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
for epoch in range(n_epochs): # epoch:50
for i, (imgs, _) in enumerate(dataloader): # imgs:(64, 1, 28, 28) _:label(64)
imgs = imgs.view(imgs.size(0), -1) # 将图片展开为28*28=784 imgs:(64, 784)
real_img = Variable(imgs) # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
real_label = Variable(torch.ones(imgs.size(0), 1)) ## 定义真实的图片label为1
fake_label = Variable(torch.zeros(imgs.size(0), 1)) ## 定义假的图片的label为0
real_out = discriminator(real_img) # 将真实图片放入判别器中
loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
real_scores = real_out # 得到真实图片的判别值,输出的值越接近1越好
## 计算假的图片的损失
## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
z = Variable(torch.randn(imgs.size(0), latent_dim)) ## 随机生成一些噪声, 大小为(128, 100)
fake_img = generator(z).detach() ## 随机噪声放入生成网络中,生成一张假的图片。
fake_out = discriminator(fake_img) ## 判别器判断假的图片
loss_fake_D = criterion(fake_out, fake_label) ## 得到假的图片的loss
fake_scores = fake_out
## 损失函数和优化
loss_D = loss_real_D + loss_fake_D # 损失包括判真损失和判假损失
optimizer_D.zero_grad() # 在反向传播之前,先将梯度归0
loss_D.backward() # 将误差反向传播
optimizer_D.step() # 更新参数
z = Variable(torch.randn(imgs.size(0), latent_dim)) ## 得到随机噪声
fake_img = generator(z) ## 随机噪声输入到生成器中,得到一副假的图片
output = discriminator(fake_img) ## 经过判别器得到的结果
## 损失函数和优化
loss_G = criterion(output, real_label) ## 得到的假的图片与真实的图片的label的loss
optimizer_G.zero_grad() ## 梯度归0
loss_G.backward() ## 进行反向传播
optimizer_G.step() ## step()一般用在反向传播后面,用于更新生成网络的参数
## 打印训练过程中的日志
## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
if (i + 1) % 100 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
)
## 保存训练过程中的图像
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')
[Epoch 0/50] [Batch 99/118] [D loss: 1.348898] [G loss: 0.768190] [D real: 0.680194] [D fake: 0.608467]
[Epoch 1/50] [Batch 99/118] [D loss: 1.088056] [G loss: 0.825508] [D real: 0.743123] [D fake: 0.538369]
[Epoch 2/50] [Batch 99/118] [D loss: 1.114249] [G loss: 1.370899] [D real: 0.772192] [D fake: 0.571547]
[Epoch 3/50] [Batch 99/118] [D loss: 1.074770] [G loss: 1.163773] [D real: 0.687619] [D fake: 0.492250]
[Epoch 4/50] [Batch 99/118] [D loss: 1.126985] [G loss: 0.981320] [D real: 0.586216] [D fake: 0.420572]
[Epoch 5/50] [Batch 99/118] [D loss: 1.402424] [G loss: 0.648313] [D real: 0.352514] [D fake: 0.158908]
[Epoch 6/50] [Batch 99/118] [D loss: 1.128472] [G loss: 1.799203] [D real: 0.806444] [D fake: 0.586730]
[Epoch 7/50] [Batch 99/118] [D loss: 1.066108] [G loss: 1.737643] [D real: 0.764517] [D fake: 0.531414]
[Epoch 8/50] [Batch 99/118] [D loss: 1.162140] [G loss: 1.896749] [D real: 0.797096] [D fake: 0.600104]
[Epoch 9/50] [Batch 99/118] [D loss: 0.931134] [G loss: 1.207216] [D real: 0.600548] [D fake: 0.271762]
[Epoch 10/50] [Batch 99/118] [D loss: 0.906784] [G loss: 1.635424] [D real: 0.649781] [D fake: 0.306829]
[Epoch 11/50] [Batch 99/118] [D loss: 1.128253] [G loss: 0.814038] [D real: 0.455113] [D fake: 0.181365]
[Epoch 12/50] [Batch 99/118] [D loss: 0.656877] [G loss: 2.148012] [D real: 0.799558] [D fake: 0.318620]
[Epoch 13/50] [Batch 99/118] [D loss: 0.758273] [G loss: 1.781502] [D real: 0.849074] [D fake: 0.437607]
[Epoch 14/50] [Batch 99/118] [D loss: 0.982824] [G loss: 2.315076] [D real: 0.795012] [D fake: 0.504426]
[Epoch 15/50] [Batch 99/118] [D loss: 0.846314] [G loss: 1.125144] [D real: 0.594142] [D fake: 0.150700]
[Epoch 16/50] [Batch 99/118] [D loss: 0.788453] [G loss: 1.134926] [D real: 0.598429] [D fake: 0.113131]
[Epoch 17/50] [Batch 99/118] [D loss: 0.860472] [G loss: 1.416159] [D real: 0.554753] [D fake: 0.070943]
[Epoch 18/50] [Batch 99/118] [D loss: 0.729715] [G loss: 2.033889] [D real: 0.813916] [D fake: 0.380846]
[Epoch 19/50] [Batch 99/118] [D loss: 0.699210] [G loss: 2.655535] [D real: 0.845672] [D fake: 0.384237]
[Epoch 20/50] [Batch 99/118] [D loss: 0.608509] [G loss: 1.670838] [D real: 0.758573] [D fake: 0.230607]
[Epoch 21/50] [Batch 99/118] [D loss: 0.669346] [G loss: 2.555538] [D real: 0.817196] [D fake: 0.330501]
[Epoch 22/50] [Batch 99/118] [D loss: 0.811412] [G loss: 3.608017] [D real: 0.880692] [D fake: 0.466917]
[Epoch 23/50] [Batch 99/118] [D loss: 0.879888] [G loss: 1.472922] [D real: 0.610781] [D fake: 0.124722]
[Epoch 24/50] [Batch 99/118] [D loss: 0.767168] [G loss: 3.407905] [D real: 0.906761] [D fake: 0.470930]
[Epoch 25/50] [Batch 99/118] [D loss: 0.534345] [G loss: 2.263444] [D real: 0.890575] [D fake: 0.311338]
[Epoch 26/50] [Batch 99/118] [D loss: 0.473837] [G loss: 1.867095] [D real: 0.807679] [D fake: 0.173667]
[Epoch 27/50] [Batch 99/118] [D loss: 0.672992] [G loss: 2.960940] [D real: 0.846083] [D fake: 0.356172]
[Epoch 28/50] [Batch 99/118] [D loss: 0.726250] [G loss: 2.020569] [D real: 0.658650] [D fake: 0.034624]
[Epoch 29/50] [Batch 99/118] [D loss: 0.503680] [G loss: 2.267217] [D real: 0.826359] [D fake: 0.216285]
[Epoch 30/50] [Batch 99/118] [D loss: 0.987975] [G loss: 1.588039] [D real: 0.544412] [D fake: 0.043705]
[Epoch 31/50] [Batch 99/118] [D loss: 1.162546] [G loss: 2.823585] [D real: 0.729459] [D fake: 0.494907]
[Epoch 32/50] [Batch 99/118] [D loss: 0.924303] [G loss: 1.293745] [D real: 0.582127] [D fake: 0.117892]
[Epoch 33/50] [Batch 99/118] [D loss: 0.747387] [G loss: 2.206166] [D real: 0.797877] [D fake: 0.343705]
[Epoch 34/50] [Batch 99/118] [D loss: 0.623693] [G loss: 3.111738] [D real: 0.898811] [D fake: 0.381497]
[Epoch 35/50] [Batch 99/118] [D loss: 0.567340] [G loss: 2.021876] [D real: 0.757147] [D fake: 0.179998]
[Epoch 36/50] [Batch 99/118] [D loss: 0.727314] [G loss: 1.915004] [D real: 0.755489] [D fake: 0.287818]
[Epoch 37/50] [Batch 99/118] [D loss: 0.826854] [G loss: 1.472841] [D real: 0.674383] [D fake: 0.238948]
[Epoch 38/50] [Batch 99/118] [D loss: 1.143365] [G loss: 0.757286] [D real: 0.489352] [D fake: 0.069917]
[Epoch 39/50] [Batch 99/118] [D loss: 0.818748] [G loss: 1.114080] [D real: 0.601727] [D fake: 0.127117]
[Epoch 40/50] [Batch 99/118] [D loss: 0.918430] [G loss: 1.276388] [D real: 0.629529] [D fake: 0.249412]
[Epoch 41/50] [Batch 99/118] [D loss: 0.727234] [G loss: 1.541735] [D real: 0.718813] [D fake: 0.211931]
[Epoch 42/50] [Batch 99/118] [D loss: 0.979106] [G loss: 0.957361] [D real: 0.568877] [D fake: 0.127781]
[Epoch 43/50] [Batch 99/118] [D loss: 0.683977] [G loss: 1.902616] [D real: 0.765684] [D fake: 0.275655]
[Epoch 44/50] [Batch 99/118] [D loss: 0.681833] [G loss: 2.164286] [D real: 0.775665] [D fake: 0.293220]
[Epoch 45/50] [Batch 99/118] [D loss: 0.762346] [G loss: 1.543463] [D real: 0.613166] [D fake: 0.084738]
[Epoch 46/50] [Batch 99/118] [D loss: 0.780659] [G loss: 1.477143] [D real: 0.697691] [D fake: 0.234303]
[Epoch 47/50] [Batch 99/118] [D loss: 0.709177] [G loss: 1.837770] [D real: 0.750658] [D fake: 0.254356]
[Epoch 48/50] [Batch 99/118] [D loss: 0.884956] [G loss: 2.488509] [D real: 0.832655] [D fake: 0.457649]
[Epoch 49/50] [Batch 99/118] [D loss: 0.990627] [G loss: 4.116466] [D real: 0.913515] [D fake: 0.563187]
文章来源:https://blog.csdn.net/qq_60245590/article/details/135022925
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!