こんにちは、技術開発の三浦です。
朝早くに外に出ると、なんとなく日差しが軽くなった気がします。風も少し冷たくなって、だんだん夏も終わりに近づいている気がします。
今回はAIコミュニティHugging Faceで公開されている事前学習済みモデルを使ってImageCaptioningモデルを組み、PyTorch Lightningでファインチューニングする方法を調べたので、ご紹介したいと思います。
ImageCaptioningとPyTorch Lightning
ImageCaptioningは入力画像に対してそれを説明するテキストを生成する機械学習のタスクです。以前こちらの記事でもImageCaptioningモデルを作った話をご紹介しました。
この記事ではTransformerの構造を持つImageCaptioningモデルを作ったのですが、画像処理の部分は畳み込みニューラルネットワーク(CNN)の構造を持つモデルになっていました。CNNは事前学習済みのResNetを使用していた一方で、テキストを生成するTransformerの部分は事前学習せずに1から学習させていました。
最近よく見ているAIコミュニティ「Hugging Face」では様々な事前学習済みのTransformerベースのモデルが公開されています。Hugging Faceで公開されている事前学習済みモデルを使うことで、複雑な構造を持つImageCaptioningモデルを簡単に組み立てることが出来るのでは、と考え、試してみようと思いました。
Hugging FaceのモデルはPyTorch対応の物が多いので、こちらも最近使っているPyTorch LightningというPyTorchを簡潔に書くことが出来る機械学習フレームワークを使って、どうやってImageCaptioningモデルを学習させることが出来るのか調べてみました。
まとめると、Hugging Faceで公開されている事前学習済みのモデルをベースにImageCaptioningモデルを組み立て、組み立てたモデルをPyTorch Lightningで学習させる、ということを試しました。どのようなことを行ったのかについて、ご紹介していこうと思います。
torch.utils.data.Dataset
によるデータの準備
まずは学習に使用するデータを用意します。今回はCOCO Datasetのうち、2017の学習データを使用しました。
COCO - Common Objects in Context
COCO Datasetは1つの画像に対してキャプションラベルだけでなく、物体の領域を表すセグメンテーションラベルも提供されていますが、今回はキャプションラベルのみを使用しました。また学習用の画像は10万枚以上あるのですが、全部を使用するととても時間がかかります。今回はどうやってモデルを組んで学習させられるのかに興味があったため、使用する画像は1,000枚、キャプションは5,000件に絞り、処理をすぐに完了出来るようにしました。
JSON形式のラベルデータを、以下のようなテーブル形式のデータに加工しました。
このデータを元に、モデル学習用のデータをtorch.utils.data.Dataset
クラスのオブジェクトとして表現し、PyTorchで利用出来るようにします。コードの例を以下に記載します。
class ImageCaptionDataset(Dataset): def __init__(self,dataset,tokenizer,feature_extractor,max_token_length): ''' Parameters -------------------- dataset: DataFrame キャプションと画像パスをカラムに持つDataFrame tokenizer: Tokenizer キャプションをトークン化するHugging Face TransformersのTokenizer feature_extractor: ViTFeatureExtractor 画像をVisionTransformerに入力するための前処理 max_token_length: int 最大トークン数 ''' super().__init__() self.data = dataset self.data_size = dataset.shape[0] self.feature_extractor = feature_extractor self.tokenizer = tokenizer self.max_token_length = max_token_length def __getitem__(self, index): row = self.data.iloc[index] caption = row['caption'] img_path = row['img_path'] img = Image.open(img_path).convert('RGB') label = self.tokenizer( caption, return_tensors='pt', max_length=self.max_token_length, truncation=True, padding="max_length" ).input_ids[0] img_feature = self.feature_extractor( img, return_tensors='pt' ).pixel_values[0] return img_feature, label def __len__(self): return self.data_size
torch.utils.data.Dataset
ではデータセットからサンプルを提供する処理を__getitem__
に書きます。ここではキャプションをHugging FaceのライブラリTransformersのTokeniser
でトークン化、画像をViTFeatureExtractor
で前処理をかけた後、画像とラベルのペアでサンプルとして提供する処理を書きました。
torch.utils.data.Dataloader
とpytorch_lightning.LightningDataModule
の実装
Dataset
からサンプルをバッチで取得するtorch.utils.data.Dataloader
を、PyTorch Lightningの学習フェーズの適切なタイミングで提供するpytorch_lightning.LightningDataModule
を作ります。
class ImageCaptionDataModule(pl.LightningDataModule): def __init__( self, train_data, val_data, tokenizer, feature_extractor, batch_size=32, max_token_length=64 ): ''' Parameters -------------------- train_data: DataFrame キャプションと画像パスをカラムに持つDataFrame(学習用) val_data: DataFrame キャプションと画像パスをカラムに持つDataFrame(検証用) tokenizer: Tokenizer キャプションをトークン化するHugging Face TransformersのTokenizer feature_extractor: ViTFeatureExtractor 画像をVisionTransformerに入力するための前処理 batch_size: int 提供するサンプル数 max_token_length: int 最大トークン数 ''' super().__init__() self.train_data = train_data self.val_data = val_data self.tokenizer = tokenizer self.feature_extractor = feature_extractor self.batch_size = batch_size self.max_token_length = max_token_length def setup(self, stage): self.train_dataset = ImageCaptionDataset( self.train_data, self.tokenizer, self.feature_extractor, self.max_token_length ) self.val_dataset = ImageCaptionDataset( self.val_data, self.tokenizer, self.feature_extractor, self.max_token_length ) def train_dataloader(self): dataloader = DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=3, shuffle=True ) return dataloader def val_dataloader(self): dataloader = DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=3, shuffle=True ) return dataloader
pytorch_lightning.LightningModule
の実装
モデルや学習時の処理などを定義するpytorch_lightning.LightningModule
を実装します。
class ImageCaptionModelModule(pl.LightningModule): def __init__(self, model, tokenizer, feature_extractor, lr=0.0001): ''' Parameters -------------------- model: VisonEncoderDecoderModel Hugging Face TransformersのVisionEncoderDecoderModel モデルの本体 tokenizer: Tokenizer キャプションをトークン化するHugging Face TransformersのTokenizer feature_extractor: ViTFeatureExtractor 画像をVisionTransformerに入力するための前処理 lr: float Adam Optimizerにセットするlearning rate ''' super().__init__() self.model = model self.tokenizer = tokenizer self.feature_ext = feature_extractor self.model.config.pad_token_id = self.tokenizer.eos_token_id self.model.config.decoder_start_token_id = self.tokenizer.bos_token_id self.model.config.decoder = self.tokenizer.vocab_size self.lr = lr def training_step(self, batch, batch_idx): x, y = batch loss = self.model(pixel_values=x, labels=y).loss self.log('train_loss',loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): x, y = batch loss = self.model(pixel_values=x, labels=y).loss self.log('val_loss',loss) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr)
Hugging Faceを活用することで、複雑なモデルも非常に簡潔に定義することが出来ます!
一点詰まったところがmodel.config
を設定する箇所で、ここは使用するmodel
やtokenizer
によって変更が必要になるところだと思います。たとえばあるtokenizer
ではpad_token_id
が定義されているけど別のtokenizer
では定義されていない、といったことがあって、それを知らなかったため、別のプロジェクトで作ったモデルをそのまま使おうとして上手くいかなかった、といったことがありました。
modelやtokenizerなどのダウンロード
次に事前学習済みのモデルVisionEncoderDecoderModel
やTokenizer
, ViTFeatureExtractor
をHugging Faceからダウンロードします。VisionEncoderDecoderModel
には画像をエンコードするモデル、テキストを生成する(デコードする)モデルをそれぞれ別に、事前学習済みのモデルから指定することが出来ます。今回はエンコーダとしてVision Transformer(ViT)を、デコーダとしてGPT2を指定しました。ここは色々なモデルを指定して精度を試してみると面白そうな部分です。
encoder_model_name = "google/vit-base-patch16-224-in21k" decoder_model_name = "gpt2" model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( encoder_model_name, decoder_model_name ) tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name) tokenizer.pad_token = tokenizer.eos_token feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_model_name)
pytorch_lightning.Trainer
のセッティング
モデルの学習を行うpytorch_lightning.Trainer
のセッティングを行います。
学習状況はpl.loggers.TensorBoardLogger
でTensorBoardに出力するようにします。
#tensorboard logger logger = pl.loggers.TensorBoardLogger( save_dir=tfboard_log_dir, name=tfboard_name ) #checkpoint callback checkpoint_callback = pl.callbacks.ModelCheckpoint( monitor='val_loss', dirpath=checkpoint_path, filename='img_captions-{epoch:02d}-{val_loss:.2f}' ) #earliystopping callback earlystopping_callback = pl.callbacks.early_stopping.EarlyStopping( monitor="val_loss", mode="min" ) trainer = pl.Trainer( max_epochs=30, accelerator='gpu', devices=1, logger=logger, callbacks=[ checkpoint_callback, earlystopping_callback ] )
#TensorBoardの起動 %load_ext tensorboard %tensorboard --logdir ./tf_logs/
学習の実行
最後にpytorch_lightning.LightningDataModule
とpytorch_lightning.LightningModule
のインスタンスを生成し、学習を実行します。
data_module = ImageCaptionDataModule( train_data,val_data, tokenizer, feature_extractor ) caption_module = ImageCaptionModelModule( model, tokenizer, feature_extractor ) trainer.fit(caption_module, datamodule=data_module)
学習したモデルは以下のようにして使うことが出来ます。
#test from PIL import Image import requests import io url = 'image-url' img = Image.open(io.BytesIO(requests.get(url).content)) feature_ext = caption_module.feature_ext(img.convert('RGB'), return_tensors='pt').pixel_values output = caption_module.model.generate(feature_ext, eos_token_id=tokenizer.eos_token_id, max_length=64) display(caption_module.tokenizer.decode(output.tolist()[0],skip_special_tokens=True))
滑走路は写っていないのに「airport runway」という単語が入っていますが、ちゃんとしたキャプションを生成することが出来ているようです!
まとめ
今回はImageCaptioningモデルをHugging Faceの事前学習済みモデルで組んで、PyTorch Lightningで学習する方法について調べたことをまとめてみました。ところどころ詰まるところはありましたが、短いコードで高性能なモデルを作ることが出来ました。これからも色々なモデルを作ってみたいと思います!