import torch
import torch.nn as nn
import torch.optim as optim
from torch import device
from torchvision import datasets, transforms
这段代码导入了 PyTorch 相关的核心库以及用于处理图像数据的torchvision库中的部分模块。torch是 PyTorch 的基础库,nn用于构建神经网络模型,optim用于定义优化器,device用于指定计算设备(如 CPU 或 GPU),datasets和transforms用于加载和预处理图像数据。
# 定义超参数
latent_dim = 100 # 噪声向量维度
image_size = # 图像大小
channels = 3 # 图像通道数(RGB)
batch_size = 128
num_epochs = 100
learning_rate = 0.0002
这里定义了一系列超参数,包括生成对抗网络(GAN)中噪声向量的维度、生成图像的大小、图像的通道数(RGB 图像为 3 通道)、每次训练的批量大小、训练的轮数以及学习率。这些超参数将在后续的模型训练和生成过程中起到关键作用。
# 数据预处理与加载
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = datasets.ImageFolder('path/to/image/dataset', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
- 数据预处理:首先定义了一个
transform操作,它由一系列的图像变换组成。包括将图像调整为指定的大小、进行中心裁剪、转换为张量格式以及对图像的像素值进行归一化处理,使其均值为 0.5,标准差为 0.5。 - 数据加载:使用
ImageFolder类从指定路径加载图像数据集,并应用前面定义的transform操作。然后通过DataLoader将数据集包装成可迭代的数据加载器,以便在训练过程中按批次获取数据,并且设置了随机打乱数据的功能。
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# 输入层
nn.ConvTranspose2d(latent_dim, 512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 层叠反卷积层
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh() # 输出层,使用tanh激活函数以确保生成图像像素值在[-1, 1]范围内
)
def forward(self, input):
return self.main(input)
- 模型结构:生成器是一个继承自
nn.Module的神经网络类。在其构造函数__init__中,定义了一个由多个层组成的main序列模型。包括一个输入层的转置卷积层(将噪声向量转换为具有一定特征维度的张量),随后跟着几个层叠的转置卷积层用于逐步上采样和生成图像特征,每个转置卷积层后都跟着批量归一化层和 ReLU 激活函数,最后输出层使用tanh激活函数将生成图像的像素值映射到[-1, 1]范围内。 - 前向传播:
forward方法定义了数据在模型中的正向传播路径,即输入噪声向量通过main序列模型得到生成的图像。
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 输入层
nn.Conv2d(channels, , kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 层叠卷积层
nn.Conv2d(, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 输出层
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid() # 输出判别概率
)
def forward(self, input):
return self.main(input).view(-1)
- 模型结构:判别器同样是继承自
nn.Module的类。在构造函数中定义的main序列模型由多个卷积层组成。从输入层开始,依次通过多个层叠的卷积层对输入图像进行下采样和特征提取,每个卷积层后跟着批量归一化层和 LeakyReLU 激活函数(斜率为 0.2),最后输出层使用sigmoid函数将判别结果映射到[0, 1]区间,表示输入图像是真实图像的概率。 - 前向传播:
forward方法定义了数据在判别器中的正向传播路径,输入图像通过main序列模型后经过view(-1)操作将输出张量展平为一维向量,得到判别器对输入图像的判别概率。
# 实例化模型
generator = Generator()
discriminator = Discriminator()
# 使用Adam优化器
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 定义损失函数
criterion = nn.BCELoss()
- 模型实例化:分别实例化了生成器和判别器模型。
- 优化器定义:为生成器和判别器分别定义了 Adam 优化器,指定了学习率以及
betas参数(用于控制优化器的动量和自适应学习率的衰减)。 - 损失函数定义:使用
BCELoss作为损失函数,用于衡量判别器和生成器的输出与真实标签之间的差异。这种损失函数适用于二分类问题,在这里用于判断图像是真实的还是生成的。
# 训练循环
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
real_images = real_images.to(device) # 将数据转移到GPU(如果可用)
# 训练判别器
discriminator.zero_grad()
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
# 计算真实图像损失
real_output = discriminator(real_images)
d_real_loss = criterion(real_output, real_labels)
# 生成假图像并计算损失
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
fake_output = discriminator(fake_images.detach()) # detach()避免反向传播到生成器
d_fake_loss = criterion(fake_output, fake_labels)
# 判别器总损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()
# 训练生成器
generator.zero_grad()
fake_labels.fill_(1) # 期望生成器产生的图像被判别器判断为真
# 重新计算生成图像的判别损失
fake_output = discriminator(fake_images)
g_loss = criterion(fake_output, fake_labels)
g_loss.backward()
optimizer_G.step()
if (i + 1) % 100 == 0:
print(f"Epoch [{epoch}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "
f"D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")
- 整体训练流程:通过外层的
for循环遍历训练的轮数,内层的for循环遍历数据加载器中的每个批次数据。 - 判别器训练:
- 首先将判别器的梯度清零,然后定义真实图像的标签为全 1 向量,生成图像的标签为全 0 向量。
- 计算判别器对真实图像的输出与真实标签之间的损失(
d_real_loss),以及对生成图像(通过生成器生成并从计算图中分离,避免反向传播到生成器)的输出与假标签之间的损失(d_fake_loss)。 - 将这两个损失相加得到判别器的总损失
d_loss,然后进行反向传播更新判别器的参数。
- 生成器训练:
- 先将生成器的梯度清零,然后将生成图像的期望标签设置为全 1 向量,表示希望生成器生成的图像能被判别器判断为真实图像。
- 重新计算生成图像通过判别器后的输出与期望标签之间的损失(
g_loss),进行反向传播更新生成器的参数。
- 训练信息打印:每隔 100 个批次打印当前轮数、批次数以及判别器和生成器的损失值。
# 生成新图像
fixed_noise = torch.randn(, latent_dim, device=device)
fake_images = generator(fixed_noise).detach().cpu()
# 可以保存或显示fake_images,查看生成的图像效果
在训练完成后,通过生成一个固定的噪声向量fixed_noise,将其输入到训练好的生成器中得到生成的图像fake_images,然后将其从 GPU(如果在 GPU 上训练)转移到 CPU,并可以进一步保存或显示这些生成的图像来查看生成效果。
总体而言,这段代码实现了一个基本的生成对抗网络(GAN),包括数据预处理、模型定义、训练过程以及生成新图像的功能,用于生成与训练数据相似的新图像。