"VAE + DDPM for MNIST Digit Generation – Conditional Sampling Example
In this implementation, we compress images into a latent space using a VAE, then apply a DDPM (Denoising Diffusion Probabilistic Model) to the latent representations. The entire approach was designed with the help of ChatGPT.
The VAE uses ResidualConvBlock + GroupNorm + SiLU + Dropout to improve stability and expressive power, and combines MSE loss with an edge loss in the reconstruction objective to preserve sharp contours.
By increasing the number of VAE training steps, we were able to significantly reduce the jagged artifacts in generated images.
You can specify a category (digit 0–9) to sample, allowing not only random generation but also conditional generation where the model outputs the desired digit."
# =========================================================
# MNIST: VAE(latent) + DDPM(latent) with Class-Conditional CFG
# =========================================================
import os, math, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm
def seed_everything(seed=42):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ================= Improved CNN VAE for MNIST (28x28x1) =================
class ResidualConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, groups=8, dropout=0.05):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(groups, in_ch),
nn.SiLU(),
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
)
self.block2 = nn.Sequential(
nn.GroupNorm(groups, out_ch),
nn.SiLU(),
nn.Dropout2d(dropout),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
)
self.skip = (nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity())
def forward(self, x):
h = self.block1(x)
h = self.block2(h)
return h + self.skip(x)
class ConvEncoder(nn.Module):
# 28x28 -> 14x14 -> 7x7
def __init__(self, latent_dim=16, ch=48):
super().__init__()
self.stem = nn.Conv2d(1, ch, 3, padding=1)
self.res1 = ResidualConvBlock(ch, ch) # 28x28
self.down1 = nn.Conv2d(ch, ch*2, 4, stride=2, padding=1) # ->14x14
self.res2 = ResidualConvBlock(ch*2, ch*2) # 14x14
self.down2 = nn.Conv2d(ch*2, ch*4, 4, stride=2, padding=1) # ->7x7
self.res3 = ResidualConvBlock(ch*4, ch*4) # 7x7
self.proj = nn.Linear((ch*4)*7*7, 256)
self.mu = nn.Linear(256, latent_dim)
self.logvar = nn.Linear(256, latent_dim)
def forward(self, x):
h = self.stem(x)
h = self.res1(h)
h = self.down1(h)
h = self.res2(h)
h = self.down2(h)
h = self.res3(h)
h = h.view(h.size(0), -1)
h = F.silu(self.proj(h))
mu, logvar = self.mu(h), self.logvar(h)
return mu, logvar
class ConvDecoder(nn.Module):
# 7x7 -> 14x14 -> 28x28
def __init__(self, latent_dim=16, ch=48):
super().__init__()
self.fc = nn.Linear(latent_dim, (ch*4)*7*7)
self.res3 = ResidualConvBlock(ch*4, ch*4) # 7x7
self.up1 = nn.ConvTranspose2d(ch*4, ch*2, 4, stride=2, padding=1) # ->14x14
self.res2 = ResidualConvBlock(ch*2, ch*2) # 14x14
self.up2 = nn.ConvTranspose2d(ch*2, ch, 4, stride=2, padding=1) # ->28x28
self.res1 = ResidualConvBlock(ch, ch) # 28x28
self.head = nn.Conv2d(ch, 1, 3, padding=1)
def forward(self, z):
h = F.silu(self.fc(z)).view(z.size(0), -1, 7, 7)
h = self.res3(h)
h = self.up1(h)
h = self.res2(h)
h = self.up2(h)
h = self.res1(h)
x_logits = self.head(h)
return x_logits
class VAE(nn.Module):
def __init__(self, latent_dim=16, ch=48):
super().__init__()
self.enc = ConvEncoder(latent_dim, ch)
self.dec = ConvDecoder(latent_dim, ch)
def reparam(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.enc(x)
z = self.reparam(mu, logvar)
x_logits = self.dec(z)
return x_logits, mu, logvar, z
# --------- Edge-aware Reconstruction Loss + KL Warmup ----------
def image_gradients(x):
kx = torch.tensor([[[[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]]]], device=x.device, dtype=x.dtype)
ky = torch.tensor([[[[-1, -2, -1],
[ 0, 0, 0],
[ 1, 2, 1]]]], device=x.device, dtype=x.dtype)
gx = F.conv2d(x, kx, padding=1)
gy = F.conv2d(x, ky, padding=1)
return torch.sqrt(gx*gx + gy*gy + 1e-6)
def vae_loss(x, x_logits, mu, logvar, recon_type="bce", edge_lambda=0.05, beta=1.0):
if recon_type == "mse":
x_hat = torch.sigmoid(x_logits)
recon = F.mse_loss(x_hat, x, reduction="sum")
else:
recon = F.binary_cross_entropy_with_logits(x_logits, x, reduction="sum")
if edge_lambda > 0:
x_hat_prob = torch.sigmoid(x_logits)
gx = image_gradients(x)
gx_hat = image_gradients(x_hat_prob)
edge = F.l1_loss(gx_hat, gx, reduction="sum")
recon = recon + edge_lambda * edge
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon + beta * kl, recon, kl
def kl_warmup(epoch, total_epochs, min_beta=0.0, max_beta=0.5):
t = epoch / max(1, total_epochs)
beta = min_beta + 0.5*(max_beta - min_beta)*(1 - math.cos(math.pi * min(t, 1.0)))
return beta
# --------- Training (VAE) ----------
def train_vae(model, loader, device, epochs=10, lr=1e-3, out_dir="runs_vae",
recon="bce", edge_lambda=0.05, beta_max=0.5, beta_min=0.0,
save_recon_every_epoch=True, weight_decay=1e-4):
os.makedirs(out_dir, exist_ok=True)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
model.to(device)
for ep in range(1, epochs+1):
model.train()
beta = kl_warmup(ep, epochs, min_beta=beta_min, max_beta=beta_max)
pbar = tqdm(loader, desc=f"[VAE*] epoch {ep}/{epochs} (beta={beta:.3f})")
total, total_rec, total_kl = 0.0, 0.0, 0.0
for x, _ in pbar:
x = x.to(device)
with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
x_logits, mu, logvar, _ = model(x)
loss, rec, kl = vae_loss(x, x_logits, mu, logvar,
recon_type=recon, edge_lambda=edge_lambda, beta=beta)
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt); scaler.update()
total += loss.item(); total_rec += rec.item(); total_kl += kl.item()
pbar.set_postfix(loss=f"{total/len(loader):.2f}",
rec=f"{total_rec/len(loader):.2f}",
kl=f"{total_kl/len(loader):.2f}")
if save_recon_every_epoch:
with torch.no_grad():
x_vis = next(iter(loader))[0][:64].to(device)
x_logits, _, _, _ = model(x_vis)
x_hat = torch.sigmoid(x_logits)
grid = vutils.make_grid(torch.cat([x_vis, x_hat], dim=0), nrow=16)
from torchvision.utils import save_image
save_image(grid, os.path.join(out_dir, f"recon_ep{ep}.png"))
return model
# --------------------
# DDPM
# --------------------
@dataclass
class DDPMConfig:
timesteps: int = 1000
beta_start: float = 1e-4
beta_end: float = 0.02
class TimeEmbedding(nn.Module):
def __init__(self, embedding_dim=64, max_period=10000):
super().__init__()
self.embedding_dim = embedding_dim
self.max_period = max_period
def forward(self, t):
half = self.embedding_dim // 2
device = t.device
freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=device) / half)
args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.embedding_dim % 2 == 1:
emb = F.pad(emb, (0,1))
return emb
class LatentNoisePredictor(nn.Module):
def __init__(self, latent_dim, num_classes=10, hidden=768, emb_dim=64):
super().__init__()
self.te = TimeEmbedding(emb_dim)
self.ce = nn.Embedding(num_classes + 1, emb_dim) # +1: null token
nn.init.normal_(self.ce.weight, std=0.02)
self.null_class_idx = num_classes
self.fc1 = nn.Linear(latent_dim + emb_dim + emb_dim, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.out = nn.Linear(hidden, latent_dim)
def forward(self, z_t, t, y=None):
if y is None:
y = torch.full((z_t.size(0),), self.null_class_idx, device=z_t.device, dtype=torch.long)
te = self.te(t)
ce = self.ce(y)
x = torch.cat([z_t, te, ce], dim=1)
h = F.silu(self.fc1(x))
h = F.silu(self.fc2(h))
return self.out(h)
class DDPM:
def __init__(self, cfg, device):
self.cfg = cfg
self.device = device
# ---- Cosine ᾱ(t) ----
def _cosine_alphas_cumprod(T, s=0.008):
steps = torch.arange(T+1, dtype=torch.float32)
f = torch.cos(((steps/T + s) / (1+s)) * math.pi * 0.5) ** 2
a_bar = f / f[0]
return a_bar[1:]
alphas_cumprod = _cosine_alphas_cumprod(cfg.timesteps).to(device) # ᾱ_t
alphas = torch.cat([alphas_cumprod[0:1],
(alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(1e-8, 1.0)], dim=0)
betas = 1.0 - alphas
self.betas = betas
self.alphas = alphas
self.alphas_cumprod = alphas_cumprod
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
def q_sample(self, x0, t, noise=None):
if noise is None:
noise = torch.randn_like(x0)
return (
self.sqrt_alphas_cumprod[t].unsqueeze(1) * x0 +
self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1) * noise,
noise
)
def p_sample(self, model, x_t, t, y=None, guidance_w=0.0):
beta_t = self.betas[t].unsqueeze(1)
alpha_t = self.alphas[t].unsqueeze(1)
alpha_bar_t = self.alphas_cumprod[t].unsqueeze(1)
sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t)
if guidance_w == 0.0 or y is None:
eps_theta = model(x_t, t, y)
else:
eps_cond = model(x_t, t, y)
null_y = torch.full_like(y, model.null_class_idx)
eps_uncond = model(x_t, t, null_y)
eps_theta = (1 + guidance_w) * eps_cond - guidance_w * eps_uncond
mean = (torch.sqrt(alpha_t) * (x_t - (beta_t / sqrt_one_minus_ab) * eps_theta))
if (t > 0).all():
return mean + torch.sqrt(beta_t) * torch.randn_like(x_t)
else:
return mean
@torch.no_grad()
def sample(self, model, shape, decode_fn, n_steps=None, y=None, guidance_w=0.0):
model.eval()
B, latent_dim = shape
x_t = torch.randn(B, latent_dim, device=self.device)
timesteps = self.cfg.timesteps if n_steps is None else n_steps
t_space = torch.linspace(self.cfg.timesteps-1, 0, timesteps, dtype=torch.long, device=self.device)
for t_scalar in t_space:
t = torch.full((B,), int(t_scalar.item()), device=self.device, dtype=torch.long)
x_t = self.p_sample(model, x_t, t, y=y, guidance_w=guidance_w)
return decode_fn(x_t).clamp(0, 1)
# --------- EMA ----------
class EMA:
def __init__(self, model, decay=0.999):
self.decay = decay
self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model):
for k, v in model.state_dict().items():
self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1 - self.decay)
def copy_to(self, model):
model.load_state_dict(self.shadow, strict=True)
# --------- Training (DDPM) ----------
def train_ddpm(latent_net, ddpm, vae, loader, device,
epochs=5, lr=2e-4, num_classes=10, cond_drop_prob=0.2,
weight_decay=0.0, ema_decay=0.999):
opt = torch.optim.AdamW(latent_net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=5, T_mult=2)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
vae.eval()
ema = EMA(latent_net, decay=ema_decay)
for ep in range(1, epochs+1):
latent_net.train()
pbar = tqdm(loader, desc=f"[DDPM] {ep}/{epochs}")
for x, y in pbar:
x = x.to(device); y = y.to(device)
with torch.no_grad():
mu, logvar = vae.enc(x)
z0 = vae.reparam(mu, logvar)
if cond_drop_prob > 0:
drop = (torch.rand_like(y.float()) < cond_drop_prob)
y = y.clone()
y[drop] = latent_net.null_class_idx
t = torch.randint(0, ddpm.cfg.timesteps, (z0.size(0),), device=device)
z_t, noise = ddpm.q_sample(z0, t)
with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
pred = latent_net(z_t, t, y)
loss = F.mse_loss(pred, noise)
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(latent_net.parameters(), 1.0)
scaler.step(opt); scaler.update()
ema.update(latent_net)
sched.step()
return latent_net, ema
# --------------------
# execute
# --------------------
if __name__ == "__main__":
seed_everything(42)
device = get_device()
print("Device:", device)
# MNIST Data
from torchvision import transforms
tfm = transforms.ToTensor()
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
# VAE
latent_dim = 16
vae = VAE(latent_dim=latent_dim, ch=48)
vae = train_vae(vae, train_loader, device,
epochs=5, lr=1e-3, beta_max=0.5, edge_lambda=0.05,
out_dir="runs_vae", save_recon_every_epoch=True)
# DDPM
cfg = DDPMConfig()
ddpm = DDPM(cfg, device)
num_classes = 10
latent_net = LatentNoisePredictor(latent_dim=latent_dim, num_classes=num_classes, hidden=768).to(device)
latent_net, ema = train_ddpm(
latent_net, ddpm, vae, train_loader, device,
epochs=10, lr=2e-4, num_classes=num_classes, cond_drop_prob=0.2,
weight_decay=0.0, ema_decay=0.999
)
ema.copy_to(latent_net)
B = 64
target_digit = 5
y = torch.full((B,), target_digit, device=device, dtype=torch.long)
decode_fn = lambda z: torch.sigmoid(vae.dec(z)) # logits→確率
imgs = ddpm.sample(
latent_net, shape=(B, latent_dim), decode_fn=decode_fn,
n_steps=250, y=y, guidance_w=3
)
grid = vutils.make_grid(imgs[:16], nrow=4)
import matplotlib.pyplot as plt
plt.figure(figsize=(6,6)); plt.axis("off")
plt.imshow(grid.permute(1,2,0).cpu().numpy())
plt.show()
The quality also changes depending on the CFG setting, but overall I’m pretty satisfied with the results.
I recommend setting the VAE epochs to 15. As for the DDPM, it was suggested to go up to 50 epochs, but even 10 epochs works quite well.