生成对抗网络(Generative Adversarial Network, GAN)是一种通过对抗训练生成数据的深度学习模型,由生成器(Generator)和判别器(Discriminator)两部分组成,其核心思想源于博弈论中的零和博弈。
一、核心组成
生成器(G)
目标:生成逼真的假数据(如图像、文本),试图欺骗判别器。
输入:随机噪声(通常服从高斯分布或均匀分布)。
输出:合成数据(如假图像)。
判别器(D)
目标:区分真实数据(来自训练集)和生成器合成的假数据。
输出:概率值(0到1),表示输入数据是真实的概率。
二、关于对抗训练
1. 动态博弈
1)生成器尝试生成越来越逼真的数据,使得判别器无法区分真假。
2)判别器则不断优化自身,以更准确地区分真假数据。
3)两者交替训练,最终达到纳什均衡(生成器生成的数据与真实数据分布一致,判别器无法区分,输出概率恒为0.5)。
2. 优化目标(极小极大博弈)
\[\min_{G}{\max_D}V(D,G)=E_{x\sim p_{data}}[logD(x)]+E_{z\sim p_z}[log(1-D(G(z)))]\]
其中,
\(D(x)\):判别器对真实数据的判别结果;
\(G(z)\):生成器生成的假数据;
判别器希望最大化\(V(D,G)\)(正确分类真假数据);
生成器希望最小化\(V(D,G)\)(让判别器无法区分)。
3.交替更新
1) 固定生成器,训练判别器:
用真实数据(标签1)和生成数据(标签0)训练判别器,提高其鉴别能力。
2) 固定判别器,训练生成器:
通过反向传播调整生成器参数,使得判别器对生成数据的输出概率接近1(即欺骗判别器)。
三、典型应用
图像生成:生成逼真的人脸、风景、艺术画(如 DCGAN、StyleGAN);
图像编辑:图像修复(填补缺失区域)、风格迁移(如将照片转为油画风格);
数据增强:为小样本任务生成额外的训练数据;
超分辨率重建:将低分辨率图像恢复为高分辨率图像。
四、优势与挑战
优势
无监督学习:无需对数据进行标注,仅通过真实数据即可训练(适用于标注成本高的场景)。
生成高质量数据:相比其他生成模型(如变分自编码器 VAE),GAN 在图像生成等任务中往往能生成更逼真、细节更丰富的数据。
灵活性:生成器和判别器可以采用不同的网络结构(如卷积神经网络 CNN、循环神经网络 RNN 等),适用于多种数据类型(图像、文本、音频等)。
挑战
训练不稳定:容易出现 “模式崩溃”(生成器只生成少数几种相似数据,缺乏多样性)或难以收敛;
平衡难题:生成器和判别器的能力需要匹配,否则可能一方过强导致另一方无法学习(如判别器太弱,生成器无需优化即可欺骗它);
可解释性差:生成器的内部工作机制难以解释,生成结果的可控性较弱(近年通过改进模型如 StyleGAN 缓解了这一问题)。
五、Python示例
使用 PyTorch 实现简单 的GAN 模型,生成手写数字图像。- import matplotlib
- matplotlib.use('TkAgg')
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torchvision import datasets, transforms
- from torch.utils.data import DataLoader
- import matplotlib.pyplot as plt
- import numpy as np
- plt.rcParams['font.sans-serif']=['SimHei'] # 中文支持
- plt.rcParams['axes.unicode_minus']=False # 负号显示
- # 设置随机种子,确保结果可复现
- torch.manual_seed(42)
- np.random.seed(42)
- # 定义设备
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # 数据加载和预处理
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,)) # 将图像归一化到 [-1, 1]
- ])
- train_dataset = datasets.MNIST(root='./data', train=True,
- download=True, transform=transform)
- train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
- # 定义生成器网络
- class Generator(nn.Module):
- def __init__(self, latent_dim=100, img_dim=784):
- super(Generator, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(latent_dim, 256),
- nn.LeakyReLU(0.2),
- nn.BatchNorm1d(256),
- nn.Linear(256, 512),
- nn.LeakyReLU(0.2),
- nn.BatchNorm1d(512),
- nn.Linear(512, img_dim),
- nn.Tanh() # 输出范围 [-1, 1]
- )
- def forward(self, z):
- return self.model(z).view(z.size(0), 1, 28, 28)
- # 定义判别器网络
- class Discriminator(nn.Module):
- def __init__(self, img_dim=784):
- super(Discriminator, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(img_dim, 512),
- nn.LeakyReLU(0.2),
- nn.Dropout(0.3),
- nn.Linear(512, 256),
- nn.LeakyReLU(0.2),
- nn.Dropout(0.3),
- nn.Linear(256, 1),
- nn.Sigmoid() # 输出概率值
- )
- def forward(self, img):
- img_flat = img.view(img.size(0), -1)
- return self.model(img_flat)
- # 初始化模型
- latent_dim = 100
- generator = Generator(latent_dim).to(device)
- discriminator = Discriminator().to(device)
- # 定义损失函数和优化器
- criterion = nn.BCELoss()
- lr = 0.0002
- g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
- d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
- # 训练函数
- def train_gan(epochs):
- for epoch in range(epochs):
- for i, (real_imgs, _) in enumerate(train_loader):
- batch_size = real_imgs.size(0)
- real_imgs = real_imgs.to(device)
- # 创建标签
- real_labels = torch.ones(batch_size, 1).to(device)
- fake_labels = torch.zeros(batch_size, 1).to(device)
- # ---------------------
- # 训练判别器
- # ---------------------
- d_optimizer.zero_grad()
- # 计算判别器对真实图像的损失
- real_pred = discriminator(real_imgs)
- d_real_loss = criterion(real_pred, real_labels)
- # 生成假图像
- z = torch.randn(batch_size, latent_dim).to(device)
- fake_imgs = generator(z)
- # 计算判别器对假图像的损失
- fake_pred = discriminator(fake_imgs.detach())
- d_fake_loss = criterion(fake_pred, fake_labels)
- # 总判别器损失
- d_loss = d_real_loss + d_fake_loss
- d_loss.backward()
- d_optimizer.step()
- # ---------------------
- # 训练生成器
- # ---------------------
- g_optimizer.zero_grad()
- # 生成假图像
- fake_imgs = generator(z)
- # 计算判别器对假图像的预测
- fake_pred = discriminator(fake_imgs)
- # 生成器希望判别器将假图像判断为真
- g_loss = criterion(fake_pred, real_labels)
- g_loss.backward()
- g_optimizer.step()
- # 打印训练进度
- if i % 100 == 0:
- print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(train_loader)} "
- f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")
- # 每个epoch结束后,生成一些样本图像
- if (epoch + 1) % 10 == 0:
- generate_samples(generator, epoch + 1, latent_dim, device)
- # 生成样本图像
- def generate_samples(generator, epoch, latent_dim, device, n_samples=16):
- generator.eval()
- z = torch.randn(n_samples, latent_dim).to(device)
- with torch.no_grad():
- samples = generator(z).cpu()
- # 可视化生成的样本
- fig, axes = plt.subplots(4, 4, figsize=(8, 8))
- for i, ax in enumerate(axes.flatten()):
- ax.imshow(samples[i][0].numpy(), cmap='gray')
- ax.axis('off')
- plt.tight_layout()
- plt.savefig(f"gan_samples/gan_samples_epoch_{epoch}.png")
- plt.close()
- generator.train()
- # 训练模型
- train_gan(epochs=50)
- # 生成最终样本
- generate_samples(generator, "final", latent_dim, device)
复制代码 最终生成的样本:
六、小结
GAN通过对抗机制实现了强大的生成能力,成为生成模型领域的里程碑技术。衍生变体(如CGAN、CycleGAN等)进一步扩展了其应用场景。
End.
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |