common.title

Docs
Quantum Circuit
TYTAN CLOUD

QUANTUM GAMING


Overview
Contact
Event
Project
Research

Terms of service (Web service)

Terms of service (Quantum and ML Cloud service)

Privacy policy


Sign in
Sign up
common.title

VAE + DDPM for MNIST Digit Generation – Conditional Sampling Example

Yuichiro Minato

2025/08/11 13:42

"VAE + DDPM for MNIST Digit Generation – Conditional Sampling Example

In this implementation, we compress images into a latent space using a VAE, then apply a DDPM (Denoising Diffusion Probabilistic Model) to the latent representations. The entire approach was designed with the help of ChatGPT.

The VAE uses ResidualConvBlock + GroupNorm + SiLU + Dropout to improve stability and expressive power, and combines MSE loss with an edge loss in the reconstruction objective to preserve sharp contours.

By increasing the number of VAE training steps, we were able to significantly reduce the jagged artifacts in generated images.

You can specify a category (digit 0–9) to sample, allowing not only random generation but also conditional generation where the model outputs the desired digit."

# =========================================================
# MNIST: VAE(latent) + DDPM(latent) with Class-Conditional CFG
# =========================================================

import os, math, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")
# ================= Improved CNN VAE for MNIST (28x28x1) =================

class ResidualConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, groups=8, dropout=0.05):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(groups, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(),
            nn.Dropout2d(dropout),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        )
        self.skip = (nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity())

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        return h + self.skip(x)

class ConvEncoder(nn.Module):
    # 28x28 -> 14x14 -> 7x7
    def __init__(self, latent_dim=16, ch=48):
        super().__init__()
        self.stem = nn.Conv2d(1, ch, 3, padding=1)
        self.res1 = ResidualConvBlock(ch, ch)                 # 28x28
        self.down1 = nn.Conv2d(ch, ch*2, 4, stride=2, padding=1)  # ->14x14
        self.res2 = ResidualConvBlock(ch*2, ch*2)             # 14x14
        self.down2 = nn.Conv2d(ch*2, ch*4, 4, stride=2, padding=1) # ->7x7
        self.res3 = ResidualConvBlock(ch*4, ch*4)             # 7x7
        self.proj  = nn.Linear((ch*4)*7*7, 256)
        self.mu     = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

    def forward(self, x):
        h = self.stem(x)
        h = self.res1(h)
        h = self.down1(h)
        h = self.res2(h)
        h = self.down2(h)
        h = self.res3(h)
        h = h.view(h.size(0), -1)
        h = F.silu(self.proj(h))
        mu, logvar = self.mu(h), self.logvar(h)
        return mu, logvar

class ConvDecoder(nn.Module):
    # 7x7 -> 14x14 -> 28x28
    def __init__(self, latent_dim=16, ch=48):
        super().__init__()
        self.fc   = nn.Linear(latent_dim, (ch*4)*7*7)
        self.res3 = ResidualConvBlock(ch*4, ch*4)             # 7x7
        self.up1  = nn.ConvTranspose2d(ch*4, ch*2, 4, stride=2, padding=1) # ->14x14
        self.res2 = ResidualConvBlock(ch*2, ch*2)             # 14x14
        self.up2  = nn.ConvTranspose2d(ch*2, ch, 4, stride=2, padding=1)   # ->28x28
        self.res1 = ResidualConvBlock(ch, ch)                 # 28x28
        self.head = nn.Conv2d(ch, 1, 3, padding=1)

    def forward(self, z):
        h = F.silu(self.fc(z)).view(z.size(0), -1, 7, 7)
        h = self.res3(h)
        h = self.up1(h)
        h = self.res2(h)
        h = self.up2(h)
        h = self.res1(h)
        x_logits = self.head(h) 
        return x_logits

class VAE(nn.Module):
    def __init__(self, latent_dim=16, ch=48):
        super().__init__()
        self.enc = ConvEncoder(latent_dim, ch)
        self.dec = ConvDecoder(latent_dim, ch)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.enc(x)
        z = self.reparam(mu, logvar)
        x_logits = self.dec(z)
        return x_logits, mu, logvar, z
# --------- Edge-aware Reconstruction Loss + KL Warmup ----------
def image_gradients(x):
    kx = torch.tensor([[[[-1, 0, 1],
                         [-2, 0, 2],
                         [-1, 0, 1]]]], device=x.device, dtype=x.dtype)
    ky = torch.tensor([[[[-1, -2, -1],
                         [ 0,  0,  0],
                         [ 1,  2,  1]]]], device=x.device, dtype=x.dtype)
    gx = F.conv2d(x, kx, padding=1)
    gy = F.conv2d(x, ky, padding=1)
    return torch.sqrt(gx*gx + gy*gy + 1e-6)

def vae_loss(x, x_logits, mu, logvar, recon_type="bce", edge_lambda=0.05, beta=1.0):
    if recon_type == "mse":
        x_hat = torch.sigmoid(x_logits)
        recon = F.mse_loss(x_hat, x, reduction="sum")
    else:
        recon = F.binary_cross_entropy_with_logits(x_logits, x, reduction="sum")

    if edge_lambda > 0:
        x_hat_prob = torch.sigmoid(x_logits)
        gx = image_gradients(x)
        gx_hat = image_gradients(x_hat_prob)
        edge = F.l1_loss(gx_hat, gx, reduction="sum")
        recon = recon + edge_lambda * edge

    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kl, recon, kl

def kl_warmup(epoch, total_epochs, min_beta=0.0, max_beta=0.5):
    t = epoch / max(1, total_epochs)
    beta = min_beta + 0.5*(max_beta - min_beta)*(1 - math.cos(math.pi * min(t, 1.0)))
    return beta
# --------- Training (VAE) ----------
def train_vae(model, loader, device, epochs=10, lr=1e-3, out_dir="runs_vae",
              recon="bce", edge_lambda=0.05, beta_max=0.5, beta_min=0.0,
              save_recon_every_epoch=True, weight_decay=1e-4):
    os.makedirs(out_dir, exist_ok=True)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
    model.to(device)

    for ep in range(1, epochs+1):
        model.train()
        beta = kl_warmup(ep, epochs, min_beta=beta_min, max_beta=beta_max)
        pbar = tqdm(loader, desc=f"[VAE*] epoch {ep}/{epochs} (beta={beta:.3f})")
        total, total_rec, total_kl = 0.0, 0.0, 0.0

        for x, _ in pbar:
            x = x.to(device)
            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                x_logits, mu, logvar, _ = model(x)
                loss, rec, kl = vae_loss(x, x_logits, mu, logvar,
                                         recon_type=recon, edge_lambda=edge_lambda, beta=beta)
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt); scaler.update()

            total += loss.item(); total_rec += rec.item(); total_kl += kl.item()
            pbar.set_postfix(loss=f"{total/len(loader):.2f}",
                             rec=f"{total_rec/len(loader):.2f}",
                             kl=f"{total_kl/len(loader):.2f}")

        if save_recon_every_epoch:
            with torch.no_grad():
                x_vis = next(iter(loader))[0][:64].to(device)
                x_logits, _, _, _ = model(x_vis)
                x_hat = torch.sigmoid(x_logits)
                grid = vutils.make_grid(torch.cat([x_vis, x_hat], dim=0), nrow=16)
                from torchvision.utils import save_image
                save_image(grid, os.path.join(out_dir, f"recon_ep{ep}.png"))

    return model
# --------------------
# DDPM
# --------------------
@dataclass
class DDPMConfig:
    timesteps: int = 1000
    beta_start: float = 1e-4  
    beta_end: float = 0.02   

class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim=64, max_period=10000):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_period = max_period
    def forward(self, t):
        half = self.embedding_dim // 2
        device = t.device
        freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=device) / half)
        args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
        emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.embedding_dim % 2 == 1:
            emb = F.pad(emb, (0,1))
        return emb

class LatentNoisePredictor(nn.Module):
    def __init__(self, latent_dim, num_classes=10, hidden=768, emb_dim=64):
        super().__init__()
        self.te = TimeEmbedding(emb_dim)
        self.ce = nn.Embedding(num_classes + 1, emb_dim)  # +1: null token
        nn.init.normal_(self.ce.weight, std=0.02)
        self.null_class_idx = num_classes

        self.fc1 = nn.Linear(latent_dim + emb_dim + emb_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, latent_dim)

    def forward(self, z_t, t, y=None):
        if y is None:
            y = torch.full((z_t.size(0),), self.null_class_idx, device=z_t.device, dtype=torch.long)
        te = self.te(t)
        ce = self.ce(y)
        x = torch.cat([z_t, te, ce], dim=1)
        h = F.silu(self.fc1(x))
        h = F.silu(self.fc2(h))
        return self.out(h)  

class DDPM:
    def __init__(self, cfg, device):
        self.cfg = cfg
        self.device = device
        # ---- Cosine ᾱ(t) ----
        def _cosine_alphas_cumprod(T, s=0.008):
            steps = torch.arange(T+1, dtype=torch.float32)
            f = torch.cos(((steps/T + s) / (1+s)) * math.pi * 0.5) ** 2
            a_bar = f / f[0]
            return a_bar[1:] 
        alphas_cumprod = _cosine_alphas_cumprod(cfg.timesteps).to(device)  # ᾱ_t

        alphas = torch.cat([alphas_cumprod[0:1],
                            (alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(1e-8, 1.0)], dim=0)
        betas = 1.0 - alphas

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        return (
            self.sqrt_alphas_cumprod[t].unsqueeze(1) * x0 +
            self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1) * noise,
            noise
        )

    def p_sample(self, model, x_t, t, y=None, guidance_w=0.0):
        beta_t = self.betas[t].unsqueeze(1)
        alpha_t = self.alphas[t].unsqueeze(1)
        alpha_bar_t = self.alphas_cumprod[t].unsqueeze(1)
        sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t)

        if guidance_w == 0.0 or y is None:
            eps_theta = model(x_t, t, y)
        else:
            eps_cond = model(x_t, t, y)
            null_y = torch.full_like(y, model.null_class_idx)
            eps_uncond = model(x_t, t, null_y)
            eps_theta = (1 + guidance_w) * eps_cond - guidance_w * eps_uncond

        mean = (torch.sqrt(alpha_t) * (x_t - (beta_t / sqrt_one_minus_ab) * eps_theta))
        if (t > 0).all():
            return mean + torch.sqrt(beta_t) * torch.randn_like(x_t)
        else:
            return mean

    @torch.no_grad()
    def sample(self, model, shape, decode_fn, n_steps=None, y=None, guidance_w=0.0):
        model.eval()
        B, latent_dim = shape
        x_t = torch.randn(B, latent_dim, device=self.device)
        timesteps = self.cfg.timesteps if n_steps is None else n_steps
        t_space = torch.linspace(self.cfg.timesteps-1, 0, timesteps, dtype=torch.long, device=self.device)
        for t_scalar in t_space:
            t = torch.full((B,), int(t_scalar.item()), device=self.device, dtype=torch.long)
            x_t = self.p_sample(model, x_t, t, y=y, guidance_w=guidance_w)
        return decode_fn(x_t).clamp(0, 1)
# --------- EMA ----------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}

    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1 - self.decay)

    def copy_to(self, model):
        model.load_state_dict(self.shadow, strict=True)
# --------- Training (DDPM) ----------
def train_ddpm(latent_net, ddpm, vae, loader, device,
               epochs=5, lr=2e-4, num_classes=10, cond_drop_prob=0.2,
               weight_decay=0.0, ema_decay=0.999):
    opt = torch.optim.AdamW(latent_net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=5, T_mult=2)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

    vae.eval()
    ema = EMA(latent_net, decay=ema_decay)

    for ep in range(1, epochs+1):
        latent_net.train()
        pbar = tqdm(loader, desc=f"[DDPM] {ep}/{epochs}")
        for x, y in pbar:
            x = x.to(device); y = y.to(device)

            with torch.no_grad():
                mu, logvar = vae.enc(x)
                z0 = vae.reparam(mu, logvar)

            if cond_drop_prob > 0:
                drop = (torch.rand_like(y.float()) < cond_drop_prob)
                y = y.clone()
                y[drop] = latent_net.null_class_idx

            t = torch.randint(0, ddpm.cfg.timesteps, (z0.size(0),), device=device)
            z_t, noise = ddpm.q_sample(z0, t)

            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                pred = latent_net(z_t, t, y)   
                loss = F.mse_loss(pred, noise)

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(latent_net.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            ema.update(latent_net)
            sched.step()

    return latent_net, ema
# --------------------
# execute
# --------------------
if __name__ == "__main__":
    seed_everything(42)
    device = get_device()
    print("Device:", device)

    # MNIST Data
    from torchvision import transforms
    tfm = transforms.ToTensor()
    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

    # VAE
    latent_dim = 16
    vae = VAE(latent_dim=latent_dim, ch=48)
    vae = train_vae(vae, train_loader, device,
                    epochs=5, lr=1e-3, beta_max=0.5, edge_lambda=0.05,
                    out_dir="runs_vae", save_recon_every_epoch=True)

    # DDPM
    cfg = DDPMConfig()
    ddpm = DDPM(cfg, device)
    num_classes = 10
    latent_net = LatentNoisePredictor(latent_dim=latent_dim, num_classes=num_classes, hidden=768).to(device)
    latent_net, ema = train_ddpm(
        latent_net, ddpm, vae, train_loader, device,
        epochs=10, lr=2e-4, num_classes=num_classes, cond_drop_prob=0.2,
        weight_decay=0.0, ema_decay=0.999
    )
    ema.copy_to(latent_net)

    B = 64
    target_digit = 5
    y = torch.full((B,), target_digit, device=device, dtype=torch.long)

    decode_fn = lambda z: torch.sigmoid(vae.dec(z))  # logits→確率
    imgs = ddpm.sample(
        latent_net, shape=(B, latent_dim), decode_fn=decode_fn,
        n_steps=250, y=y, guidance_w=3  
    )

    grid = vutils.make_grid(imgs[:16], nrow=4)
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6,6)); plt.axis("off")
    plt.imshow(grid.permute(1,2,0).cpu().numpy())
    plt.show()

image

image

image

The quality also changes depending on the CFG setting, but overall I’m pretty satisfied with the results.
I recommend setting the VAE epochs to 15. As for the DDPM, it was suggested to go up to 50 epochs, but even 10 epochs works quite well.

image

© 2025, blueqat Inc. All rights reserved