基础GCN
2023-12-27 16:22:15
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
transform=transforms.Compose([
transforms.ToTensor(), # 0-1归一化;c,h,w,
transforms.Normalize(0.5,0.5)
])
train_ds=torchvision.datasets.MNIST('data',
train=True,
transform=transform,
download=True)
dataloader=torch.utils.data.DataLoader(train_ds,batch_size=512,shuffle=True)
imgs,_=next(iter(dataloader))
imgs.shape
#生成器 输入长度100的正态分布随机分布,输出是 (1,28,28)形状的tensor
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main=nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,28*28),
nn.Tanh()
)
def forward(self,x): # x表示长度为100的noise输入
img=self.main(x)
img=img.view(-1,28,28)
return img
# 判别器 输入是一张图片形状的张量,输出为二分类的概率值,输出使用sigmoid激活 0-1
# pytorch提供的二分类损失函数 BCEloss 计算二分类的交叉熵损失 判别器推荐使用 LeakyReLU 激活函数
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.main=nn.Sequential(
nn.Linear(28*28,512),
nn.LeakyReLU(),
nn.Linear(512,256),
nn.LeakyReLU(),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,x):
x=x.view(-1,28*28)
x=self.main(x)
return x
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen=Generator().to(device)
dis=Discriminator().to(device)
d_optim=torch.optim.Adam(dis.parameters(),lr=0.0001)
g_optim=torch.optim.Adam(gen.parameters(),lr=0.0001)
loss_fn=torch.nn.BCELoss()
def gen_img_plot(model,test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4,4))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow((prediction[i]+1)/2)
plt.axis('off')
plt.show()
test_input = torch.randn(16,100,device=device)
D_loss = []
G_loss = []
# 训练循环
for epoch in range(20):
d_epoch_loss = 0
g_epoch_loss = 0
count=len(dataloader)
for step, (img, _) in enumerate(dataloader):
img = img.to(device)
size=img.size(0)
random_noisy=torch.randn(size,100,device=device)
d_optim.zero_grad()
real_output=dis(img) #判别器输入真实图片
d_real_loss = loss_fn(real_output,
torch.ones_like(real_output)) # 判别器在真实图像上的损失
d_real_loss.backward()
gen_img = gen(random_noisy) #生成图像
fake_output = dis(gen_img.detach()) # 判别器输入生成图片,优化目标是判别器,对生成器作梯度截断
d_fake_loss = loss_fn(fake_output,
torch.zeros_like(fake_output)) # 判别器在生成器上的损失
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optim.step()
# 以上是判别器损失和优化,以下是生成器损失和优化
g_optim.zero_grad()
gen_img = gen(random_noisy)
fake_output = dis(gen_img)
g_loss = loss_fn(fake_output,
torch.ones_like(fake_output)) # 生成器的损失
g_loss.backward()
g_optim.step()
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch:',epoch)
gen_img_plot(gen,test_input)
文章来源:https://blog.csdn.net/m0_56294205/article/details/135236350
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!