CCCMKホールディングス TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

ブログタイトル

Hugging FaceとPyTorch LightningでImageCaptioningモデルを作ってみました!

こんにちは、技術開発の三浦です。

朝早くに外に出ると、なんとなく日差しが軽くなった気がします。風も少し冷たくなって、だんだん夏も終わりに近づいている気がします。

今回はAIコミュニティHugging Faceで公開されている事前学習済みモデルを使ってImageCaptioningモデルを組み、PyTorch Lightningでファインチューニングする方法を調べたので、ご紹介したいと思います。

ImageCaptioningとPyTorch Lightning

ImageCaptioningは入力画像に対してそれを説明するテキストを生成する機械学習のタスクです。以前こちらの記事でもImageCaptioningモデルを作った話をご紹介しました。

techblog.cccmk.co.jp

この記事ではTransformerの構造を持つImageCaptioningモデルを作ったのですが、画像処理の部分は畳み込みニューラルネットワーク(CNN)の構造を持つモデルになっていました。CNNは事前学習済みのResNetを使用していた一方で、テキストを生成するTransformerの部分は事前学習せずに1から学習させていました。

最近よく見ているAIコミュニティ「Hugging Face」では様々な事前学習済みのTransformerベースのモデルが公開されています。Hugging Faceで公開されている事前学習済みモデルを使うことで、複雑な構造を持つImageCaptioningモデルを簡単に組み立てることが出来るのでは、と考え、試してみようと思いました。

Hugging FaceのモデルはPyTorch対応の物が多いので、こちらも最近使っているPyTorch LightningというPyTorchを簡潔に書くことが出来る機械学習フレームワークを使って、どうやってImageCaptioningモデルを学習させることが出来るのか調べてみました。

まとめると、Hugging Faceで公開されている事前学習済みのモデルをベースにImageCaptioningモデルを組み立て、組み立てたモデルをPyTorch Lightningで学習させる、ということを試しました。どのようなことを行ったのかについて、ご紹介していこうと思います。

torch.utils.data.Datasetによるデータの準備

まずは学習に使用するデータを用意します。今回はCOCO Datasetのうち、2017の学習データを使用しました。

COCO - Common Objects in Context

COCO Datasetは1つの画像に対してキャプションラベルだけでなく、物体の領域を表すセグメンテーションラベルも提供されていますが、今回はキャプションラベルのみを使用しました。また学習用の画像は10万枚以上あるのですが、全部を使用するととても時間がかかります。今回はどうやってモデルを組んで学習させられるのかに興味があったため、使用する画像は1,000枚、キャプションは5,000件に絞り、処理をすぐに完了出来るようにしました。

JSON形式のラベルデータを、以下のようなテーブル形式のデータに加工しました。

テーブル形式で取りまとめました

このデータを元に、モデル学習用のデータをtorch.utils.data.Datasetクラスのオブジェクトとして表現し、PyTorchで利用出来るようにします。コードの例を以下に記載します。

class ImageCaptionDataset(Dataset):
  def __init__(self,dataset,tokenizer,feature_extractor,max_token_length):
    '''
    Parameters
    --------------------
    dataset: DataFrame
      キャプションと画像パスをカラムに持つDataFrame
    tokenizer: Tokenizer
      キャプションをトークン化するHugging Face TransformersのTokenizer
    feature_extractor: ViTFeatureExtractor
      画像をVisionTransformerに入力するための前処理
    max_token_length: int
      最大トークン数
    '''
    super().__init__()
    self.data = dataset
    self.data_size = dataset.shape[0]
    self.feature_extractor = feature_extractor
    self.tokenizer = tokenizer
    self.max_token_length = max_token_length

  def __getitem__(self, index):
    row = self.data.iloc[index]
    caption = row['caption']
    img_path = row['img_path']
    
    img = Image.open(img_path).convert('RGB')
    label = self.tokenizer(
        caption,
        return_tensors='pt',
        max_length=self.max_token_length, 
        truncation=True, 
        padding="max_length"
    ).input_ids[0]
    
    img_feature = self.feature_extractor(
        img,
        return_tensors='pt'
    ).pixel_values[0]
    
    return img_feature, label
  
  def __len__(self):
    return self.data_size

torch.utils.data.Datasetではデータセットからサンプルを提供する処理を__getitem__に書きます。ここではキャプションをHugging FaceのライブラリTransformersTokeniserでトークン化、画像をViTFeatureExtractorで前処理をかけた後、画像とラベルのペアでサンプルとして提供する処理を書きました。

torch.utils.data.Dataloaderpytorch_lightning.LightningDataModuleの実装

Datasetからサンプルをバッチで取得するtorch.utils.data.Dataloaderを、PyTorch Lightningの学習フェーズの適切なタイミングで提供するpytorch_lightning.LightningDataModuleを作ります。

class ImageCaptionDataModule(pl.LightningDataModule):
  def __init__(
        self, 
        train_data, 
        val_data, 
        tokenizer, 
        feature_extractor, 
        batch_size=32,
        max_token_length=64
    ):
    '''
    Parameters
    --------------------
    train_data: DataFrame
      キャプションと画像パスをカラムに持つDataFrame(学習用)
    val_data: DataFrame
      キャプションと画像パスをカラムに持つDataFrame(検証用)
    tokenizer: Tokenizer
      キャプションをトークン化するHugging Face TransformersのTokenizer
    feature_extractor: ViTFeatureExtractor
      画像をVisionTransformerに入力するための前処理
    batch_size: int
        提供するサンプル数
    max_token_length: int
      最大トークン数
    '''
    super().__init__()
    self.train_data = train_data
    self.val_data = val_data
    self.tokenizer = tokenizer
    self.feature_extractor = feature_extractor
    self.batch_size = batch_size
    self.max_token_length = max_token_length
  
  def setup(self, stage):
    self.train_dataset = ImageCaptionDataset(
              self.train_data,
              self.tokenizer,
              self.feature_extractor,
              self.max_token_length
    )
    self.val_dataset = ImageCaptionDataset(
              self.val_data,
              self.tokenizer,
              self.feature_extractor,
              self.max_token_length
    )
  
  def train_dataloader(self):
    dataloader = DataLoader(
        self.train_dataset, 
        batch_size=self.batch_size,
        num_workers=3,
        shuffle=True
)
    return dataloader
  
  def val_dataloader(self):
    dataloader = DataLoader(
        self.val_dataset, 
        batch_size=self.batch_size,
        num_workers=3,
        shuffle=True
  )
    return dataloader

pytorch_lightning.LightningModuleの実装

モデルや学習時の処理などを定義するpytorch_lightning.LightningModuleを実装します。

class ImageCaptionModelModule(pl.LightningModule):
  def __init__(self, model, tokenizer, feature_extractor, lr=0.0001):
    '''
    Parameters
    --------------------
    model: VisonEncoderDecoderModel
      Hugging Face TransformersのVisionEncoderDecoderModel
      モデルの本体
    tokenizer: Tokenizer
      キャプションをトークン化するHugging Face TransformersのTokenizer
    feature_extractor: ViTFeatureExtractor
      画像をVisionTransformerに入力するための前処理
    lr: float
        Adam Optimizerにセットするlearning rate

    '''
    super().__init__()
    self.model = model
    self.tokenizer = tokenizer
    self.feature_ext = feature_extractor
    self.model.config.pad_token_id = self.tokenizer.eos_token_id
    self.model.config.decoder_start_token_id = self.tokenizer.bos_token_id
    self.model.config.decoder = self.tokenizer.vocab_size
    self.lr = lr

  def training_step(self, batch, batch_idx):
    x, y = batch
    loss = self.model(pixel_values=x, labels=y).loss
    self.log('train_loss',loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss
  
  def validation_step(self, batch, batch_idx):
    x, y = batch
    loss = self.model(pixel_values=x, labels=y).loss
    self.log('val_loss',loss)

  def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=self.lr)

Hugging Faceを活用することで、複雑なモデルも非常に簡潔に定義することが出来ます!

一点詰まったところがmodel.configを設定する箇所で、ここは使用するmodeltokenizerによって変更が必要になるところだと思います。たとえばあるtokenizerではpad_token_idが定義されているけど別のtokenizerでは定義されていない、といったことがあって、それを知らなかったため、別のプロジェクトで作ったモデルをそのまま使おうとして上手くいかなかった、といったことがありました。

modelやtokenizerなどのダウンロード

次に事前学習済みのモデルVisionEncoderDecoderModelTokenizer, ViTFeatureExtractorをHugging Faceからダウンロードします。VisionEncoderDecoderModelには画像をエンコードするモデル、テキストを生成する(デコードする)モデルをそれぞれ別に、事前学習済みのモデルから指定することが出来ます。今回はエンコーダとしてVision Transformer(ViT)を、デコーダとしてGPT2を指定しました。ここは色々なモデルを指定して精度を試してみると面白そうな部分です。

encoder_model_name = "google/vit-base-patch16-224-in21k"
decoder_model_name = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_model_name,
    decoder_model_name
)

tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name)
tokenizer.pad_token = tokenizer.eos_token
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_model_name)

pytorch_lightning.Trainerのセッティング

モデルの学習を行うpytorch_lightning.Trainerのセッティングを行います。 学習状況はpl.loggers.TensorBoardLoggerでTensorBoardに出力するようにします。

#tensorboard logger
logger = pl.loggers.TensorBoardLogger(
    save_dir=tfboard_log_dir, 
    name=tfboard_name
)

#checkpoint callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath=checkpoint_path,
    filename='img_captions-{epoch:02d}-{val_loss:.2f}'
)

#earliystopping callback
earlystopping_callback = pl.callbacks.early_stopping.EarlyStopping(
    monitor="val_loss", 
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=30,
    accelerator='gpu', 
    devices=1,
    logger=logger,
    callbacks=[
        checkpoint_callback,
        earlystopping_callback
    ]
)
#TensorBoardの起動
%load_ext tensorboard
%tensorboard --logdir ./tf_logs/

学習の実行

最後にpytorch_lightning.LightningDataModulepytorch_lightning.LightningModuleのインスタンスを生成し、学習を実行します。

data_module = ImageCaptionDataModule(
    train_data,val_data,
    tokenizer,
    feature_extractor
)
caption_module = ImageCaptionModelModule(
    model, 
    tokenizer, 
    feature_extractor
)

trainer.fit(caption_module, datamodule=data_module)

学習したモデルは以下のようにして使うことが出来ます。

#test
from PIL import Image
import requests
import io

url = 'image-url'
img = Image.open(io.BytesIO(requests.get(url).content))

feature_ext = caption_module.feature_ext(img.convert('RGB'), return_tensors='pt').pixel_values
output = caption_module.model.generate(feature_ext, eos_token_id=tokenizer.eos_token_id, max_length=64)
display(caption_module.tokenizer.decode(output.tolist()[0],skip_special_tokens=True))

COCO Datasetの画像を入力してみました

滑走路は写っていないのに「airport runway」という単語が入っていますが、ちゃんとしたキャプションを生成することが出来ているようです!

まとめ

今回はImageCaptioningモデルをHugging Faceの事前学習済みモデルで組んで、PyTorch Lightningで学習する方法について調べたことをまとめてみました。ところどころ詰まるところはありましたが、短いコードで高性能なモデルを作ることが出来ました。これからも色々なモデルを作ってみたいと思います!