Sarashina-Embedding-v1-1B: 日本語LLMをベースにしたテキスト埋め込み(2/2)~発展編~

基本編はこちら→Sarashina-Embedding-v1-1B: 日本語LLMをベースにしたテキスト埋め込み(1/2)~基本編~

TL;DR

  • 10億パラメータの日本語言語モデルであるSarashina2.1-1Bを事前学習モデルとした日本語テキスト埋め込みモデルを学習。
  • 弱教師あり学習と教師あり学習の二段階の対照学習。
  • JMTEBベンチマークで最高水準のスコアを達成。特にRetrieval、Classification等でハイスコア。
  • 弱教師あり学習と教師あり学習の二段階学習によってモデル性能が向上。
  • 事前学習のトークン数を増やすと後段のテキスト埋め込みモデルの性能が向上。
  • 弱教師あり学習のデータが増えると検索などのタスクの性能が向上。
  • 公開ページ: https://huggingface.co/sbintuitions/sarashina-embedding-v1-1b

目次

概要

こんにちは、RAGコア構築チームの福地です。

Sarashina-Embedding-v1-1Bという日本語テキスト埋め込みモデルを開発・公開しました。このブログでは、学習方法や性能向上の要因など少し発展的な内容について解説します。

公開したモデルがどんなモデルなのか、どんな性能なのかについてはSarashina-Embedding-v1-1B: 日本語LLMをベースにしたテキスト埋め込み(1/2)~基本編~をご覧ください。

Sarashina-Embedding-v1-1Bの学習方法

Sarashina-Embedding-v1-1Bは対照学習という埋め込みモデルの学習によく使われる手法を二段階に分けて行うことで構築しました。

対照学習

対照学習(contrastive learning)は、テキストや画像の埋め込みモデルの学習に使われる手法の一つです。対照学習の基本的な考え方は、「似たものは潜在空間上で近くに、違うものは潜在空間上で遠くに配置する」というものです。例えば、似た意味のテキスト同士(例:"「彼は非凡なプログラマーだ」と「彼は良い計算機科学者です」") に対応するベクトル同士を近づけ、同時に意味的に似ていないペア(例: "「彼は非凡なプログラマーだ」と「ソフトバンク株式会社は竹芝にある」")のベクトル同士を遠ざけるような学習を行います。

Sarashina-Embedding-v1-1Bも対照学習によって学習されています。

Sarashina-Embedding-v1-1Bの学習では、対照学習を二段階に分けて行うことで、広範なドメインにおける高い適応性と、検索や意味的類似度タスクでの優れた性能を実現しています。以下に、その二段階学習の詳細を説明します。

図1: 二段階対照学習の概要

Stage 1: 弱教師あり学習

第一段階では、弱教師あり学習(weakly-supervised-learning、WSL)によりモデルを対照学習で訓練します。具体的には、入力テキストとそれに意味的に関連するテキスト(例えば、質問応答ペアや、ブログのタイトルと本文のペアなど)を学習データとして用います。弱教師あり学習では、これらのペアは必ずしも完璧に正確なものである必要はなく、多少のノイズや不整合が含まれるデータを許容する(ゆえに教師と呼ばれます)ことにより学習データセットの大規模化・多様化が可能になります。

Sarashina-Embedding-v1-1Bの弱教師あり学習では、モデルは入力テキストと対応するテキストの間の類似度や関連性を学習し、テキストの意味を捉える能力を向上させます。具体的には、Sarashina 2.1-1Bを重みの初期値として、意味的に関連性のあるテキストのペアを入力として与え、そのペア間の意味的な関連性を最大化するようにモデルを訓練します。

幅広いテキストに対して適切な埋め込みベクトルを出力できるように、独自のWebクロールデータと公開されているオープンデータを組み合わせ、弱教師データセットを構築しました。以下に、弱教師データの構成を示します。1

データセット テキストペア数
Auto Wiki QA/NLI 50,521,135
独自Webクロールデータ 47,370,649
MQA 12,941,472
llm-japanese-dataset 9,074,340
Wikipedia 5,555,212
独自のQuiz dataset 988,478
Natural Questions2 132,796
JSQuAD 62,859
SNOW(T15+T23) 62,758
JaQuAD 31,746
MKQA 3,318
total 126,744,763

この第一段階の訓練により、Sarashina-Embedding-v1-1Bは入力テキストの意味的な特徴を捉えるための基本的な能力を獲得しました。しかし、これだけではまだ不十分であり、次のステップで、より正確な意味理解を可能にするために教師あり学習を行います。

Stage 2: 教師あり学習

Sarashina-Embedding-v1-1Bの更なる性能向上を目指し、第二段階では教師ありデータによるファインチューニング(supervised fine-tuning: SFT)を行います。学習方法は、弱教師あり学習と同じく対照学習です。この段階では、ラベル付きの4種類のデータでモデルをさらに訓練します。これにより、Sarashina-Embedding-v1-1Bはより高品質なテキスト埋め込みを生成し、クエリとドキュメントの一致度や、意味的類似度の評価において優れた性能を発揮するようになります。

教師あり学習に使用するデータセットは、検索や意味的類似度タスクの精度向上に効果が期待できできそうな高品質データセットで構成されています。日本語に特化した検索データセットが少ないため、補完的に英語のデータセットも取り入れています。選定されたデータセットには以下のようなものがあります。

  1. JSNLI: 約14万文のペアからなる自然言語推論データセットです。
  2. NU-MNLI: 名古屋大学が提供する、マルチジャンル自然言語推論データセットの日本語翻訳バージョンであり、約6万8千ペアが含まれます。
  3. Mr. TyDi (Japanese subset): 多言語データセットTyDiから抽出した日本語部分。3,697ペアが含まれます。
  4. Natural Questions3 (Sampled): 英語検索データセットですが、一部のサンプルを使用しました。高品質なデータであるため、日本語検索能力の向上に寄与するのではという期待から一部のサンプルを使用しました。

これらのデータセットは、入力テキスト(検索であれば検索クエリ)とその正解テキスト(検索であれば関連文書)、hard negative(正解と似ているが実は不正解なテキスト。例えば「梅干しの発祥はどこ?」というクエリに対して「『家伝秘法調合録』-黒田玄仙には、「梅をおろして汁をとり、瀬戸物に入れ天日で干す」といった内容が記されており、現在の菓子としてのし梅とほぼ製法も変わらないことからその原型とみなされている。」という文がhard negativeとして挙げられます。)の3つ組で構成されています。特に、hard negativeを用いることで、「表面上似ているが実は意味が異なるペア」などの難しい事例を識別できるような埋め込みを構成できるようになり、より高品質な埋め込みを生成できるようになります。

学習時の工夫

安定かつ高効率に学習を実行するため、いくつかの工夫をしています。まず、対照学習においては大きなバッチサイズが性能に寄与するという先行研究の知見を踏まえて、stage 1の弱教師あり学習ではバッチサイズを32,768という巨大な値に設定しています。 この大規模なバッチサイズを実現するために、複数GPUを使用してバッチを分散し、テンソルをall gatherすることで効率的に対照学習のlossを計算しています。また、GradCacheなども活用してバッチサイズを拡張しています。

JMTEBによる性能評価

Sarashina-Embedding-v1-1B: 言語モデルをベースにした日本語テキスト埋め込み(1)~基本編~でも精度評価を載せましたが、この記事でもJMTEBによる評価結果を再掲します。

詳しくは、Sarashina-Embedding-v1-1B: 言語モデルをベースにした日本語テキスト埋め込み(1)~基本編~をご覧ください。

Model Model Parameters Max Tokens Avg. Retrieval STS Classification Reranking Clustering PairClassification
OpenAI/text-embedding-3-large4 unknown 8191 74.05 74.48 82.52 77.58 93.58 53.32 62.35
cl-nagoya/ruri-large 337M 512 73.31 73.02 83.13 77.43 92.99 51.82 62.29
pkshatech/GLuCoSE-base-ja-v2 133M 512 72.23 73.36 82.96 74.21 93.01 48.65 62.37
pkshatech/RoSEtta-base-ja 190M 1024 72.04 73.21 81.39 72.41 92.69 53.23 61.74
intfloat/multilingual-e5-large 560M 512 70.90 70.98 79.70 72.89 92.96 51.24 62.15
Sarashina-Embedding-v1-1B(This model) 1.22B 8192 75.50 77.61 82.71 78.37 93.74 53.86 62.00

Sarashina-Embedding-v1-1Bは、JMTEBの16データセットの平均スコアでOpenAI/text-embedding-3-largeを上回り、全体で75.50を達成しました。特に、Retrieval(検索)タスク、Classification(文書分類)タスク、Rerankingタスク、Clustering(文書クラスタリング)タスクで最高のスコアを記録しました。

Ablation Study

Sarashina-Embedding-v1-1Bには、

  • 弱教師あり学習(WSL)と教師あり学習(SFT)の二段階の対照学習を行った
  • 巨大な事前学習コーパスで学習した言語モデルを事前学習モデルにした
  • 弱教師あり学習で比較的大規模なテキストペアデータを学習した

という3点のポイントがあります。このセクションでは、この3点のポイントの効果をAblation Studyとして実験的に確かめます。

二段階対照学習の影響

Sarashina-Embedding-v1-1Bでは、弱教師あり学習(WSL)と教師あり学習(SFT)の二段階の対照学習を行っていますが、二段階の学習方法は本当にテキスト埋め込み性能向上に寄与するのでしょうか?

E5GTEなどのEncoderベースの先行手法では、10億以上のテキストペアを用いて弱教師あり学習を行っている一方、Mistral-E5などのDecoderベースの手法では弱教師あり学習を行わずに教師あり学習のみを行っています。

Decoderモデルに対する二段階対照学習の効果を確かめるために、以下の3つの学習方法のモデルをJMTEBで評価しました。

  • Sarashina-Embedding-v1-1B: 今回公開したモデル。弱教師あり学習(WSL)と教師あり学習(SFT)の二段階の対照学習を行った
  • Sarashina-Embedding-v1-1B-WSL: 弱教師あり学習(WSL)のみを行ったモデル。教師あり学習(SFT)はしていない。
  • Sarashina-Embedding-v1-1B-SFT: 弱教師あり学習(WSL)は行わず、教師あり学習(SFT)のみを行ったモデル。

3つのモデルのJMTEBのスコアを下記の表に示します。二段階学習を行ったモデル(Sarashina-Embedding-v1-1B)が、弱教師あり学習(WSL)のみのモデル、教師あり学習(SFT)のみを行ったモデルよりも総合的に高性能であることが確認できます。特に、Retrieval(検索)タスクでの二段階学習の効果が大きいことも観察されました。

Model WSL SFT Ave. Retrieval STS Classification Reranking Clustering PairClassification
Sarashina-Embedding-v1-1B(This Model) + + 75.50 77.61 82.71 78.37 93.74 53.86 62.00
Sarashina-Embedding-v1-1B-WSL + - 72.55 73.14 77.68 77.71 93.30 50.22 62.09
Sarashina-Embedding-v1-1B-SFT - + 71.91 69.76 79.73 77.55 92.45 53.69 62.42

ベースモデルの学習トークン数による影響

今回、ベースモデルとしてSarashina2.1-1Bを用いました。このモデルは、日本語・英語・コード混合コーパスで10Tトークン分学習し、次に日本語重視コーパスで1Tトークン分学習、最後に100Bの高品質コーパスで学習を行っています。つまり、合計で11.1Tトークン分学習をしています。

このモデルは非常に巨大なテキストで学習されていますが、このモデルをテキスト埋め込みモデルのベースにした場合、学習トークン数に対してテキスト埋め込みの性能はスケールするのでしょうか? そこで、実験的に事前学習の学習トークン数とJMETBのベンチマーク結果との間の関係性を調べました。

Sarashina2.1-1Bの学習途中の3つのチェックポイント(1T、5T、10Tトークン)を事前学習モデルにして、Sarashina-Embedding-v1-1Bと同じレシピで二段階対照学習を行いました。学習後の埋め込みモデルをJMTEBで評価を行いました。

結果を図2、図3に示します。横軸が学習トークン数で、縦軸がJMTEBのスコアです。横軸が対数軸になっていることに注意してください。図の一番右側の点は、Sarashina2.1-1BをベースモデルにしたSarashina-Embedding-v1-1Bに対応しています。x軸の10~11.1Tトークンまでの日本語重視コーパスで学習した区間をグレーに着色しています。

まず、JMTEBの16データセットの平均を見ると、事前学習の学習トークン数が増えれば増えるほどスコアが向上していることがわかります。特に、10Tトークンで学習した後、さらに日本語重視コーパスと高品質コーパスで学習を行うことでスコアが大きく上昇していることが確認できます。

図2: 事前学習トークン数とJMTEB平均スコアの関係

次に、JMTEBのタスクごとの評価を見ていきましょう(図3)。6つのタスクの中で、Retrieval(検索)、Classification(分類)、Reranking(リランキング)の3つのタスクは、事前学習トークン数を増やせば増やすほどスコアが向上しています。他のタスクでは、事前学習トークン数の増加に対してスコアの明確な上昇傾向はありませんでした。

Retrieval(検索)/Reranking(リランキング)のタスクは、クエリをもとに関連度の高いテキストを検索/並べ替えするタスクです。事前学習したトークン数が増えることで、モデルに内在する事物・事柄の関係等の知識が増えていくので、モデルの検索/並べ替え能力も向上すると考えられます。Retrieval(検索)タスクでは、日本語重視コーパスと高品質コーパスによる継続学習の効果がはっきりと確認できます。高品質な日本語データには有用な知識が多く含まれているため、検索タスクのスコア向上に効果があったのではないかと考えられます。

図3: 事前学習トークン数とJMTEBの各タスクのスコアの関係

弱教師あり学習のデータ規模の影響

Sarashina-Embedding-v1-1Bは、二段階対照学習により学習されています。Stage 1の弱教師あり学習のデータ規模が、性能にどの程度寄与したのかを実験的に調べました。

実験では、弱教師あり学習のデータサンプル数の変化が、二段階学習後のJMTEBの平均スコアにどのように影響するかを調査しました。stage 1の複数のチェックポイントから、それぞれでstage 2の教師あり学習を行いました。それぞれのチェックポイントのstep数と学習ペアのサンプル数は、400 steps = 13M samples、800 steps = 26M samples、1600 steps = 53M samples、3200 steps = 105M samples、3800 steps = 125M samples に対応しています。

図4は、stage 1の学習サンプル数とJMTEBの平均スコアの関係を示しています。横軸は学習サンプル数、縦軸はJMTEBの平均スコアです。横軸が対数軸であることに注意してください。弱教師あり学習のサンプル数が26Mのポイントを除けば、学習サンプル数の増加に伴いJMTEBの平均スコアが上昇する傾向が見られます。

図4: 弱教師あり学習とJMTEBの各タスクのスコアの関係

もう少し細かく分析するために、図5で弱教師あり学習のサンプル数と各タスクのスコアの関係を見ていきましょう。Retrieval(検索)とSTS(意味的類似性)、Reranking(リランキング)のタスクで、弱教師あり学習のサンプル数を増やすことで各タスクのスコアが上がっていきそうに見えます。逆にClusteringタスクでは、弱教師あり学習のサンプル数によってスコアが大きく上下しています。特に、サンプル数が26Mで Clusteringタスクのスコアが大きくなっており、この影響によって全体スコアも学習サンプル数26Mで大きく伸びているのではないかと考えられます。

図5: 弱教師あり学習のサンプル数とJMTEBの各タスクのスコアの関係

まとめ

まず、目論見通り二段階学習がJMTEBスコアの向上に大きく貢献していることが確認できました。

次に、事前学習トークン数を増やすことで、特に検索/リランキングタスクにおいてモデルのスコアが向上する傾向にあることがわかりました。これは、事前学習トークン数の増加により、モデルがより多くの知識を獲得し、文と文の関連性を認識する能力が向上したためだと考えられます。また、日本語重視コーパスによるカリキュラム学習も、検索タスクのスコア向上に効果的でした。

弱教師あり学習のデータ規模に関しては、学習サンプル数の増加に伴い、JMTEBの平均スコアが上昇する傾向が見られました。特に、検索、STS、リランキングタスクにおいて、弱教師あり学習のサンプル数を増やすことでスコアが向上しました。

Sarashina-Embedding-v1-1Bの性能向上には、二段階学習、事前学習トークン数の増加と弱教師あり学習のデータ規模の拡大が寄与しており、特に検索タスクやリランキングタスクにおいてその効果が顕著であることがわかりました。テキスト埋め込みモデルの対照学習という文脈で、一部のタスクに限定すれば学習データのスケーリング則のようなものが成立するのかもしれません。

課題と展望

Sarashina-Embedding-v1-1Bの開発においては、まだいくつかの課題があります。

まず、Matryoshka Representation Learning (MRL)の適用が挙げられます。MRLはOpenAI/text-embedding-3-largeでも取り入れられている表現学習手法の一つです。MRLを行うことで、出力された埋め込み表現を次元削減した際の性能劣化を抑えることができるようになり、ベクトル検索のコスト低減やメモリ消費量の削減が期待できます。今後、Sarashina-Embedding-v1-1BもMRLによる学習を検討する予定です。

InstructionやPrefixへの対応も課題です。Mistral-E5Geckoなど英語圏のSOTAモデルは、InstructionやPrefixを利用することでタスクごとに適した埋め込み表現を作成できるように訓練し、精度向上を実現しています。Sarashina Embeddingの開発においても、学習データセットの多様化や合成データの活用により、InstructionやPrefixによる埋め込み表現の制御と性能向上を図りたいと思っています。

これらの課題に取り組むため、一緒に研究していただけるリサーチャー、エンジニア、インターンの方を募集しています。ご興味のある方は、ぜひご応募ください。

ライセンス

Sarashina-Embedding-v1-1BはSarashina Model NonCommercial License Agreementでライセンスされており、商用利用には制限があります。

本稿をご覧いただいた方の中で、事業やプロジェクトに Sarashina-Embedding-v1-1B を使ってみたいという場合、SB Intuitions にお問い合わせください。

<コンタクトページ>



注釈


  1. 先行研究(Wang et al.、2022など)と異なる点として、本モデルの弱教師あり学習のデータセットに、llm-japanese-dataset、Natural Questions、JSQuADなどの人手でアノテーション済みのデータが含まれている点がある。今回のデータセットの選択においては、日本語特化のテキスト埋め込みモデルの教師あり学習の学習データとしては、タスク・言語の関係で関連性が強くないと推定される人手アノテーションデータは弱教師あり学習のデータとして使うことにした。
  2. 英語の検索データセット。Sarashina-Embedding-v1-1Bは事前学習時に英語データセットを学習しており、英語の言語モデルとしても比較的性能が高いので、その英語性能を活かすためNatural Questionsをデータセットに加えた。
  3. 教師あり学習でも、日本語と英語のバランスを取るためにNatural Questions(NQ)をダウンサンプリングしてデータセットに加えている。NQのような英語データについては、かなりad hokな取り扱いをしているため最適な戦略とは限らないと考えている。今後、英語データと日本語データの最適な混合戦略を検討していきたい。
  4. 2024/4/23に社内でベンチマークを行った。