Stable Diffusionなどの拡散モデルを調べているとU-Net実装に行き着いて、転置畳み込みのアップサンプリングに行き着きますよね。調べてみます。
転置畳み込み (Transposed Convolution, Deconolution)
https://cvml-expertguide.net/terms/dl/layers/convolution/transposed-convolution/
「転置畳み込みでは,畳み込み層(基本型)と逆の操作を行う(2節,図1).これにより「(1)カーネルサイズに沿った空間サイズのアップサンプリング」と「(2)空間畳み込み処理」の2つを,同時に効率的に実行できる」
とのこと。
「このように,元の重み行列の転置行列を用いる演算に(ニューラルネットワーク的にも)相当するので,「転置」畳み込み層と呼ばれるようになった」
名前の由来も書いてありました。
実装方法については2つあるようです。
1、転置畳み込み
畳み込みの逆操作をすることに相当しています。
2、上昇畳み込み
サイズ拡大をしてから畳み込みをすることで同じ効果が得られるようです。
とてもわかりやすい記事でした。
次の記事です。
転置畳み込みは(逆畳み込み)とは?画像生成に使われる手法を分かりやすく解説
https://nisshingeppo.com/ai/whats-deconvolution/
この記事では、転置畳み込みは逆の操作ではないと紹介がされていました。上の記事では二つの手法が別々にあって結果は同じという感じでした。この辺りは後で検証でもいいかな。。。
strideの大きさとpaddingの大きさによって幾つかの例題を解いていてとてもわかりやすいです。
次の記事は、
ニューラルネットワークにおけるDeconvolution
https://qiita.com/shngt/items/9c86e69e16ce6d61a0c6
転置畳み込みや逆畳み込みと呼ばれるので、ここでは、Deconvolutionの名前で紹介されています。
転置畳み込みは、Transposed Convolutionですね。
こちらも基本的には畳み込みはサイズ拡大からの畳み込みの手法で紹介されていました。
次の記事は、
<学習シリーズ>Pytorchの転置畳み込み(ConvTranspose2d)の確認
https://note.com/kiyo_ai_note/n/ne4d78a36de04
最後の記事はPyTorchのConvTranspose2dという機能を使って実践的に転置畳み込みを行っていました。大変興味深いですね。
上の記事での逆畳み込みに関して、単に畳み込みの逆をするという発想のほかに、やはり拡大してから畳み込みの方が一般的のようでした。せっかくなので、続いて自分でもPyTorchで実装をしてみたいです。
その前に、逆畳み込みと上昇畳み込みを確認してみます。
まずは通常の畳み込み。入力が3*3で、フィルタが2*2、ストライドが1で出力は2*2
次に転置畳み込みですが、
入力が2*2、フィルタは2*2、ストライドが1で、入力の各コマにフィルタの数字をかけてずらしていくといけました。
こんな感じの出力です。
ただ、実際には、下記のように最初に入力を拡大してからフィルタをかけて畳み込みをするのと変わりませんでした。
基本的には同じことのようですので、拡大してから出力に直す方が形式的にできそうです。
次にこちらの記事を参考にPyTorchで転置畳み込みの機能を使ってみました。
<学習シリーズ>Pytorchの転置畳み込み(ConvTranspose2d)の確認
https://note.com/kiyo_ai_note/n/ne4d78a36de04
問題は上記と同じで、2202を1001のカーネルで逆畳み込みします。
import torch
import torch.nn as nn
#入力テンソルを設定
tensor = torch.tensor([[[2, 2],[0, 2 ]]]).float()
#様々な設定値があったが、チャネルは1、カーネルサイズは2、paddingはデフォルトで1、biasをFalseにする必要があり。
conv_trans = nn.ConvTranspose2d(1, 1, kernel_size=2, bias=False)
#こちらでカーネルを変更
with torch.no_grad():
conv_trans.weight = nn.Parameter(torch.tensor([[[[1, 0],[0, 1]]]]).float())
#結果の取得
conv_trans(tensor)
結果として、
tensor([[[2., 2., 0.],
[0., 4., 2.],
[0., 0., 2.]]], grad_fn=<SqueezeBackward1>)
うまく逆畳み込みができました。
今回はU-Netで使われる転置畳み込み、逆畳み込みについて実行の仕方と最後PyTorchでの簡単な確認を行いました。以上です。