生成AIの処理が高速になる「FlashAttention-2」 米スタンフォード大の研究者が開発
https://www.techno-edge.net/article/2023/07/24/1634.html
「Transformerの入力シーケンス長を拡大することは困難です。なぜなら、Transformerの核心であるAttention層は、入力シーケンスの長さに対して2次関数的な増加で処理時間とメモリの要求が増えるからです。
1年前、論文の著者であるTri Dao氏がFlashAttentionをリリースしました。FlashAttentionは、Attentionを高速化し、近似なしでメモリ使用量を削減する新しいアルゴリズムです(2次関数ではなく線形の特性を持っています)。これにより、FlashAttentionはベースラインよりも2-4倍高速になります。
そして、今回、FlashAttentionをさらに改良した次期バージョン「FlashAttention-2」を発表しました。A100 80GB SXM4 GPU上で、さまざまな設定の異なるAttentionメソッドの実行時間を測定しました。その結果、FlashAttention-2はFlashAttentionよりも約2倍高速であることがわかりました。さらに、PyTorchの標準的なAttention実装と比較すると、FlashAttention-2は最大9倍の高速化を実現しました。」
ということで見てみます。
Github
https://github.com/Dao-AILab/flash-attention
論文
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
Tri Dao
https://arxiv.org/abs/2307.08691
「過去数年間、トランスフォーマーをより長いシーケンス長にスケーリングすることは、言語モデリングや高解像度の画像理解における性能向上、コード、音声、ビデオ生成などの新しいアプリケーションの開発を約束する重要な課題でした。シーケンスの長さをスケーリングする際の主要なボトルネックは、アテンションレイヤーであり、そのランタイムとメモリ使用量はシーケンス長の2乗に比例して増加します。FlashAttention [5] は、非対称なGPUメモリ階層を活用して、線形(2乗ではなく)のメモリ削減とランタイムの高速化(最適化されたベースラインと比較して2-4倍速く)を実現し、近似なしで動作します。ただし、FlashAttentionはまだ最適化された行列乗算(GEMM)操作ほど速くありませんでした。理論的な最大FLOPs/sの25-40%しか達成していませんでした。
この効率の低さは、GPU上の異なるスレッドブロックとワープ間での最適でない作業の分割が原因であり、低い占有率または不必要な共有メモリの読み書きが発生しています。この問題を解決するために、これらの問題に対処するためのより良い作業分割を持つFlashAttention-2を提案します。具体的には、(1)非行列乗算FLOPsの数を減らすためにアルゴリズムを調整し、(2)アテンション計算を並列化し、1つのヘッドでも異なるスレッドブロック間で実行し、占有率を増加させ、(3)各スレッドブロック内でワープ間で作業を分散させ、共有メモリを介した通信を減少させます。これにより、FlashAttentionと比較して約2倍の高速化が実現され、A100で理論的な最大FLOPs/sの50-73%に達し、GEMM操作の効率に近づくことができます。我々は経験的に検証し、GPTスタイルのモデルをトレーニングする際に、FlashAttention-2をエンドツーエンドで使用すると、A100 GPUあたり最大225 TFLOPs/s(モデルFLOPsの72%の利用率)のトレーニングスピードに達することを確認しました。」
ちょっと内容が難しいので今回はあまり詳しい説明はできなさそうです。具体的にはGPUでの実行を最適化することで、FlashAttentionの2倍の高速化をした具体的な手法が書いてあります。