こんにちは。論文を読み始めたものです。今日はWhisperと言う文字起こし向けのモデルを高速形があると言うことなので読んでみたいと思います。
自分たちの開発した手法のベンチマークを全世界にシェアして業界発展というすごい分野ですね。
努力が紙になってるのを感じます。
Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling
Sanchit Gandhi, Patrick von Platen, Alexander M. Rush
https://arxiv.org/abs/2311.00430v1
Githubはこちら
https://github.com/huggingface/distil-whisper
まずはアブストラクトから。翻訳と自分の意見がごっちゃになって進むのでちょっと読みづらいかもしれませんが、、、
「事前訓練された音声認識モデルのサイズが増加するにつれて、これらの大規模モデルを低遅延またはリソース制約環境で実行することが難しくなります。本研究では、擬似ラベリングを利用して大規模なオープンソースデータセットを組み立て、それを使用してWhisperモデルを小型のバリアントであるDistil-Whisperに蒸留します。シンプルな単語誤り率(WER)ヒューリスティックを使用し、トレーニングのために最高品質の擬似ラベルのみを選択します。蒸留されたモデルは、パラメータが51%少ないにもかかわらず、ゼロショット転送設定での分布外テストデータで1%以内のWERを達成するほど、Distil-Whisperは、Whisperモデルの困難な音響条件に対する堅牢性を維持しつつ、長尺のオーディオでの幻覚誤りに対しても耐性があります。Distil-Whisperは、推測的デコーディングのためにWhisperとペアになるように設計されており、元のモデルと同じ出力を数学的に保証しつつ、2倍の速度向上を実現します。この分野のさらなる研究を促進するために、私たちは訓練コード、推論コード、およびモデルを公開しています。5.8倍速くなります。」
ということで、あんまり省略するところがなかったですが、モデルパラメータを51%軽量化してパフォーマンスをほぼ維持してるそうです。
近年の自動音声認識 Automatic Speech Recognition(ASR)システムは、多くの学術的ベンチマークで人間レベルの精度を超えていて、さまざまなアプリに利用されていて、OpenAIのWhisper(Radford et al., 2022)は、1.5億パラメータのシーケンス・ツー・シーケンス(Seq2Seq)トランスフォーマーモデル(Vaswani et al., 2017)で、68万時間の音声認識データで事前トレーニングされており、強い一般化能力を示しています。一方で容量が大きくなり過ぎているのはちょっと扱いが難しいそうです。
別の記事でも触れましたが、ニューラルネットワークのモデルの圧縮手法として、枝刈り、量子化、蒸留とあり、今回は蒸留がテーマです。
自然言語処理の最近の取り組みでは、トランスフォーマーベースのモデルの圧縮において、**ナレッジディスティレーション/知識蒸留(Knowledge distillation / KD)**は、BERTなどのモデルのサイズを減らすために成功してます。機械翻訳方法に触発された疑似ラベリング(pseudo-labelling / PL)アプローチもSeq2Seq要約において生成タスクの大幅な圧縮の可能性を示しています。KDはオーディオ分類において有望な結果を示していますが、より困難な音声認識のタスクに対してはまだ同様の結果は得られていないようです。
本論文では、Seq2Seq ASRの文脈でWhisperモデルに対して蒸留を適用しています。データに疑似ラベリングを行っているようです。Distil-Whisperは、Whisperがさまざまなオーディオドメインや騒がしい音響条件に対して持つロバスト性を維持することを実証し、最高のモデルはオリジナルのWhisperチェックポイントに対して1%以内のWERで性能を発揮し、5.8倍の速度で、51%少ないパラメーターで実行されます。逆に0.1%の性能向上を示すモデルもあり、オリジナルのWhisperモデルよりもハルシネーション(幻覚)を起こす傾向が低いためであるみたいです。
さらに、Whisperと同じエンコーダーの重みを共有することにより、Distil-Whisperは、Whisperの推測デコードのアシスタントモデルとして効率的に使用できます。この推測デコードでは、わずか8%のパラメーター数の増加で、推論速度を2倍に改善します。推測デコードはアルゴリズム的に、メインモデルの予測が変わらないことを保証するため、Whisperを使用する既存の音声認識パイプラインに対するドロップインの置き換えとして使用することができます。
まだ擬似ラベルは完成していないみたいで、継続的な研究のためにコードが下記に公開されているということです。
https://github.com/huggingface/distil-whisper
類似研究
NLP分野では、モデルの蒸留は、結構うまく行くらしく、パフォーマンスの低下も最小限に抑えられています。DistilBERTは、BERTの6層蒸留バージョンで、40%のモデルサイズ減少、60%の速度向上、およびGLUEベンチマークでの言語理解能力の97%保持を実現。DistilBARTというモデルは、XSUMおよびCNN/Daily Mailデータセットで元のモデルを上回り、37%のモデル圧縮と48%の速度向上を達成。しかし、蒸留されたモデルが分布内(ID)評価データでうまく機能する一方で、分布外(OOD)テストセットで事前トレーニングされた対応物よりも大幅に悪いパフォーマンスを示すようで、多様で大規模な疑似ラベル付けされたデータセットでトレーニングすることにより精度低下を回避することが試されています。
ASRタスクにも適用されていますが、エンコーダのみのモデルに焦点を当てています。Wav2Vec 2.0モデルに蒸留を適用し、79%のモデル圧縮と59%の速度向上を達成。ただし、ASRに対する蒸留モデルはしばしばWERスコアの増大を招き、つまり精度低下を招くということです。
Whisperモデルの蒸留に関する以前の研究は、主にモデルサイズとメモリフットプリントの削減に中心的でした。Shao et al. (2023) は、KDを量子化認識トレーニング(QAT)と組み合わせて適用し、わずかなパフォーマンスの低下で大幅なパラメータ削減が可能であることを示しました。しかし、このモデルは限られたデータでしか性能評価されていないみたいで、一般的には利用が出来なさそうということでした。
ここで途中ですが、擬似ラベリングと蒸留について再確認。。。
知識蒸留(Knowledge Distillation)
知識蒸留とか蒸留とか言われますが、今回の論文のメインの手法となっていて、基本的には大きなモデルでの入力出力を使って、小さなモデルを学習させる方法となっていて、その際に、大きなモデルの答えを小さいモデルに覚え込ませるような手法になっています。教師ありデータに対して、正解ラベルがないようなデータも、モデルの出力データをラベルとして扱うことを擬似ラベリングと呼び、そのラベルを擬似ラベルというそうです。
次にモデルの概要です。Whisperモデルですが、seq2seqトランスフォーマーモデルということです。エンコーダーとデコーダーから成り立ち、T個の要素を持つ特徴ベクトルXをNのトークンを持つyに変換するみたいです。
エンコーダーは、Xをある隠れ状態ベクトルHに変換するようです。
入力状態はダウンサンプリングされるため、HのパラメータMのサイズはTの半分になるようです。デコーダーはyiより前のトークンとHを使って、yiの確率分布を予測するそうです。
Whisperモデルをトレーニングするためには、各例(X1:T , y1:N)が(オーディオ、テキスト)のペアであるデータセットを想定しています。モデルは標準のクロスエントロピー(CE)損失を使用してトレーニングされ、ターゲットクラスラベルの推定確率を最大化することでインスタンスクラスを予測するようにトレーニングされます。
モデルの説明はかなり簡潔でした。ここからは知識蒸留の詳しい話に。
ナレッジディスティレーション(KD)(Hinton et al., 2015)は、より小さなモデルをトレーニングして、より大きな教師モデルの振る舞いを再現する圧縮技術です。学生モデルの予測とトレーニングラベルの間のCE損失を最小化することと比較して、KDは学生モデルに、与えられた文脈での次のトークンの可能な予測分布全体から学ぶことを可能にします。
収縮とファインチューニング
最も基本的な蒸留方法には、教師モデルをより小さな学生モデルに縮小し、学生にCEターゲットでトレーニングを行うことが含まれます。Shleifer & Rush(2020)に続いて、学生モデルを教師モデルの最大間隔層からの重みをコピーして初期化することにより、層ベースの圧縮を行います。例えば、32層の教師モデルから2層の学生モデルを初期化する場合、1番目と32番目の層を教師から学生にコピーします。Seq2Seq要約設定(Shleifer & Rush, 2020; Li et al., 2022)でのこの戦略の単純さと効果性を考慮して、すべての蒸留方法にこれを使用します。
疑似ラベリング
疑似ラベル設定(Kim & Rush, 2016)では、基底真理のテキスト転写 \( y_{1:N} \) を、対応する入力オーディオ \( X_{1:T} \) に対する教師の生成 \( \hat{y}_{1:N'} \) に置き換えます。
カルバック・ライブラーの発散
「KL発散(Kullback & Leibler, 1951)の設定では、学生モデルの完全な確率分布 \( P_i \) が、位置 \( i \) における次の可能なトークンの全セットにわたってKL発散を最小化することにより、教師モデルの完全な分布 \( Q_i \) に一致するようにトレーニングされます。これは「単語レベル」のKDと見なすことができます。この方法では、知識が可能なトークンに対するロジットを介して教師モデルから学生モデルに移転されます(Kim & Rush, 2016)。KL発散は、全クラスにわたる情報を提供し、CE損失よりも勾配のバリアンスが少ないため魅力的です。」
すいません、KLの勉強不足なので解説は飛ばします。。。
目標
最終的なKDトレーニング目標は、KL項とPL項の加重和です。ここで、\( \alpha_{KL} \) と \( \alpha_{PL} \) はそれぞれKL項と損失項のスカラー重みです。Shleifer & Rush (2020)に従って、\( \alpha_{KL} = 0.8 \) と \( \alpha_{PL} = 1.0 \) に設定します。
擬似ラベルの修正
Whisperモデルによって生成された疑似ラベルは、書き起こしエラーや幻覚の影響を受けています。正確な疑似ラベルのみでトレーニングを行うために、疑似ラベル付けされたトレーニングデータをフィルタリングするための単純なヒューリスティックを実行。
チャンク化された長形式の書き起こし
Whisperモデルは30秒の入力オーディオに対応する固定の受容野を持っており、一度により長いオーディオ入力を処理することができません。ほとんどの学術的なデータセットは、30秒未満の短い発話で構成されており、これは問題ではありません。しかし、会議の書き起こしなどの実世界の応用では、多くの分や時間にわたる長いオーディオファイルの書き起こしを必要とします。
元のWhisperの論文では、30秒間のオーディオセグメントを順次書き起こし、モデルによって予測されたタイムスタンプに従ってスライドウィンドウをシフトする長形式の書き起こしアルゴリズムを提示しています。この自動回帰アルゴリズムには、ビームサーチと温度のフォールバックが必要で、正確な長形式の書き起こしを保証します(Radford et al., 2022)。
長いオーディオファイルを隣接するセグメント間にわずかなオーバーラップを持つより小さなセグメントにチャンク化します。モデルは各チャンク上で実行され、推測されたテキストはオーバーラップ間の最長共通シーケンスを見つけることでストライドで結合されます。このストライドにより、チャンクを順次書き起こす必要なく、チャンク間の正確な書き起こしが可能になります。
推測デコーディング
推測デコーディング(SD)(Leviathan et al., 2023)は、自動回帰型トランスフォーマーモデルの推論を加速するために、より速いアシスタントモデルを使用する方法です。アシスタントモデルは、候補トークンのシーケンスを生成し、それらすべてがメインモデルによって一度のフォワードパスで検証されます。より速いアシスタントモデルで生成し、メインモデルでのみ検証フォワードパスを実行することにより、デコーディングプロセスは大幅に高速化されます。
簡単なモデルの図があります。
図1
あとは、データセットや評価、結果などでした。詳しい性能面に関しては今回の解説では触れませんので、論文をご参照ください。(量子コンピュータの論文読んでるとメソッドの理解が中心であまり結果を見ないのがよくない癖ですが、今後少しずつ知識が出てきたらパフォーマンス中心での解説に切り替えていきます)
感想
あまり機械学習に関しては、学習を始めてからそんなに時間が経っていないのですが、やはり文言や手法などを覚えるのがまずは先決かと思いました。今回の蒸留を使ったWhisperモデルに関してはかなりわかりやすい内容でした。シーモデルそのもの自体持つseq2seqと呼ばれる普通のモデルだったので、そうしたものに関してはあまり複雑な理解が必要なかったので助かりました。以上です。