common.title

Docs
Quantum Circuit
TYTAN CLOUD

QUANTUM GAMING


Overview
Contact
Event
Project
Research

Terms of service (Web service)

Terms of service (Quantum and ML Cloud service)

Privacy policy


Sign in
Sign up
common.title

「VAE+DDPMでMNIST数字生成 – カテゴリ指定サンプリングの実装例」

Yuichiro Minato

2025/08/11 13:34

今回の実装では、VAEで画像を潜在空間に圧縮し、その潜在表現に対してDDPM(Denoising Diffusion Probabilistic Model)を適用する構成を採用しました。ChatGPTに全部考えてもらいました。

VAEはResidualConvBlock + GroupNorm + SiLU + Dropoutで安定性と表現力を向上し、再構成損失にMSE+エッジ損失を組み合わせて輪郭を保持。

さらに、VAEの学習ステップを増やすことで、生成画像のギザギザ感を大幅に低減できました。

カテゴリ(0〜9の数字)を指定すると、その数字がサンプリングされて出力されます。
これにより、ランダム生成だけでなく条件付き生成も可能になります。

# =========================================================
# MNIST: VAE(latent) + DDPM(latent) with Class-Conditional CFG
# 改良点まとめ:
# - VAE logits化 + BCEWithLogitsLoss + Edge損失(確率で計算)
# - KLウォームアップ(beta_max=0.5 推奨)
# - VAE幅 ch=48(シャープ化)
# - DDPM: Cosineスケジュール
# - DDPM: クラス条件付き(classifier-free: cond_drop_prob)
# - CFG推論 (guidance_w)
# - AMP / GradClip / CosineAnnealingWarmRestarts
# - EMA(学習中にshadow更新→推論時にコピー)
# =========================================================

import os, math, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm

# --------------------
# ユーティリティ
# --------------------
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")
# ================= Improved CNN VAE for MNIST (28x28x1) =================
# 画質UPの工夫:
# - ResidualConvBlock(GroupNorm+SiLU+Dropout)
# - BCEWithLogits + Edge(勾配)損失ブレンド(エッジは確率空間で)
# - KLウォームアップ(betaはエポック依存、beta_max=0.5推奨)

class ResidualConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, groups=8, dropout=0.05):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(groups, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(),
            nn.Dropout2d(dropout),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        )
        self.skip = (nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity())

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        return h + self.skip(x)

class ConvEncoder(nn.Module):
    # 28x28 -> 14x14 -> 7x7
    def __init__(self, latent_dim=16, ch=48):
        super().__init__()
        self.stem = nn.Conv2d(1, ch, 3, padding=1)
        self.res1 = ResidualConvBlock(ch, ch)                 # 28x28
        self.down1 = nn.Conv2d(ch, ch*2, 4, stride=2, padding=1)  # ->14x14
        self.res2 = ResidualConvBlock(ch*2, ch*2)             # 14x14
        self.down2 = nn.Conv2d(ch*2, ch*4, 4, stride=2, padding=1) # ->7x7
        self.res3 = ResidualConvBlock(ch*4, ch*4)             # 7x7
        self.proj  = nn.Linear((ch*4)*7*7, 256)
        self.mu     = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

    def forward(self, x):
        h = self.stem(x)
        h = self.res1(h)
        h = self.down1(h)
        h = self.res2(h)
        h = self.down2(h)
        h = self.res3(h)
        h = h.view(h.size(0), -1)
        h = F.silu(self.proj(h))
        mu, logvar = self.mu(h), self.logvar(h)
        return mu, logvar

class ConvDecoder(nn.Module):
    # 7x7 -> 14x14 -> 28x28
    def __init__(self, latent_dim=16, ch=48):
        super().__init__()
        self.fc   = nn.Linear(latent_dim, (ch*4)*7*7)
        self.res3 = ResidualConvBlock(ch*4, ch*4)             # 7x7
        self.up1  = nn.ConvTranspose2d(ch*4, ch*2, 4, stride=2, padding=1) # ->14x14
        self.res2 = ResidualConvBlock(ch*2, ch*2)             # 14x14
        self.up2  = nn.ConvTranspose2d(ch*2, ch, 4, stride=2, padding=1)   # ->28x28
        self.res1 = ResidualConvBlock(ch, ch)                 # 28x28
        self.head = nn.Conv2d(ch, 1, 3, padding=1)

    def forward(self, z):
        h = F.silu(self.fc(z)).view(z.size(0), -1, 7, 7)
        h = self.res3(h)
        h = self.up1(h)
        h = self.res2(h)
        h = self.up2(h)
        h = self.res1(h)
        x_logits = self.head(h)   # ← logitsを返す(sigmoidしない)
        return x_logits

class VAE(nn.Module):
    def __init__(self, latent_dim=16, ch=48):
        super().__init__()
        self.enc = ConvEncoder(latent_dim, ch)
        self.dec = ConvDecoder(latent_dim, ch)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.enc(x)
        z = self.reparam(mu, logvar)
        x_logits = self.dec(z)
        return x_logits, mu, logvar, z
# --------- Edge-aware Reconstruction Loss + KL Warmup ----------
def image_gradients(x):
    kx = torch.tensor([[[[-1, 0, 1],
                         [-2, 0, 2],
                         [-1, 0, 1]]]], device=x.device, dtype=x.dtype)
    ky = torch.tensor([[[[-1, -2, -1],
                         [ 0,  0,  0],
                         [ 1,  2,  1]]]], device=x.device, dtype=x.dtype)
    gx = F.conv2d(x, kx, padding=1)
    gy = F.conv2d(x, ky, padding=1)
    return torch.sqrt(gx*gx + gy*gy + 1e-6)

def vae_loss(x, x_logits, mu, logvar, recon_type="bce", edge_lambda=0.05, beta=1.0):
    if recon_type == "mse":
        x_hat = torch.sigmoid(x_logits)
        recon = F.mse_loss(x_hat, x, reduction="sum")
    else:
        recon = F.binary_cross_entropy_with_logits(x_logits, x, reduction="sum")

    if edge_lambda > 0:
        x_hat_prob = torch.sigmoid(x_logits)
        gx = image_gradients(x)
        gx_hat = image_gradients(x_hat_prob)
        edge = F.l1_loss(gx_hat, gx, reduction="sum")
        recon = recon + edge_lambda * edge

    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kl, recon, kl

def kl_warmup(epoch, total_epochs, min_beta=0.0, max_beta=0.5):
    t = epoch / max(1, total_epochs)
    beta = min_beta + 0.5*(max_beta - min_beta)*(1 - math.cos(math.pi * min(t, 1.0)))
    return beta
# --------- Training (VAE) ----------
def train_vae(model, loader, device, epochs=10, lr=1e-3, out_dir="runs_vae",
              recon="bce", edge_lambda=0.05, beta_max=0.5, beta_min=0.0,
              save_recon_every_epoch=True, weight_decay=1e-4):
    os.makedirs(out_dir, exist_ok=True)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
    model.to(device)

    for ep in range(1, epochs+1):
        model.train()
        beta = kl_warmup(ep, epochs, min_beta=beta_min, max_beta=beta_max)
        pbar = tqdm(loader, desc=f"[VAE*] epoch {ep}/{epochs} (beta={beta:.3f})")
        total, total_rec, total_kl = 0.0, 0.0, 0.0

        for x, _ in pbar:
            x = x.to(device)
            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                x_logits, mu, logvar, _ = model(x)
                loss, rec, kl = vae_loss(x, x_logits, mu, logvar,
                                         recon_type=recon, edge_lambda=edge_lambda, beta=beta)
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt); scaler.update()

            total += loss.item(); total_rec += rec.item(); total_kl += kl.item()
            pbar.set_postfix(loss=f"{total/len(loader):.2f}",
                             rec=f"{total_rec/len(loader):.2f}",
                             kl=f"{total_kl/len(loader):.2f}")

        if save_recon_every_epoch:
            with torch.no_grad():
                x_vis = next(iter(loader))[0][:64].to(device)
                x_logits, _, _, _ = model(x_vis)
                x_hat = torch.sigmoid(x_logits)
                grid = vutils.make_grid(torch.cat([x_vis, x_hat], dim=0), nrow=16)
                from torchvision.utils import save_image
                save_image(grid, os.path.join(out_dir, f"recon_ep{ep}.png"))

    return model
# --------------------
# DDPM(Cosineスケジュール + Classifier-Free Guidance 対応)
# --------------------
@dataclass
class DDPMConfig:
    timesteps: int = 1000
    beta_start: float = 1e-4  # 未使用(cosineで上書き)
    beta_end: float = 0.02    # 未使用(cosineで上書き)

class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim=64, max_period=10000):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_period = max_period
    def forward(self, t):
        half = self.embedding_dim // 2
        device = t.device
        freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=device) / half)
        args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
        emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.embedding_dim % 2 == 1:
            emb = F.pad(emb, (0,1))
        return emb

class LatentNoisePredictor(nn.Module):
    def __init__(self, latent_dim, num_classes=10, hidden=768, emb_dim=64):
        super().__init__()
        self.te = TimeEmbedding(emb_dim)
        self.ce = nn.Embedding(num_classes + 1, emb_dim)  # +1: null token
        nn.init.normal_(self.ce.weight, std=0.02)
        self.null_class_idx = num_classes

        self.fc1 = nn.Linear(latent_dim + emb_dim + emb_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, latent_dim)

    def forward(self, z_t, t, y=None):
        if y is None:
            y = torch.full((z_t.size(0),), self.null_class_idx, device=z_t.device, dtype=torch.long)
        te = self.te(t)
        ce = self.ce(y)
        x = torch.cat([z_t, te, ce], dim=1)
        h = F.silu(self.fc1(x))
        h = F.silu(self.fc2(h))
        return self.out(h)  # ここでは ε を直接予測(v-paramはオプションのため未使用)

class DDPM:
    def __init__(self, cfg, device):
        self.cfg = cfg
        self.device = device
        # ---- Cosine ᾱ(t) ----
        def _cosine_alphas_cumprod(T, s=0.008):
            steps = torch.arange(T+1, dtype=torch.float32)
            f = torch.cos(((steps/T + s) / (1+s)) * math.pi * 0.5) ** 2
            a_bar = f / f[0]
            return a_bar[1:]  # 長さT
        alphas_cumprod = _cosine_alphas_cumprod(cfg.timesteps).to(device)  # ᾱ_t

        alphas = torch.cat([alphas_cumprod[0:1],
                            (alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(1e-8, 1.0)], dim=0)
        betas = 1.0 - alphas

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        return (
            self.sqrt_alphas_cumprod[t].unsqueeze(1) * x0 +
            self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1) * noise,
            noise
        )

    def p_sample(self, model, x_t, t, y=None, guidance_w=0.0):
        beta_t = self.betas[t].unsqueeze(1)
        alpha_t = self.alphas[t].unsqueeze(1)
        alpha_bar_t = self.alphas_cumprod[t].unsqueeze(1)
        sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t)

        if guidance_w == 0.0 or y is None:
            eps_theta = model(x_t, t, y)
        else:
            eps_cond = model(x_t, t, y)
            null_y = torch.full_like(y, model.null_class_idx)
            eps_uncond = model(x_t, t, null_y)
            eps_theta = (1 + guidance_w) * eps_cond - guidance_w * eps_uncond

        mean = (torch.sqrt(alpha_t) * (x_t - (beta_t / sqrt_one_minus_ab) * eps_theta))
        if (t > 0).all():
            return mean + torch.sqrt(beta_t) * torch.randn_like(x_t)
        else:
            return mean

    @torch.no_grad()
    def sample(self, model, shape, decode_fn, n_steps=None, y=None, guidance_w=0.0):
        model.eval()
        B, latent_dim = shape
        x_t = torch.randn(B, latent_dim, device=self.device)
        timesteps = self.cfg.timesteps if n_steps is None else n_steps
        t_space = torch.linspace(self.cfg.timesteps-1, 0, timesteps, dtype=torch.long, device=self.device)
        for t_scalar in t_space:
            t = torch.full((B,), int(t_scalar.item()), device=self.device, dtype=torch.long)
            x_t = self.p_sample(model, x_t, t, y=y, guidance_w=guidance_w)
        return decode_fn(x_t).clamp(0, 1)
# --------- EMA(指数移動平均) ----------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}

    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1 - self.decay)

    def copy_to(self, model):
        model.load_state_dict(self.shadow, strict=True)
# --------- Training (DDPM) ----------
def train_ddpm(latent_net, ddpm, vae, loader, device,
               epochs=5, lr=2e-4, num_classes=10, cond_drop_prob=0.2,
               weight_decay=0.0, ema_decay=0.999):
    opt = torch.optim.AdamW(latent_net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=5, T_mult=2)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

    vae.eval()
    ema = EMA(latent_net, decay=ema_decay)

    for ep in range(1, epochs+1):
        latent_net.train()
        pbar = tqdm(loader, desc=f"[DDPM] {ep}/{epochs}")
        for x, y in pbar:
            x = x.to(device); y = y.to(device)

            with torch.no_grad():
                mu, logvar = vae.enc(x)
                z0 = vae.reparam(mu, logvar)

            # classifier-free 学習: ラベルを一定確率で null に置換
            if cond_drop_prob > 0:
                drop = (torch.rand_like(y.float()) < cond_drop_prob)
                y = y.clone()
                y[drop] = latent_net.null_class_idx

            t = torch.randint(0, ddpm.cfg.timesteps, (z0.size(0),), device=device)
            z_t, noise = ddpm.q_sample(z0, t)

            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                pred = latent_net(z_t, t, y)   # ε 予測
                loss = F.mse_loss(pred, noise)

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(latent_net.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            ema.update(latent_net)
            sched.step()

    return latent_net, ema
# --------------------
# 実行
# --------------------
if __name__ == "__main__":
    seed_everything(42)
    device = get_device()
    print("Device:", device)

    # MNIST Data
    from torchvision import transforms
    tfm = transforms.ToTensor()
    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

    # VAE学習
    latent_dim = 16          # ← DDPM側と必ず一致
    vae = VAE(latent_dim=latent_dim, ch=48)
    vae = train_vae(vae, train_loader, device,
                    epochs=5, lr=1e-3, beta_max=0.5, edge_lambda=0.05,
                    out_dir="runs_vae", save_recon_every_epoch=True)

    # DDPM学習(クラス条件 & CFG学習用ドロップ)
    cfg = DDPMConfig()
    ddpm = DDPM(cfg, device)
    num_classes = 10
    latent_net = LatentNoisePredictor(latent_dim=latent_dim, num_classes=num_classes, hidden=768).to(device)
    latent_net, ema = train_ddpm(
        latent_net, ddpm, vae, train_loader, device,
        epochs=10, lr=2e-4, num_classes=num_classes, cond_drop_prob=0.2,
        weight_decay=0.0, ema_decay=0.999
    )

最後のサンプリングは、

    # サンプリング(EMA重みで推論)
    ema.copy_to(latent_net)

    # 生成したい数字とCFG強度を指定
    B = 64
    target_digit = 5
    y = torch.full((B,), target_digit, device=device, dtype=torch.long)

    decode_fn = lambda z: torch.sigmoid(vae.dec(z))  # logits→確率
    imgs = ddpm.sample(
        latent_net, shape=(B, latent_dim), decode_fn=decode_fn,
        n_steps=250, y=y, guidance_w=3   # CFGは 1.5〜2.5 でチューニング
    )

    # 可視化
    grid = vutils.make_grid(imgs[:16], nrow=4)
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6,6)); plt.axis("off")
    plt.imshow(grid.permute(1,2,0).cpu().numpy())
    plt.show()

VAEのEpoch数が十分でないとギザギザが出ましたので、今は5です。

image

こちらをDDPMはEpoch10のまま、VAEを5から15に増やすと、

image

image

あとはCFGによってもクオリティが変わってきますが、大体満足いくかなぁって感じになりました。
VAEだけを15にするのがお薦めです。あとはDDPMも50にすると提案されましたが、10でもいい感じです。

image

© 2025, blueqat Inc. All rights reserved