こんにちは、CCCMKホールディングスTECH LABの三浦です。
今年もあと1か月ですね。振り返ってみると色々とあった1年ですが、個人的には夏の厳しい暑さが印象に残っています。残りわずかですが、2024年の最後まで頑張ろうと思います。
今回は最近読んで印象に残った論文の内容についてまとめてみたいなと思います。次の論文です。
【Title】Rho-1: Not All Tokens Are What You Need
【Authors】Zhenghao Lin, Zhibin Gou, Yeyun Gong, Xiao Liu, Yelong Shen, Ruochen Xu, Chen Lin, Yujiu Yang, Jian Jiao, Nan Duan, Weizhu Chen
【Submitted】11 Apr 2024 (v1), last revised 23 May 2024
【arXivURL】https://arxiv.org/abs/2404.07965
一般的な機械学習モデルの学習において、学習に使うデータの品質はとても重要です。たとえば外れ値が含まれているとモデルの学習に悪影響を与えるので、学習前に取り除いておく、などの工夫が行われます。Large Language Model(LLM)の学習においてもデータの品質は重要で、これまでも文章単位でのフィルタリングによってLLMの性能が向上することが確認されているそうです。
この論文ではさらに踏み込んで、文章の中のトークン単位までフィルタリングを行うことで、モデルの性能が向上することが述べられています。性能が改善するだけでなく、結果的に学習時に使うトークン数が削減されることにより学習処理が効率化する、というメリットもあります。
以降で論文で述べられている、どうやって学習に使うトークンをフィルタリングするのか、それによってどのような改善が見られるのかについてまとめていきたいと思います。
学習中のトークンごとの損失の変化
最初に論文では、トークンに対する学習中の損失の変化の仕方によって4つのグループに分類しています。具体的にはTinyllama-1Bという事前学習済みのLLMをOpenWebMathという数学の文章のデータセットを用いて継続事前学習を行った際の、学習の進行具合(学習に使ったトークン数の推移)によって各トークンごとの損失の変化を調べています。
それによると、全体の26%のトークンは学習が進むにつれ損失が減少し(H→L)、大部分の51%は常に低い損失(L→L)を示したそうです。H→Lのトークンは継続事前学習においてモデルが学習出来たトークン、L→Lのトークンはすでに事前学習の段階でモデルが学習済みのトークンと捉えることが出来ます。一方11%のトークンは損失が下がらず(H→H)、さらに12%のトークンは逆に損失が上がってしまう(L→H)傾向が見られたそうです。
さらに損失の変動(ブレ具合)にも注目がされていて、L→LおよびH→Hのトークンは学習が経過しても損失が収束する傾向が見られないことが示されています。モデル全体の学習を考慮した時にこれらのトークン、つまりL→LとH→Hの傾向を示すトークンが損失関数の最適化の妨げになることが予想されます。
不要なトークンを学習中に除くSelective Language Modelingというテクニック
では学習に悪影響を及ぼしかねないトークンをどうやって除外するのでしょうか?論文で提案されているのが"Selective Language Modeling"(SLM)というテクニックです。
このテクニックでは最初に学習させたいモデルとは別に、同じ構造のモデルを高品質なデータで学習させた"Reference Model"を構築します。
Reference Modelを学習したら、継続事前学習用のデータに含まれるトークンに対し、次の計算式で定義されるReference Lossを計算します。
ここでは 番目のトークンより前のすべてのトークンを表しています。そして継続事前学習の際には、学習対象のモデルによる損失 とReference Loss によって定まる以下のスコア
に基づいて損失計算に使用するトークンを選択します。このスコアはexcess lossと呼ばれています。学習時、バッチ内のトークンに対してexcess lossを計算し、excess lossの上位 %が損失計算に使用されます。 はハイパーパラメータになりますが、論文の検証によると60%付近の値を設定するとモデルの精度が良い結果になるようです。
excess lossの意味合いとしては、学習対象のモデルによる損失が高く、かつReference Modelによる損失が低い場合に高い値になることから、高品質なデータで学習したReference Modelが学習済み(つまり高品質なデータセットと同分布のトークン)で、かつ学習対象のモデルが未学習のトークンが優先的に選択されるように設定されたスコアであることが伺えます。
SLMによる効果
全てのトークンを用いて継続事前学習(Continual Pretraining: CT)とSLMによるCTの場合でどのような改善が見られたのか、論文では様々な視点から評価をしています。そのうち、数学のデータセットで継続事前学習をしたあとのFew-shot CoT(Chain of Thought) reasoningの精度検証の結果を転載します。
モデルは2タイプで検証されていて、1つがTinyllama-1B, もう一つがMistral-7Bです。注目すべき点はTynyllama-CTとRHO-1-MATH、Mistral-CTとRHO-1-MATHの精度の向上と使用トークン数の削減です。いずれも6Bまたは4.5Bほど学習時のトークンを削減しつつ、かつ全てのベンチマークにおいて精度が向上していることが分かります。この結果を見て、ここまで精度が上がるものなんだ、とちょっとびっくりしました。トークン単位でのデータの選別が、いかに重要な前処理であるかが分かる結果だと思います。
学習の効率化、という観点では次の結果も興味深いです。学習の推移とFew shot reasoningタスクの精度の推移を、通常のCTとSLMでの比較が行われた結果です。同じトークン数を処理しているにも関わらず、SLMの方がはるかに速いスピードで精度の向上が実現出来ていることが分かります。
まとめ
今回は事前学習時にデータセット内の全てのトークンを使うのではなく、Reference Modelによる損失を加味したexcess lossによってトークンをフィルタリングすることで、効率的にモデルの学習が行えることを示した論文を読み、内容をまとめてみました。
この論文、まもなく開催される機械学習の国際会議"NeurIPS 2024"のAccepted Papersの1つです。実は私もこれからNeurIPS 2024に現地参加する予定です!現地参加レポートをまた別の記事でまとめたいと思いますので、お楽しみにお待ちください!