こんにちは、CCCMKホールディングス技術開発の三浦です。
最近寒い日が続いています。寒いと温かい飲み物が欲しくなりますが、近ごろは緑茶を飲むようになりました。お湯を入れたらすぐに飲むことが出来る粉末タイプのものもあって、気軽に楽しむことが出来ます。
今回の記事は分散処理フレームワークSpark周りについて調べた内容です。普段深層学習モデルの分散学習をDatabricksを通じ、Sparkクラスタで行っています。その中で最近少し引っかかっていたのが画像やテキストなどのモデル学習用データを読み取る処理がボトルネックになっている点でした。この部分をどう改善すれば良いのかなかなか分かりませんでした。
今のデータの入力処理は特にSparkの特徴を活かしきれているとは言えず、TensorFlowやPyTorchのDataLoaderを通じて都度画像ファイルやテキストファイルを読み込んでモデルに入力させています。Spark周辺について色々調べてみると、画像やテキストファイルをそのまま扱うのではなく、Sparkの処理に適した"Parquet"というファイルフォーマットに変換することで効率的にデータの入力処理を行うことが出来るようです。
深層学習モデルにParquet形式のデータをDataLoader経由で入力させる方法として、"Petastorm"というライブラリを利用する方法があります。今回はまず画像データをParquetファイルに変換し、Petastormを使ってPyTorchのDataloaderで取り込めるようにし、PyTorch-Lightningでモデルを学習させてみるところまで試してみました。
まだ調査不足なところもあり、とりあえず動かせた!レベルに留まっていますので、ご参考までにご覧いただければ幸いです
使用したデータセット
今回の検証に使用したデータセットはFood-101という、101種類の料理の画像で構成されるデータセットです。
それぞれのクラスごとに1,000枚ずつの画像が含まれています。うち750枚が学習用に設定されているので、学習用のデータは合計750,750枚75,750枚になります。
torchvisionのDatasetsの中にもこのデータセットは含まれおり、ライブラリを通じてダウンロードして使用することが出来ます。
from torchvision import datasets food101_dataset = datasets.Food101(root='./',download=True)
root
で指定したディレクトリに"food-101"というディレクトリが生成され、その中の"images"にクラスごとに画像が格納されます。
このデータを使って以降進めていきます。
SparkのDataFrameとしてデータセットを読み込む
まずこの画像データセットをSparkのDataFrameとして読み込んでみます。
import pyspark ds = spark.read.format('binaryFile').load('./food-101/images/*') display(ds)
DataFrameds
のpath
に画像ファイルのパス、content
に画像のバイナリデータが格納されます。
とりあえずデータセットをそのままSpark DataFrameに読み込んでみましたが、処理に必要になるのは画像データとその画像が対応するクラスのインデックスです。データセットを読み込み、必要な加工もDataFrameを生成するタイミングで施すことが出来ます。
import pyspark from pyspark.sql.functions import col, udf from pyspark.sql.types import LongType ''' パスに含まれるクラス名からクラスのインデックスに変換する ''' def path_to_index(path): return labels.index(path.split('/')[-2]) #Sparkで使用するユーザー定義関数にする path2indexUDF = udf(lambda x:path_to_index(x),LongType()) train_dataset, valid_dataset, test_dataset = spark.read.format('binaryFile')\ .load('./food-101/images/*')\ .withColumn('label',path2indexUDF(col('path')))\ .select('label','content')\ .randomSplit([0.7, 0.2,0.1], seed=13) train_dataset = train_dataset.repartition(4) valid_dataset = valid_dataset.repartition(4) test_dataset = test_dataset.repartition(4)
上記の処理の中で70%, 20%, 10%の割合で学習用(train_dataset
)と検証用(valid_dataset
)、そしてテスト用(test_dataset
)にデータセットを分割する処理も行っています。
DataFrameのカラムpath
に適用するユーザー定義関数(UDF)path2indexUDF
の返す値の型はLongType()
で指定しています。IntType()
なども指定できるのですが、IntType()
を指定すると学習時にPyTorchで損失を計算する時にエラーが発生しました。
repartition()
を実行することで生成されるParquetファイルの数が調整出来るようです。小さいサイズのファイルをたくさん作るのは処理効率に悪影響を及ぼすようですが、具体的にどれくらいの値を目安にしたらよいのかはまだ分かっておらず、今後色々試していこうと思います。
Parquetファイルが出力されるCacheディレクトリを指定する
SparkのDataFrameをParquetファイルに出力する際の、出力先となるCacheディレクトリを指定します。
from petastorm.spark import SparkDatasetConverter CACHE_DIR = "file:///dbfs/path/to/cache" spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, CACHE_DIR)
Spark DataFrameからParquetへの変換
Petastormのmake_spark_converter
を使ってDataFrameをParquetファイルに変換します。
from petastorm.spark import make_spark_converter train_converter = make_spark_converter(train_dataset) valid_converter = make_spark_converter(valid_dataset) test_converter = make_spark_converter(test_dataset)
ここの処理はかなり時間がかかります。
PyTorch-Lightningを使った学習処理の実装
ここからはPetastormを通してデータを読み取りながら深層学習モデルを学習する処理を実装していきます。こちらのdatabricksのブログを参考にしました。
このブログの中にリンクが貼られているdatabricks notebook "Building the PyTorch Lightning Modules"と" Main Execution notebook"を見ながら処理を作っていきました。
今回学習させるモデルは事前学習済みのResNet50を使用したFood-101による料理画像の分類モデルです。
LightningDataModule
PyTorch-LightningのLightingDataModuleを継承したクラスを実装します。このクラスはインスタンス化する際に先ほど生成したtrain_converter
などのSparkDatasetConverterオブジェクトを受け取り、内部でPyTorchのDataLoaderを生成します。
class Food101DataModule(pl.LightningDataModule): def __init__(self, train_converter,valid_converter,test_converter, batch_size, device_id=0, device_count=1): super().__init__() self.train_converter = train_converter self.valid_converter = valid_converter self.test_converter = test_converter self.train_dataloader_context = None self.valid_dataloader_context = None self.test_dataloader_context = None self.batch_size = batch_size self.device_id = device_id self.device_count = device_count def train_dataloader(self): if self.train_dataloader_context: self.train_dataloader_context.__exit__(None, None, None) self.train_dataloader_context = self.train_converter.make_torch_dataloader( batch_size=self.batch_size, num_epochs=None, transform_spec=self.get_transform_spec(), cur_shard=self.device_id, shard_count=self.device_count, ) return self.train_dataloader_context.__enter__() def val_dataloader(self): if self.valid_dataloader_context: self.valid_dataloader_context.__exit__(None, None, None) self.valid_dataloader_context = self.valid_converter.make_torch_dataloader( batch_size=self.batch_size, num_epochs=None, transform_spec=self.get_transform_spec(), cur_shard=self.device_id, shard_count=self.device_count, ) return self.valid_dataloader_context.__enter__() def test_dataloader(self): if self.test_dataloader_context: self.test_dataloader_context.__exit__(None, None, None) self.test_dataloader_context = self.test_converter.make_torch_dataloader( batch_size=self.batch_size, num_epochs=None, transform_spec=self.get_transform_spec(), cur_shard=self.device_id, shard_count=self.device_count, ) return self.test_dataloader_context.__enter__() def teardown(self, stage=None): if self.train_dataloader_context: self.train_dataloader_context.__exit__(None, None, None) if self.valid_dataloader_context: self.valid_dataloader_context.__exit__(None, None, None) if self.test_dataloader_context: self.test_dataloader_context.__exit__(None, None, None) def transform_row(self, data): transformers = [transforms.Lambda(lambda x: Image.open(io.BytesIO(x)))] transformers.extend([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trans = transforms.Compose(transformers) data['X'] = data['content'].map(lambda x:trans(x).numpy()) data = data.drop(labels=['content'],axis=1) return data def get_transform_spec(self): return TransformSpec( self.transform_row, edit_fields=[('X',np.float32,(3,224,224),False)], #(name, numpy_dtype, shape, is_nullable) selected_fields=['X','label'] )
SparkDatasetConverterにはmake_torch_dataloader()
というメソッドがあり、このメソッドでTorchDatasetContextManager
クラスのオブジェクトが得られます。
そしてこのオブジェクトの__enter__()
を呼び出すとPyTorchのDataLoaderが得られます。使用が終わったら__exit__()
を呼び出してリセットする必要があります。
Petastormのドキュメントに掲載されている使用例ではwith
句で主に処理しているため、明示的に__enter__()
や___exit__()
を呼び出す必要はないのですが、この実装では少し処理が入り組んでいるので明示的に呼び出すようにしました。
make_torch_dataloader()
では様々なパラメータを指定することが出来ます。特にポイントになるのがnum_epochs=None
の指定です。このように指定して生成したDataLoaderはepochの制限がなく、無限にデータを生成出来るようになります。
この指定が特に重要なのが分散学習をする時で、もしこの指定が無いとDataLoaderの最後のバッチに含まれるサンプル数がバッチサイズよりも小さくなってしまう可能性があり、それによって分散学習時に不整合が発生してしまうようです。
LightningModule
モデルの学習ステップなどを実装するLightningModuleについては入力処理を特に意識せずに作ることが出来ます。
import torch from torch import nn from torch.nn.functional import cross_entropy class Food101Module(pl.LightningModule): def __init__(self, model, lr): super().__init__() self.model = model self.lr = lr def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x = batch['X'] y = batch['label'] y_hat = self.model(x) loss = cross_entropy(y_hat, y) self.log('train_loss',loss, prog_bar=True, on_step=True) print('train_loss:{}'.format(loss)) return loss def validation_step(self, batch, batch_idx): x = batch['X'] y = batch['label'] y_hat = self.model(x) loss = cross_entropy(y_hat, y) print('valid_loss:{}'.format(loss)) self.log('valid_loss',loss, prog_bar=True, on_step=False, on_epoch=True) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr)
Trainer
学習処理を担当するPyTorch-LightningのTrainerに関する設定ですが、こちらはいくつかパラメータの設定に注意点があります。
trainer = pl.Trainer( accelerator="gpu", devices=2, max_epochs=5, strategy='dp', limit_train_batches=10, log_every_n_steps=1, val_check_interval=10, num_sanity_val_steps=0, limit_val_batches=5, reload_dataloaders_every_n_epochs=1 )
DataLoaderが無限にデータを生成することからlimit_train_batches
とlimit_val_batches
を使ってステップ数でepochを区切るようにします。val_check_interval
は検証が実行される間隔をステップ数で指定しますが、epochの終わりに実行するよう、limit_train_batches
と同じ値を設定します。
num_sanity_val_steps
は学習前に行う検証処理のステップ数を指定するのですが、ここで実行したステップは本番の検証ステップでリセットされない(事前検証で進んだステップの後から開始される)ようなので、実行しないように0を設定します。
また今回の内容とは関係がないのですが、PyTorch-Lightningで学習を行うと、よく以下のエラーを発生させてしまいます。
ProcessExitedException: process 1 terminated with signal SIGSEGV
このエラーはdevices
に1より大きい値を設定しているにも関わらず、strategy
を指定していない場合に発生するのですが、よく原因を忘れて何度も調べてしまうので忘れないようにこちらに記載しておこうと思います・・・。
学習実行
あとは学習を実行します。事前学習済みのResNet50をダウンロードし、出力層をカスタマイズ、MLflowのExperimentに関する設定など行っています。
from torchvision import models import mlflow.pytorch resnet50 = models.resnet50(pretrained=True) resnet50.fc = nn.Linear(resnet50.fc.in_features, len(labels)) food101_model = Food101Module(resnet50, 0.02) food101_datamodule = Food101DataModule( train_converter, valid_converter, test_converter, batch_size=64 ) experiment_id = 'xxxxx' with mlflow.start_run(experiment_id=experiment_id) as run: mlflow.pytorch.autolog() trainer.fit(model=food101_model, datamodule=food101_datamodule)
今回はいったんここまで試してみました。気になる入力処理にかかる時間の測定は、次回以降でまたご紹介させて頂きます。
まとめ
ということで、今回はSpark上で効率よく深層学習モデルを学習させるため、データの入力部分に見直しをかけている話をご紹介しました。具体的にはSparkで効率的に扱うことが出来るParquetフォーマットでファイルを出力し、PetastormというライブラリでPyTorchのDataLoaderで読み込む、という方法です。なんとか動かすところまで出来たので、画像データをそのまま読み込ませた場合とParquetに変換した場合とでの処理時間の比較など、これから調べていきたいと思います!