以前拡散モデルの記事を見ていましたら、拡散過程、逆拡散過程といいのですが、Stable DiffusionでのU-Net実装の詳細について把握できなかったので、個別に調べます。探してみたところ、やはりメジャーな実装なのか、結構解説記事がありました。みてみましょう。
世界に衝撃を与えた画像生成AI「Stable Diffusion」を徹底解説!
https://qiita.com/omiita/items/ecf8d60466c50ae8295b
どうやらStable Diffusionはこちらの論文の実装がベースとなっているようです!
High-Resolution Image Synthesis with Latent Diffusion Models
Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer
https://arxiv.org/pdf/2112.10752.pdf
こちらのモデルは、
1、拡散モデル
2、VAE
3、Transformer
の組み合わせとなっているそうです。ベースとなるのは、潜在拡散モデルと呼ばれる、VAEのエンコーダーを通して潜在空間に落とした後に、拡散モデルを適用して、デコーダーで戻してるようです。
拡散過程
拡散過程ではガウシアンノイズが加えられ、T=1000ステップくらいまで実行されるようです。ノイズの加え方に関してはノイズスケジューラによって、線形だったりコサインだったりと変わるようです。線形だとノイズが初期に強くなりすぎるので、コサインが良さそうです。この辺りは量子計算での量子断熱計算や量子アニーリングのスケジューリングと似てますね。
逆拡散過程
逆拡散過程でのデノイズにはU-Netが利用されることが多いようです。ノイズ付きの画像を入れるとノイズが出力されます。ノイズがついた位置エンコーディングが入力に付帯されます。
損失関数
普通にMSEでやるようです。
順伝搬
ノイズ画像から画像を生成するにはU-Netを通してノイズを獲得してノイズを除去すればOKだけど、工夫されてT=1000ではなく、50stepほどで完了するようです。
VAEを導入
拡散モデルの計算量削減のために、VAEのエンコーダを通して潜在空間でのU-Netの学習を行うということで、損失関数も拡散過程のノイズ付き画像ではなく、潜在表現に置き換えられたパラメータになります。
U-Netアーキテクチャ
U-Netのアーキテクチャは複数のブロック1とブロック2の繰り返しとなっているようです。
ブロック1
時間tの取り込みを行う層で、
「「GNorm -> Swish -> Conv」を1つのかたまりとして見ると、ResBlockはこのかたまりが2つで構成されているものということが分かります」
ということで、Group Normalization -> Swish関数 -> 畳み込みが二回繰り返されているようです。tの取り込みがGNorm直前直後の実装があるようですが、細かいことは触れません。
ブロック2
テキストプロンプトの取り込みを行う層で、Linear,Normalization層の次にセルフアテンションとクロスアテンション層が入ってるみたいです。
CLIP
テキストからベクトルへの変換はCLIPが使われているようですが、CLIPについては別途で学習します。
次です。
DiffusionモデルをPyTorchで実装する② ~ U-Net編
https://data-analytics.fun/2022/08/27/diffusion-model-pytorch-2-unet/
「U-Netはもともと医療用の画像のセグメンテーションのためのモデルとして2015年に提案されたモデルです。」知りませんでした。
「図の左側の部分でインプット画像をダウンサンプルしていきます。」
「そして、右側の部分で画像をアップサンプルしていきます。」
モデルの内部でダウンサンプルとアップサンプルがあります。
後半はひたすら実装がPyTorchで続きますが、実装が載っているのは大変ありがたいです。
時間を見て学習したいと感じました。
次です。
Stable Diffusion、UNetのすべて
https://note.com/gcem156/n/nf2672cd16a9d
こちらもよくわかりやすかったです。最後の方にLoRAの説明がありました。
「cloneofsimo氏の最初の実装ではAttention層内のLinear層のみに適用されていましたが、kohya氏はそれをTransformer内のLinearと1×1Convまで拡張しました。その後ResNet部分の3×3Convまで拡張する方法が実装されたり、KohakuBlueLeaf氏がアダマール積(LoHA)だのクロネッカー積(Lokr)だのを使うやつを実装したり」
ということで、LoRAの差し込み場所も気になります。アダマールだのクロネッカーだのは使う行列の種類が違うだけでしょうか。
コントロールネットについても言及があります。
「テキストの他にエッジや深度などの情報を画像として入力することで、生成画像をコントロールする方法です。」
見た感じは、各ブロックに画像情報の付帯情報を加える感じですね。
解説自体もすごいわかりやすかったし、派生技術でLoRAやControlNetなども知れて大変助かります。
次です。
誰でもわかるStable Diffusion その4:U-Net
https://hoshikat.hatenablog.com/entry/2023/03/31/022605
図と言葉で説明があります。これもわかりやすいです。
正直どの記事も良かったです。かつ説明の方法もばらけてていろんな角度からの説明が読めます。
Diffusion Modelの説明だけでなく、具体的なLatent Diffusion ModelやU-Net実装も見れて大変助かりますね。
以上です。