今回の実装では、VAEで画像を潜在空間に圧縮し、その潜在表現に対してDDPM(Denoising Diffusion Probabilistic Model)を適用する構成を採用しました。ChatGPTに全部考えてもらいました。
VAEはResidualConvBlock + GroupNorm + SiLU + Dropoutで安定性と表現力を向上し、再構成損失にMSE+エッジ損失を組み合わせて輪郭を保持。
さらに、VAEの学習ステップを増やすことで、生成画像のギザギザ感を大幅に低減できました。
カテゴリ(0〜9の数字)を指定すると、その数字がサンプリングされて出力されます。
これにより、ランダム生成だけでなく条件付き生成も可能になります。
# =========================================================
# MNIST: VAE(latent) + DDPM(latent) with Class-Conditional CFG
# 改良点まとめ:
# - VAE logits化 + BCEWithLogitsLoss + Edge損失(確率で計算)
# - KLウォームアップ(beta_max=0.5 推奨)
# - VAE幅 ch=48(シャープ化)
# - DDPM: Cosineスケジュール
# - DDPM: クラス条件付き(classifier-free: cond_drop_prob)
# - CFG推論 (guidance_w)
# - AMP / GradClip / CosineAnnealingWarmRestarts
# - EMA(学習中にshadow更新→推論時にコピー)
# =========================================================
import os, math, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm
# --------------------
# ユーティリティ
# --------------------
def seed_everything(seed=42):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ================= Improved CNN VAE for MNIST (28x28x1) =================
# 画質UPの工夫:
# - ResidualConvBlock(GroupNorm+SiLU+Dropout)
# - BCEWithLogits + Edge(勾配)損失ブレンド(エッジは確率空間で)
# - KLウォームアップ(betaはエポック依存、beta_max=0.5推奨)
class ResidualConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, groups=8, dropout=0.05):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(groups, in_ch),
nn.SiLU(),
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
)
self.block2 = nn.Sequential(
nn.GroupNorm(groups, out_ch),
nn.SiLU(),
nn.Dropout2d(dropout),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
)
self.skip = (nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity())
def forward(self, x):
h = self.block1(x)
h = self.block2(h)
return h + self.skip(x)
class ConvEncoder(nn.Module):
# 28x28 -> 14x14 -> 7x7
def __init__(self, latent_dim=16, ch=48):
super().__init__()
self.stem = nn.Conv2d(1, ch, 3, padding=1)
self.res1 = ResidualConvBlock(ch, ch) # 28x28
self.down1 = nn.Conv2d(ch, ch*2, 4, stride=2, padding=1) # ->14x14
self.res2 = ResidualConvBlock(ch*2, ch*2) # 14x14
self.down2 = nn.Conv2d(ch*2, ch*4, 4, stride=2, padding=1) # ->7x7
self.res3 = ResidualConvBlock(ch*4, ch*4) # 7x7
self.proj = nn.Linear((ch*4)*7*7, 256)
self.mu = nn.Linear(256, latent_dim)
self.logvar = nn.Linear(256, latent_dim)
def forward(self, x):
h = self.stem(x)
h = self.res1(h)
h = self.down1(h)
h = self.res2(h)
h = self.down2(h)
h = self.res3(h)
h = h.view(h.size(0), -1)
h = F.silu(self.proj(h))
mu, logvar = self.mu(h), self.logvar(h)
return mu, logvar
class ConvDecoder(nn.Module):
# 7x7 -> 14x14 -> 28x28
def __init__(self, latent_dim=16, ch=48):
super().__init__()
self.fc = nn.Linear(latent_dim, (ch*4)*7*7)
self.res3 = ResidualConvBlock(ch*4, ch*4) # 7x7
self.up1 = nn.ConvTranspose2d(ch*4, ch*2, 4, stride=2, padding=1) # ->14x14
self.res2 = ResidualConvBlock(ch*2, ch*2) # 14x14
self.up2 = nn.ConvTranspose2d(ch*2, ch, 4, stride=2, padding=1) # ->28x28
self.res1 = ResidualConvBlock(ch, ch) # 28x28
self.head = nn.Conv2d(ch, 1, 3, padding=1)
def forward(self, z):
h = F.silu(self.fc(z)).view(z.size(0), -1, 7, 7)
h = self.res3(h)
h = self.up1(h)
h = self.res2(h)
h = self.up2(h)
h = self.res1(h)
x_logits = self.head(h) # ← logitsを返す(sigmoidしない)
return x_logits
class VAE(nn.Module):
def __init__(self, latent_dim=16, ch=48):
super().__init__()
self.enc = ConvEncoder(latent_dim, ch)
self.dec = ConvDecoder(latent_dim, ch)
def reparam(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.enc(x)
z = self.reparam(mu, logvar)
x_logits = self.dec(z)
return x_logits, mu, logvar, z
# --------- Edge-aware Reconstruction Loss + KL Warmup ----------
def image_gradients(x):
kx = torch.tensor([[[[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]]]], device=x.device, dtype=x.dtype)
ky = torch.tensor([[[[-1, -2, -1],
[ 0, 0, 0],
[ 1, 2, 1]]]], device=x.device, dtype=x.dtype)
gx = F.conv2d(x, kx, padding=1)
gy = F.conv2d(x, ky, padding=1)
return torch.sqrt(gx*gx + gy*gy + 1e-6)
def vae_loss(x, x_logits, mu, logvar, recon_type="bce", edge_lambda=0.05, beta=1.0):
if recon_type == "mse":
x_hat = torch.sigmoid(x_logits)
recon = F.mse_loss(x_hat, x, reduction="sum")
else:
recon = F.binary_cross_entropy_with_logits(x_logits, x, reduction="sum")
if edge_lambda > 0:
x_hat_prob = torch.sigmoid(x_logits)
gx = image_gradients(x)
gx_hat = image_gradients(x_hat_prob)
edge = F.l1_loss(gx_hat, gx, reduction="sum")
recon = recon + edge_lambda * edge
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon + beta * kl, recon, kl
def kl_warmup(epoch, total_epochs, min_beta=0.0, max_beta=0.5):
t = epoch / max(1, total_epochs)
beta = min_beta + 0.5*(max_beta - min_beta)*(1 - math.cos(math.pi * min(t, 1.0)))
return beta
# --------- Training (VAE) ----------
def train_vae(model, loader, device, epochs=10, lr=1e-3, out_dir="runs_vae",
recon="bce", edge_lambda=0.05, beta_max=0.5, beta_min=0.0,
save_recon_every_epoch=True, weight_decay=1e-4):
os.makedirs(out_dir, exist_ok=True)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
model.to(device)
for ep in range(1, epochs+1):
model.train()
beta = kl_warmup(ep, epochs, min_beta=beta_min, max_beta=beta_max)
pbar = tqdm(loader, desc=f"[VAE*] epoch {ep}/{epochs} (beta={beta:.3f})")
total, total_rec, total_kl = 0.0, 0.0, 0.0
for x, _ in pbar:
x = x.to(device)
with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
x_logits, mu, logvar, _ = model(x)
loss, rec, kl = vae_loss(x, x_logits, mu, logvar,
recon_type=recon, edge_lambda=edge_lambda, beta=beta)
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt); scaler.update()
total += loss.item(); total_rec += rec.item(); total_kl += kl.item()
pbar.set_postfix(loss=f"{total/len(loader):.2f}",
rec=f"{total_rec/len(loader):.2f}",
kl=f"{total_kl/len(loader):.2f}")
if save_recon_every_epoch:
with torch.no_grad():
x_vis = next(iter(loader))[0][:64].to(device)
x_logits, _, _, _ = model(x_vis)
x_hat = torch.sigmoid(x_logits)
grid = vutils.make_grid(torch.cat([x_vis, x_hat], dim=0), nrow=16)
from torchvision.utils import save_image
save_image(grid, os.path.join(out_dir, f"recon_ep{ep}.png"))
return model
# --------------------
# DDPM(Cosineスケジュール + Classifier-Free Guidance 対応)
# --------------------
@dataclass
class DDPMConfig:
timesteps: int = 1000
beta_start: float = 1e-4 # 未使用(cosineで上書き)
beta_end: float = 0.02 # 未使用(cosineで上書き)
class TimeEmbedding(nn.Module):
def __init__(self, embedding_dim=64, max_period=10000):
super().__init__()
self.embedding_dim = embedding_dim
self.max_period = max_period
def forward(self, t):
half = self.embedding_dim // 2
device = t.device
freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=device) / half)
args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.embedding_dim % 2 == 1:
emb = F.pad(emb, (0,1))
return emb
class LatentNoisePredictor(nn.Module):
def __init__(self, latent_dim, num_classes=10, hidden=768, emb_dim=64):
super().__init__()
self.te = TimeEmbedding(emb_dim)
self.ce = nn.Embedding(num_classes + 1, emb_dim) # +1: null token
nn.init.normal_(self.ce.weight, std=0.02)
self.null_class_idx = num_classes
self.fc1 = nn.Linear(latent_dim + emb_dim + emb_dim, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.out = nn.Linear(hidden, latent_dim)
def forward(self, z_t, t, y=None):
if y is None:
y = torch.full((z_t.size(0),), self.null_class_idx, device=z_t.device, dtype=torch.long)
te = self.te(t)
ce = self.ce(y)
x = torch.cat([z_t, te, ce], dim=1)
h = F.silu(self.fc1(x))
h = F.silu(self.fc2(h))
return self.out(h) # ここでは ε を直接予測(v-paramはオプションのため未使用)
class DDPM:
def __init__(self, cfg, device):
self.cfg = cfg
self.device = device
# ---- Cosine ᾱ(t) ----
def _cosine_alphas_cumprod(T, s=0.008):
steps = torch.arange(T+1, dtype=torch.float32)
f = torch.cos(((steps/T + s) / (1+s)) * math.pi * 0.5) ** 2
a_bar = f / f[0]
return a_bar[1:] # 長さT
alphas_cumprod = _cosine_alphas_cumprod(cfg.timesteps).to(device) # ᾱ_t
alphas = torch.cat([alphas_cumprod[0:1],
(alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(1e-8, 1.0)], dim=0)
betas = 1.0 - alphas
self.betas = betas
self.alphas = alphas
self.alphas_cumprod = alphas_cumprod
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
def q_sample(self, x0, t, noise=None):
if noise is None:
noise = torch.randn_like(x0)
return (
self.sqrt_alphas_cumprod[t].unsqueeze(1) * x0 +
self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1) * noise,
noise
)
def p_sample(self, model, x_t, t, y=None, guidance_w=0.0):
beta_t = self.betas[t].unsqueeze(1)
alpha_t = self.alphas[t].unsqueeze(1)
alpha_bar_t = self.alphas_cumprod[t].unsqueeze(1)
sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t)
if guidance_w == 0.0 or y is None:
eps_theta = model(x_t, t, y)
else:
eps_cond = model(x_t, t, y)
null_y = torch.full_like(y, model.null_class_idx)
eps_uncond = model(x_t, t, null_y)
eps_theta = (1 + guidance_w) * eps_cond - guidance_w * eps_uncond
mean = (torch.sqrt(alpha_t) * (x_t - (beta_t / sqrt_one_minus_ab) * eps_theta))
if (t > 0).all():
return mean + torch.sqrt(beta_t) * torch.randn_like(x_t)
else:
return mean
@torch.no_grad()
def sample(self, model, shape, decode_fn, n_steps=None, y=None, guidance_w=0.0):
model.eval()
B, latent_dim = shape
x_t = torch.randn(B, latent_dim, device=self.device)
timesteps = self.cfg.timesteps if n_steps is None else n_steps
t_space = torch.linspace(self.cfg.timesteps-1, 0, timesteps, dtype=torch.long, device=self.device)
for t_scalar in t_space:
t = torch.full((B,), int(t_scalar.item()), device=self.device, dtype=torch.long)
x_t = self.p_sample(model, x_t, t, y=y, guidance_w=guidance_w)
return decode_fn(x_t).clamp(0, 1)
# --------- EMA(指数移動平均) ----------
class EMA:
def __init__(self, model, decay=0.999):
self.decay = decay
self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model):
for k, v in model.state_dict().items():
self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1 - self.decay)
def copy_to(self, model):
model.load_state_dict(self.shadow, strict=True)
# --------- Training (DDPM) ----------
def train_ddpm(latent_net, ddpm, vae, loader, device,
epochs=5, lr=2e-4, num_classes=10, cond_drop_prob=0.2,
weight_decay=0.0, ema_decay=0.999):
opt = torch.optim.AdamW(latent_net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=5, T_mult=2)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
vae.eval()
ema = EMA(latent_net, decay=ema_decay)
for ep in range(1, epochs+1):
latent_net.train()
pbar = tqdm(loader, desc=f"[DDPM] {ep}/{epochs}")
for x, y in pbar:
x = x.to(device); y = y.to(device)
with torch.no_grad():
mu, logvar = vae.enc(x)
z0 = vae.reparam(mu, logvar)
# classifier-free 学習: ラベルを一定確率で null に置換
if cond_drop_prob > 0:
drop = (torch.rand_like(y.float()) < cond_drop_prob)
y = y.clone()
y[drop] = latent_net.null_class_idx
t = torch.randint(0, ddpm.cfg.timesteps, (z0.size(0),), device=device)
z_t, noise = ddpm.q_sample(z0, t)
with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
pred = latent_net(z_t, t, y) # ε 予測
loss = F.mse_loss(pred, noise)
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(latent_net.parameters(), 1.0)
scaler.step(opt); scaler.update()
ema.update(latent_net)
sched.step()
return latent_net, ema
# --------------------
# 実行
# --------------------
if __name__ == "__main__":
seed_everything(42)
device = get_device()
print("Device:", device)
# MNIST Data
from torchvision import transforms
tfm = transforms.ToTensor()
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
# VAE学習
latent_dim = 16 # ← DDPM側と必ず一致
vae = VAE(latent_dim=latent_dim, ch=48)
vae = train_vae(vae, train_loader, device,
epochs=5, lr=1e-3, beta_max=0.5, edge_lambda=0.05,
out_dir="runs_vae", save_recon_every_epoch=True)
# DDPM学習(クラス条件 & CFG学習用ドロップ)
cfg = DDPMConfig()
ddpm = DDPM(cfg, device)
num_classes = 10
latent_net = LatentNoisePredictor(latent_dim=latent_dim, num_classes=num_classes, hidden=768).to(device)
latent_net, ema = train_ddpm(
latent_net, ddpm, vae, train_loader, device,
epochs=10, lr=2e-4, num_classes=num_classes, cond_drop_prob=0.2,
weight_decay=0.0, ema_decay=0.999
)
最後のサンプリングは、
# サンプリング(EMA重みで推論)
ema.copy_to(latent_net)
# 生成したい数字とCFG強度を指定
B = 64
target_digit = 5
y = torch.full((B,), target_digit, device=device, dtype=torch.long)
decode_fn = lambda z: torch.sigmoid(vae.dec(z)) # logits→確率
imgs = ddpm.sample(
latent_net, shape=(B, latent_dim), decode_fn=decode_fn,
n_steps=250, y=y, guidance_w=3 # CFGは 1.5〜2.5 でチューニング
)
# 可視化
grid = vutils.make_grid(imgs[:16], nrow=4)
import matplotlib.pyplot as plt
plt.figure(figsize=(6,6)); plt.axis("off")
plt.imshow(grid.permute(1,2,0).cpu().numpy())
plt.show()
VAEのEpoch数が十分でないとギザギザが出ましたので、今は5です。
こちらをDDPMはEpoch10のまま、VAEを5から15に増やすと、
あとはCFGによってもクオリティが変わってきますが、大体満足いくかなぁって感じになりました。
VAEだけを15にするのがお薦めです。あとはDDPMも50にすると提案されましたが、10でもいい感じです。