【VAE】使用变分自编码器生成MNIST数字图像
一、AE(AutoEncoder)-自编码器
自编码器是一种无监督学习模型,其核心思想是:将输入样本 x 经过编码器(Encoder)压缩到一个潜在空间(Latent Space),再通过解码器(Decoder)将其还原为输出 \hat{x}。
然而,传统自编码器存在一个问题:其潜在空间通常是离散且不连续的。这意味着,如果我们在潜在空间中随机采样一个点并输入给解码器,往往会得到无意义的结果——因为这些点可能并不对应任何真实样本的概率分布区域,也就是说,它们“落在空白区”,无法生成合理的样本。
二、VAE-变分自编码器
为了解决潜在空间不连续的问题,VAE在自编码器的基础上引入了概率建模思想。它不再直接将输入压缩成一个固定向量,而是将每个样本编码为一个高斯分布,用均值向量 μ和方差向量 σ2来表示。每次输入一个样本,都会通过Encoder将其变为 μ 和 σ
在训练过程中,VAE通过优化目标函数,使这些高斯分布在潜在空间中尽可能连续且符合标准正态分布 N(0,1)。这样,在生成阶段,我们就可以直接从 N(0,1)中采样潜在变量 zz,再通过解码器生成新的样本。
可以说,VAE通过“给潜在空间加上分布约束”,让模型具备了真正的生成能力。
三、实例:MNIST手写数字的图像生成
1、实际表现
训练Epoch = 50,进行输出


效果只能说可以看,但是还需要优化。通过VAE生成的图像有个很明显的特点,就是中间清晰,四周模糊。以下是造成周围模糊的三个主要原因:
重构目标的均方误差(MSE)导致的模糊性
VAE通常采用均方误差作为重构损失,这相当于假设像素服从独立的高斯分布。模型在优化时会趋向生成“平均意义上的正确像素值”,从而在多个可能结果之间取均值,导致输出图像整体变得平滑、细节模糊。潜在空间采样的不确定性
由于VAE在潜在空间中对每个输入都引入了随机采样(即 z∼N(μ,σ2)),生成时的这种随机性会带来细节的模糊扩散,尤其在解码器的非线性映射不够强时更为明显。潜在空间的全局性编码倾向
编码器通常优先捕捉输入图像的整体结构信息,而对局部高频细节(如边缘、纹理)保留不足。因此生成的图像中心部分(结构清晰的区域)通常重构较好,而边缘部分往往较为模糊。
2、代码
class VAE(nn.Module): def __init__(self, latent_dim): super(VAE, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), # 28x28 -> 14x14 nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7 nn.ReLU(), nn.Flatten(), nn.Linear(64 * 7 * 7, 128), nn.ReLU() ) self.fc_mu = nn.Linear(128, latent_dim) self.fc_logvar = nn.Linear(128, latent_dim) self.decoder = nn.Sequential( nn.Linear(latent_dim, 64 * 7 * 7), nn.ReLU(), nn.Unflatten(1, (64, 7, 7)), nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # 7x7 -> 14x14 nn.ReLU(), nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 14x14 -> 28x28 nn.Sigmoid() ) def encode(self, x): h = self.encoder(x) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): return self.decoder(z) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon_x = self.decode(z) return recon_x, mu, logvar def vae_loss(recon_x, x, mu, logvar): recon_loss = nn.functional.mse_loss(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum') kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return recon_loss + kl_loss
四、扩展应用
在使用的时候,一般都是以VAE作为编码器使用,通过输入样本,将样本压缩至潜在空间,然后进行后续处理;处理后再通过冻结的预训练VAE解码器进行解码输出。
登录后方可回帖