CCCMKホールディングス TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

ブログタイトル

Hugging FaceのDatasetsとTransformersで作ったテキスト分類モデルをSHAPで可視化してみました。

こんにちは、CCCMKホールディングス技術開発の三浦です。2月も終わり、明日から3月になります。つい先日、お正月を過ごしていた気がするのですが、あっという間にそれから2か月も経ってしまいました。時間が早く過ぎていくことは決して悪いことではないと思いますが、もう少しゆっくりでもいいのになぁとも感じます。色々夢中になっているとあっという間に時間が過ぎてしまうので、時々手を止めてぼんやりする時間をもっと作ってみようかな、と考えています。

最近、説明可能なAI(Explainable AI;XAI)について調べており、その中でSHAPというXAIのフレームワークを知り、前回画像分類モデルに適用してみた話をご紹介しました。

techblog.cccmk.co.jp

今回は自然言語モデルにSHAPを適用してみたいと思い、その方法について調べてみました。モデルはHugging FaceのTransformersの事前学習済みモデルを使わせて頂きました。Transformersでは色々な学習済みのモデルをダウンロードして使うことが出来ますが、そのモデルをファインチューニングして別のタスクに利用しようとすると、モデルだけでなく学習用のデータも必要になります。Hugging FaceではDatasetsというライブラリを通じて様々なデータセットをダウンロードして使うことが出来るようで、今回はそちらの方法についても調べてみました。

Hugging Face Datasets

DatasetsはHugging Face社が公開しているライブラリです。

huggingface.co

このライブラリを通じ、Hugging Face Hubでシェアされている各種データセットを簡単にダウンロードして使うことが出来ます。今回はテキスト分類モデルを構築したいと考え、それに合うデータを探してみました。以下のデータセットは様々な言語のテキストに対して3つのラベル"positive"("肯定"), "neutral"("中立"), "negative"("否定")が付与されたデータセットになっており、日本語のデータも含まれています。今回の用途にはぴったりなので、こちらのデータセットを使わせて頂きました。

huggingface.co

事前学習済みモデル

事前学習済みのモデルは同じくHugging FaceのTransformersのモデルを使わせて頂きました。

huggingface.co

データセットの用意

それではDatasetsを使ってデータセットの準備をします。

ダウンロード

今回使用するデータセットはさらにSubsetという形で言語ごとに分割されています。ダウンロードするときはデータセット名をpathで、Subset名をnameで指定してload_dataset関数を使用します。

from datasets import load_dataset

dataset = load_dataset(
    path='tyqiangz/multilingual-sentiments',
    name='japanese'
)

datasetを表示すると、ダウンロードしたデータセットはさらに"train", "validation", "test"に分かれており、それぞれ"text", "source", "label"というフィールドを持っていること、データの件数を確認することが出来ます。

DatasetDict({
    train: Dataset({
        features: ['text', 'source', 'label'],
        num_rows: 120000
    })
    validation: Dataset({
        features: ['text', 'source', 'label'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['text', 'source', 'label'],
        num_rows: 3000
    })
})

"label"には0, 1, 2の値が格納されていて、0が"肯定"、1が"中立"、2が"否定"を表しています。

加工する

ダウンロードしたデータセットをモデルに入力できるように加工していきます。トークン化の処理を行い、データのサンプリングを行っていきます。こちらはHugging Faceのチュートリアルを参考に組みました。

huggingface.co

# tokenizerのダウンロード
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_path = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = AutoTokenizer.from_pretrained(model_path)

# トークン化を行う関数
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True,max_length=64)

# datasetsの加工
tokenized_dataset = dataset.map(tokenize_function,batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['text','source']).rename_column('label','labels')
tokenized_dataset.set_format('torch')

train_shuffled = tokenized_dataset['train'].shuffle(seed=12).flatten_indices()
train_dataset = train_shuffled.select(range(8000))
valid_dataset = train_shuffled.select(range(8000,10000))

これで学習用に8,000件、検証用に2,000件のトークン化済みのデータセットを得ることが出来、データセットの準備は完了です。とてもシンプルなので、"ちょっと何か試してみたい"といった時にとても便利だと思いました!

PyTorch-LightningのLightningDataModule, LightningModule, Trainerの準備

今回もPyTorch-Lightningを使ってモデルを学習させました。PyTorch-Lightningのいつものメンバー、LightningDataModule, LightningModule, Trainerの3人の登場です。

まずはデータセットの管理供給を担うLightningDataModuleです。

# LightningDataModule

from pytorch_lightning import LightningDataModule,
from torch.utils.data import DataLoader

class SentenceData(LightningDataModule):
  def __init__(self,batch_size):
    super().__init__()
    self.batch_size = batch_size
  
  def setup(self, stage):
    self.train_dataset = train_dataset
    self.valid_dataset = valid_dataset
  
  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
  
  def val_dataloader(self):
    return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False)

次にモデルの管理とその振る舞いを制御するLightningModuleです。

# LightningModule

from pytorch_lightning import LightningModule
from torch.optim import Adam

class SentenceModel(LightningModule):
  def __init__(self, model, lr=0.02):
    super().__init__()
    self.lr = lr
    self.model = model

  def forward(self, x):
    self.model.eval()
    return self.model(**x)

  def training_step(self,batch, batch_idx):
    output = self.model(**batch)
    loss = output.loss
    self.log('train_loss',loss)
    return loss
  
  def validation_step(self, batch, batch_idx):
    output = self.model(**batch)
    loss = output.loss
    self.log('val_loss',loss)
  
  def configure_optimizers(self):
    return Adam(self.parameters(), lr=self.lr)

そして学習処理を担当するTrainerです。今回はGoogle colaboratoryで動かしており、学習のログにはTensorBoardを使用しますEarlyStoppingTensorBoardLoggerを使って過学習を避け学習時間短縮を図るのと、学習結果の記録付けを行いました。

  
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# Callback/Logger
es_callback = EarlyStopping(monitor='val_loss',mode='min')
tf_logger = TensorBoardLogger('/path/to/logger', 'sentence_clsf_log')

# Trainer

# 事前学習済みモデルのダウンロード
model = AutoModelForSequenceClassification.from_pretrained(model_path,num_labels=3)
model.train()
plmodel = SentenceModel(model)
pldata = SentenceData(batch_size=32)
trainer = Trainer(
    max_epochs=3,
    accelerator='gpu',
    devices=1,
    default_root_dir='/path/to/logger',
    callbacks=[es_callback],
    logger=[tf_logger]
)

また、学習率の探索をテストの意味も込めて実施してみました。

lr_finder = trainer.tuner.lr_find(model=plmodel,datamodule=pldata)
new_lr = lr_finder.suggestion()

今回は0.0001という学習率が提示されました。

学習の実行

あとは学習を実行します。

plmodel = SentenceModel(model,lr=new_lr)
trainer.fit(model=plmodel,datamodule=pldata)

動作確認

自分で入力したテキストに対して、どのようにモデルが判定するか、試してみます。

import torch
text = 'これは使いやすくて買ってよかったです'

tokenized_text = tokenizer(
    text,
    return_tensors='pt',
    padding='max_length', 
    truncation=True,
    max_length=64
)

trained_model = plmodel.model
trained_model.eval()

with torch.no_grad():
  print(plmodel.forward(tokenized_text).logits[0].tolist())

結果は以下のようになり、0番目のクラスに対するスコアが最も高くなりました。このデータセットでは0は"肯定"なので、結果は正しいです。

[2.1547601222991943, -0.27284368872642517, -2.4547019004821777]

今度は"これは使いにくく、買って失敗しました"というテキストを入力すると、

[-2.044304370880127, 0.2790968716144562, 1.871802568435669]

となり、(0から数えて)2番目のクラスのスコアが最も高くなりました。2は"否定"なので、これも正しい結果です。

Pipelinesを使った推論処理の簡略化

Transformersのモデルを使ってテキストに対して推論を行う時、トークン化してモデルに入力、といった2つのステップが必要ですが、この処理はTransformersのPipelinesを使うことで以下のように簡単に書くことが出来ます。

from transformers import pipeline

pipe = pipeline('sentiment-analysis',model=trained_model, tokenizer=tokenizer)
print(pipe('これは使いやすくて買ってよかったです!'))

結果です。

[{'label': 'LABEL_0', 'score': 0.936873733997345}]

pipelineに渡す最初のパラメータは対応するタスクを表す文字列を指定します。これは任意の文字列ではなく、Transformersが対応したタスクの中から選び指定する必要があります。

最後にSHAPを使って入力テキストのどこにモデルが注目して結果を出力したのかを可視化してみます。これもpipelineを使うことで簡単に実行することが出来ます。

SHAPによる可視化

前回は画像に対してでしたが、今回はテキストに対してSHAPを実行します。前回はExplainerのパラメータに自作の関数を指定しましたが、今回は先ほどのpipelineを指定します。

# SHAP
import shap
explainer = shap.Explainer(pipe)

shap_value = explainer(['これは使いやすくて買ってよかったです!'])
# 0("肯定")スコアに対する出力の根拠の可視化
shap.plots.text(shap_value[0,:,0])

すると以下のような図が表示されます。

推計に対する各トークンの寄与度が表示されます。

これは肯定的な内容のテキストに対するラベル0("肯定)スコアへの各トークンの寄与を表しています。このテキストにおいては"!"がもっともモデルの"肯定"という判断にプラスに影響をしているようです。

また、いくつかのテキストをまとめて入力して評価することも出来ます。"肯定", "中立", "否定"を意識して自分で書いてみたテキストを入力してみました。

shap_value = explainer(
    [
        '小腹が空いたときのおやつように買いました!歯ごたえがあって、ちょっと食べただけでお腹が膨れるので大満足です!',
        '味はとても良く気に入っているが、量が少なくてコスパはあまりよくないと思う。',
        'ピリ辛とのことで試してみたのですが、私には辛すぎて食べられませんでした・・・。'
    ]
)
shap.plots.text(shap_value[:,:,:])

複数のテキストに対するSHAPの結果

各ラベルを切り替えて、その出力に対するモデルが注目したトークンを見ることが出来ます。一番下のテキストは"否定"のつもりで書いた文章なのですが、モデルは"LABEL_1":"中立"と判断しています。上の図では"LABEL_2":"否定"のスコアに対する各トークンの寄与を表示していますが、これによれば"すぎ"というトークンが"否定"の出力にプラスに働いているようで、これは納得感があります。一方"辛"というトークンは"否定"の予測にはマイナスに働いているようです。"ピリ辛", "辛すぎ"の2つの表現で"辛"を使ってみたのですが、この辺りの意味のとらえ方は確かに難しいのかもしれません。

まとめ

ということで、今回はHugging FaceのDatasetsとTransformersでテキスト分類モデルを作り、SHAPでモデルの出力を可視化する方法を調べてみた話をまとめてみました。SHAPをテキスト分類モデルで試せたことも良かったのですが、今回は特にHugging FaceのDatasetsやTransformersのPipelinesの使い方を知ることが出来たのがとても大きな収穫になりました。これからも活用していきたいと思います!

ちなみに・・・

今回モデルを学習している中で、"モデルのlossが全然下がらない", "モデルにどんなテキストを入力しても同じ出力になる"という現象が発生しました。色々試し、最終的には上手く学習が出来るようになったのですが、恐らく以下の2点のいずれかが問題だったのでは・・・と考えています。自分への備忘録も兼ね、こちらに残しておきます。

  • バッチサイズが小さ過ぎる?最初は8で実行していたのですが、32まで増やしてみました。その分トークン化の時にmax_lengthを調整してデータが大きくなり過ぎないようにしました。
  • PyTorchのtrainモードとevalモードがなんらかの原因で入れ替わってしまった?model.eval()のように実行するとevalモードになり、ネットワーク上のDropoutなどのいくつかのレイヤが実行されなくなります。なぜかevalモードで学習処理を行ってしまい、上手く学習が進まなかったのかもしれません。