U-NetからTransformerへ:MNISTとCIFAR-10で試す拡散モデルの新アプローチ
拡散モデル(Diffusion Model)の生成ネットワークといえば、U-Net構造が定番です。私もこれまで主にU-Netを使ってきましたが、今回の実験ではあえてU-Netを外し、TransformerベースのモデルであるTransformer2DModel
(DiT: Diffusion Transformer)を採用してみました。
実験構成
- データセット:MNIST / CIFAR-10
- 生成モデル:
Transformer2DModel
(U-Net非使用) - スケジューラ:DDPM + DPMSolver++
- 潜在空間版(VAE + Transformer)も実施
- 省メモリ対応:AMP(自動混合精度)+ 勾配チェックポイント
MNISTではTransformerを直接ピクセル空間で学習。CIFAR-10では
- ピクセル空間Transformer版
- VAEで潜在空間(8×8)に圧縮したLatent Diffusion + Transformer版
の両方を試しました。
結果
最初のサンプルはほぼノイズのまま。
U-Netに比べ、Transformer単体での拡散モデル学習はやや難易度が高く、初期学習では形状が出にくい印象でした。
MNISTでは多少ノイズは抑えられたけど数字までは出ず。
CIFAR-10ではさらに多くの学習時間とパラメータ調整が必要でした。
その後、潜在空間版(VAE+Transformer)でも全然ダメですね。
動画へのチャレンジ
さらには動画生成もやりましたが、全然ダメでした。
まとめ
全然ダメだったので、まずはDiTできちんと画像が出ることを目標とします。U-Netでももう少し大きい画像をトライしてみたいです。