CCCMKホールディングス TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

Parquet+Petastormを使って画像分類モデルをSparkで学習させてみました!

こんにちは、CCCMKホールディングス技術開発の三浦です。

最近寒い日が続いています。寒いと温かい飲み物が欲しくなりますが、近ごろは緑茶を飲むようになりました。お湯を入れたらすぐに飲むことが出来る粉末タイプのものもあって、気軽に楽しむことが出来ます。

今回の記事は分散処理フレームワークSpark周りについて調べた内容です。普段深層学習モデルの分散学習をDatabricksを通じ、Sparkクラスタで行っています。その中で最近少し引っかかっていたのが画像やテキストなどのモデル学習用データを読み取る処理がボトルネックになっている点でした。この部分をどう改善すれば良いのかなかなか分かりませんでした。

今のデータの入力処理は特にSparkの特徴を活かしきれているとは言えず、TensorFlowやPyTorchのDataLoaderを通じて都度画像ファイルやテキストファイルを読み込んでモデルに入力させています。Spark周辺について色々調べてみると、画像やテキストファイルをそのまま扱うのではなく、Sparkの処理に適した"Parquet"というファイルフォーマットに変換することで効率的にデータの入力処理を行うことが出来るようです。

深層学習モデルにParquet形式のデータをDataLoader経由で入力させる方法として、"Petastorm"というライブラリを利用する方法があります。今回はまず画像データをParquetファイルに変換し、Petastormを使ってPyTorchのDataloaderで取り込めるようにし、PyTorch-Lightningでモデルを学習させてみるところまで試してみました。

まだ調査不足なところもあり、とりあえず動かせた!レベルに留まっていますので、ご参考までにご覧いただければ幸いです

データの一連の流れ

使用したデータセット

今回の検証に使用したデータセットはFood-101という、101種類の料理の画像で構成されるデータセットです。

data.vision.ee.ethz.ch

それぞれのクラスごとに1,000枚ずつの画像が含まれています。うち750枚が学習用に設定されているので、学習用のデータは合計750,750枚75,750枚になります。

torchvisionのDatasetsの中にもこのデータセットは含まれおり、ライブラリを通じてダウンロードして使用することが出来ます。

from torchvision import datasets
food101_dataset = datasets.Food101(root='./',download=True)

rootで指定したディレクトリに"food-101"というディレクトリが生成され、その中の"images"にクラスごとに画像が格納されます。

imagesディレクトリの中

このデータを使って以降進めていきます。

SparkのDataFrameとしてデータセットを読み込む

まずこの画像データセットをSparkのDataFrameとして読み込んでみます。

import pyspark
ds = spark.read.format('binaryFile').load('./food-101/images/*')
display(ds)

DataFramedspathに画像ファイルのパス、contentに画像のバイナリデータが格納されます。

databricks notebook上で表示させています

とりあえずデータセットをそのままSpark DataFrameに読み込んでみましたが、処理に必要になるのは画像データとその画像が対応するクラスのインデックスです。データセットを読み込み、必要な加工もDataFrameを生成するタイミングで施すことが出来ます。

import pyspark
from pyspark.sql.functions import col, udf
from pyspark.sql.types import LongType

'''
パスに含まれるクラス名からクラスのインデックスに変換する
'''
def path_to_index(path):
  return labels.index(path.split('/')[-2])

#Sparkで使用するユーザー定義関数にする
path2indexUDF = udf(lambda x:path_to_index(x),LongType()) 

train_dataset, valid_dataset, test_dataset = spark.read.format('binaryFile')\
    .load('./food-101/images/*')\
    .withColumn('label',path2indexUDF(col('path')))\
    .select('label','content')\
    .randomSplit([0.7, 0.2,0.1], seed=13)

train_dataset = train_dataset.repartition(4)
valid_dataset = valid_dataset.repartition(4)
test_dataset = test_dataset.repartition(4)

上記の処理の中で70%, 20%, 10%の割合で学習用(train_dataset)と検証用(valid_dataset)、そしてテスト用(test_dataset)にデータセットを分割する処理も行っています。 DataFrameのカラムpathに適用するユーザー定義関数(UDF)path2indexUDFの返す値の型はLongType()で指定しています。IntType()なども指定できるのですが、IntType()を指定すると学習時にPyTorchで損失を計算する時にエラーが発生しました。

repartition()を実行することで生成されるParquetファイルの数が調整出来るようです。小さいサイズのファイルをたくさん作るのは処理効率に悪影響を及ぼすようですが、具体的にどれくらいの値を目安にしたらよいのかはまだ分かっておらず、今後色々試していこうと思います。

Parquetファイルが出力されるCacheディレクトリを指定する

SparkのDataFrameをParquetファイルに出力する際の、出力先となるCacheディレクトリを指定します。

from petastorm.spark import SparkDatasetConverter
CACHE_DIR = "file:///dbfs/path/to/cache"
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, CACHE_DIR)

Spark DataFrameからParquetへの変換

Petastormのmake_spark_converterを使ってDataFrameをParquetファイルに変換します。

from petastorm.spark import make_spark_converter
train_converter = make_spark_converter(train_dataset)
valid_converter = make_spark_converter(valid_dataset)
test_converter = make_spark_converter(test_dataset)

ここの処理はかなり時間がかかります。

PyTorch-Lightningを使った学習処理の実装

ここからはPetastormを通してデータを読み取りながら深層学習モデルを学習する処理を実装していきます。こちらのdatabricksのブログを参考にしました。

www.databricks.com

このブログの中にリンクが貼られているdatabricks notebook "Building the PyTorch Lightning Modules"と" Main Execution notebook"を見ながら処理を作っていきました。

今回学習させるモデルは事前学習済みのResNet50を使用したFood-101による料理画像の分類モデルです。

LightningDataModule

PyTorch-LightningのLightingDataModuleを継承したクラスを実装します。このクラスはインスタンス化する際に先ほど生成したtrain_converterなどのSparkDatasetConverterオブジェクトを受け取り、内部でPyTorchのDataLoaderを生成します。

class Food101DataModule(pl.LightningDataModule):
  
  def __init__(self, train_converter,valid_converter,test_converter, batch_size, device_id=0, device_count=1):
    super().__init__()
    self.train_converter = train_converter
    self.valid_converter = valid_converter
    self.test_converter = test_converter
    self.train_dataloader_context = None
    self.valid_dataloader_context = None
    self.test_dataloader_context = None

    self.batch_size = batch_size
    self.device_id = device_id
    self.device_count = device_count
    
  def train_dataloader(self):
    if self.train_dataloader_context:
      self.train_dataloader_context.__exit__(None, None, None)
    self.train_dataloader_context = self.train_converter.make_torch_dataloader(
          batch_size=self.batch_size, 
          num_epochs=None,
          transform_spec=self.get_transform_spec(),
          cur_shard=self.device_id,
          shard_count=self.device_count,
    )
    return self.train_dataloader_context.__enter__()
    
  def val_dataloader(self):
    if self.valid_dataloader_context:
      self.valid_dataloader_context.__exit__(None, None, None)
    self.valid_dataloader_context = self.valid_converter.make_torch_dataloader(
          batch_size=self.batch_size, 
          num_epochs=None,
          transform_spec=self.get_transform_spec(),
          cur_shard=self.device_id,
          shard_count=self.device_count,
    )
    return self.valid_dataloader_context.__enter__()
    
  def test_dataloader(self):
    if self.test_dataloader_context:
      self.test_dataloader_context.__exit__(None, None, None)
    self.test_dataloader_context = self.test_converter.make_torch_dataloader(
          batch_size=self.batch_size, 
          num_epochs=None,
          transform_spec=self.get_transform_spec(),
          cur_shard=self.device_id,
          shard_count=self.device_count,
    )
    return self.test_dataloader_context.__enter__()
    
  def teardown(self, stage=None):
    if self.train_dataloader_context:
      self.train_dataloader_context.__exit__(None, None, None)
    if self.valid_dataloader_context:
      self.valid_dataloader_context.__exit__(None, None, None)
    if self.test_dataloader_context:
      self.test_dataloader_context.__exit__(None, None, None)
    
  def transform_row(self, data):
    transformers = [transforms.Lambda(lambda x: Image.open(io.BytesIO(x)))]
    transformers.extend([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    trans = transforms.Compose(transformers)
    data['X'] = data['content'].map(lambda x:trans(x).numpy())
    data = data.drop(labels=['content'],axis=1)
    return data

  def get_transform_spec(self):
    return TransformSpec(
      self.transform_row,
      edit_fields=[('X',np.float32,(3,224,224),False)], #(name, numpy_dtype, shape, is_nullable)
      selected_fields=['X','label']
    )

SparkDatasetConverterにはmake_torch_dataloader()というメソッドがあり、このメソッドでTorchDatasetContextManagerクラスのオブジェクトが得られます。 そしてこのオブジェクトの__enter__()を呼び出すとPyTorchのDataLoaderが得られます。使用が終わったら__exit__()を呼び出してリセットする必要があります。 Petastormのドキュメントに掲載されている使用例ではwith句で主に処理しているため、明示的に__enter__()___exit__()を呼び出す必要はないのですが、この実装では少し処理が入り組んでいるので明示的に呼び出すようにしました。

make_torch_dataloader()では様々なパラメータを指定することが出来ます。特にポイントになるのがnum_epochs=Noneの指定です。このように指定して生成したDataLoaderはepochの制限がなく、無限にデータを生成出来るようになります。 この指定が特に重要なのが分散学習をする時で、もしこの指定が無いとDataLoaderの最後のバッチに含まれるサンプル数がバッチサイズよりも小さくなってしまう可能性があり、それによって分散学習時に不整合が発生してしまうようです。

LightningModule

モデルの学習ステップなどを実装するLightningModuleについては入力処理を特に意識せずに作ることが出来ます。

import torch
from torch import nn
from torch.nn.functional import cross_entropy

class Food101Module(pl.LightningModule):
  def __init__(self, model, lr):
    super().__init__()
    self.model = model
    self.lr = lr
  
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_idx):
    x = batch['X']
    y = batch['label']
    y_hat = self.model(x)
    loss = cross_entropy(y_hat, y)
    self.log('train_loss',loss, prog_bar=True, on_step=True)
    print('train_loss:{}'.format(loss))
    return loss
  
  def validation_step(self, batch, batch_idx):
    x = batch['X']
    y = batch['label']
    y_hat = self.model(x)
    loss = cross_entropy(y_hat, y)
    print('valid_loss:{}'.format(loss))
    self.log('valid_loss',loss, prog_bar=True, on_step=False, on_epoch=True)

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.lr)

Trainer

学習処理を担当するPyTorch-LightningのTrainerに関する設定ですが、こちらはいくつかパラメータの設定に注意点があります。

trainer = pl.Trainer(
  accelerator="gpu", 
  devices=2,
  max_epochs=5,
  strategy='dp',
  limit_train_batches=10,       
  log_every_n_steps=1,
  val_check_interval=10,
  num_sanity_val_steps=0,
  limit_val_batches=5,
  reload_dataloaders_every_n_epochs=1
)

DataLoaderが無限にデータを生成することからlimit_train_batcheslimit_val_batchesを使ってステップ数でepochを区切るようにします。val_check_intervalは検証が実行される間隔をステップ数で指定しますが、epochの終わりに実行するよう、limit_train_batchesと同じ値を設定します。

num_sanity_val_stepsは学習前に行う検証処理のステップ数を指定するのですが、ここで実行したステップは本番の検証ステップでリセットされない(事前検証で進んだステップの後から開始される)ようなので、実行しないように0を設定します。

また今回の内容とは関係がないのですが、PyTorch-Lightningで学習を行うと、よく以下のエラーを発生させてしまいます。

ProcessExitedException: process 1 terminated with signal SIGSEGV

このエラーはdevicesに1より大きい値を設定しているにも関わらず、strategyを指定していない場合に発生するのですが、よく原因を忘れて何度も調べてしまうので忘れないようにこちらに記載しておこうと思います・・・。

学習実行

あとは学習を実行します。事前学習済みのResNet50をダウンロードし、出力層をカスタマイズ、MLflowのExperimentに関する設定など行っています。

from torchvision import models
import mlflow.pytorch

resnet50 = models.resnet50(pretrained=True)
resnet50.fc = nn.Linear(resnet50.fc.in_features, len(labels))

food101_model = Food101Module(resnet50, 0.02)
food101_datamodule = Food101DataModule(
    train_converter, 
    valid_converter,
    test_converter,
    batch_size=64
)

experiment_id = 'xxxxx'

with mlflow.start_run(experiment_id=experiment_id) as run:
  mlflow.pytorch.autolog()
  trainer.fit(model=food101_model, datamodule=food101_datamodule)

今回はいったんここまで試してみました。気になる入力処理にかかる時間の測定は、次回以降でまたご紹介させて頂きます。

まとめ

ということで、今回はSpark上で効率よく深層学習モデルを学習させるため、データの入力部分に見直しをかけている話をご紹介しました。具体的にはSparkで効率的に扱うことが出来るParquetフォーマットでファイルを出力し、PetastormというライブラリでPyTorchのDataLoaderで読み込む、という方法です。なんとか動かすところまで出来たので、画像データをそのまま読み込ませた場合とParquetに変換した場合とでの処理時間の比較など、これから調べていきたいと思います!