こんにちは。テックラボの高橋です。
pytorchにtorch.compileという機能があることをご存知でしょうか?
torch 2.0から導入されたこの機能を利用することで、推論処理や学習処理を高速化できるとのことです。
今回はNVIDIA A100を用いて、torch.compileがどのくらい効果があるか検証してみました。
環境
- pytorch 2.6
- GPU NVIDIA A100 80G
- ubuntu 20.04.6
- nvidia-docker 24.0.9-1
- モデル tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3
torch.compileとは
はじめに、torch.compileについて簡単に説明します。
torch.compileは実行中にモデルを最適化し高速化を図る技術です。
上図中にあるTorchDynamo、TorchInductor、Tritonについて各処理を見ていきます。
torch.compileを実行すると、まずTorchDynamoがPythonバイトコードをFXグラフと呼ばれる計算グラフに変換します。
この際の内部構造はdypyfというライブラリを用いると可視化するこができます。
例えば以下のような処理があるとします。
def toy_example(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b
この関数をコンパイルしたものをdepyfで可視化すると、個々のグラフのモジュールが人間に読みやすい形式で、コメント付きで生成されます。 グラフモジュールのひとつを見てみましょう。
from __future__ import annotations import torch class GraphModule(torch.nn.Module): def forward(self, L_a_: "f32[10]", L_b_: "f32[10]"): l_a_ = L_a_ l_b_ = L_b_ # File: /workspace/check_depyf.py:8 in toy_example, code: x = a / (torch.abs(a) + 1) abs_1: "f32[10]" = torch.abs(l_a_) add: "f32[10]" = abs_1 + 1; abs_1 = None x: "f32[10]" = l_a_ / add; l_a_ = add = None # File: /workspace/check_depyf.py:9 in toy_example, code: if b.sum() < 0: sum_1: "f32[]" = l_b_.sum(); l_b_ = None lt: "b8[]" = sum_1 < 0; sum_1 = None return (x, lt)
生成されたコメントを読むと、元の関数の以下の箇所がモジュール化されていることがわかります。
x = a / (torch.abs(a) + 1) if b.sum() < 0:
その後、このFXグラフをTorchInductorが最適化し、TritonがNVIDIAで処理を実行するための関数であるCUDAカーネルを生成します。
実装
torch.compileのチュートリアルを参考に、以下のようなコードを実行してみます。
import os import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer import matplotlib.pyplot as plt os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups def timed(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = fn() end.record() torch.cuda.synchronize() return result, start.elapsed_time(end) / 1000 device = "cuda" ckpt = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3" model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16) model.to(device) tokenizer = AutoTokenizer.from_pretrained(ckpt) model.generation_config.max_length = 128 prompts = ["なぜ犬はこんなにもかわいいの?"] * 10 # without torch.compile timings_without_compile = [] for prompt in prompts: inputs = tokenizer(prompt, return_tensors="pt").to("cuda") _, elapsed_time = timed(lambda: model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id)) print(elapsed_time) timings_without_compile.append(elapsed_time) torch._dynamo.reset() # compile model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) model.generation_config.cache_implementation = "static" # with torch.compile torch.compiler.cudagraph_mark_step_begin() timings_with_compile = [] for prompt in prompts: inputs = tokenizer(prompt, return_tensors="pt").to("cuda") _, elapsed_time = timed(lambda: model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id)) print(elapsed_time) timings_with_compile.append(elapsed_time)
結果は以下のようになります。
橙色がコンパイル有りの結果ですが、最初の2回だけコンパイルのために時間がかかっています。
そこで、下記のようにして最初の2回を抜いた実行時間の平均を取ると、
print(f"{np.array(timings_without_compile[2:]).mean()=}") print(f"{np.array(timings_with_compile[2:]).mean()=}")
コンパイル無しの平均値: 3.12秒
コンパイル有りの平均値: 1.79秒
となりました。1.7倍程度のスピードアップですね。
上記は同じプロンプトを10回実行していましたが、 試しにpromptsの内容を以下のように10種類に変更してみます。
prompts = [ "なぜ犬はこんなにもかわいいの?", "なぜ猫は夜行性なの?", "宇宙の果てには何があるの?", "人間はなぜ夢を見るの?", "海の深さはどれくらい?", "なぜ空は青いの?", "恐竜はどのようにして絶滅したの?", "なぜ音楽を聴くと感動するの?", "人類はいつ火星に住むことができるの?", "なぜ植物は光合成をするの?", "どうして鳥は飛ぶの?" ]
結果はこちらです。
コンパイル有りの場合、文章を変更すると再コンパイルが走る場合があるようです。
vLLMとの比較
vLLMはLLMを高速に実行するためのライブラリです。 PagedAttentionという仕組み等を利用して高速に推論処理を行うことができます。
こちらを以下のコードで実行時間を計測してみます。
import os import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import LLM, SamplingParams os.environ["TOKENIZERS_PARALLELISM"] = "false" def timed(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = fn() end.record() torch.cuda.synchronize() return result, start.elapsed_time(end) / 1000 device = "cuda" ckpt = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3" prompts = ["なぜ犬はこんなにもかわいいの?"] * 10 model = LLM( model=ckpt, tokenizer=ckpt, dtype="float16" ) sampling_params = SamplingParams( temperature=0.8, top_p=0.95, max_length=128 ) timings_with_vllm = [] for prompt in prompts: _, elapsed_time = timed(lambda: model.generate(prompt, sampling_params)) print(elapsed_time) timings_with_vllm.append(elapsed_time)
処理の平均値は1.61秒になりました。
デフォルトのtorch、torch.compileと比較すると以下のようになります。
GPU:NVIDIA A100 80G CUDA:12.6 pytorch:2.6 モデル:Llama-3.1-Swallow-8B-Instruct-v0.3 量子化:float16 試行:同一プロンプト x 10回
若干vLLMのほうが早いみたいですね。
今回はtorch.compile側がtorch2.6をほぼデフォルトパラメータで実行しており、FlashAttention等を無効化していません。 また、vLLMもパラメータ最適化は行っていないので、設定によってかなり結果は変わりうると思われます。
他の方の記事等を読むと、vLLMがより高速になる例があるようですので、 あくまでこのコード・環境での場合の値ということでご了承いただければと思います。
おわりに
本記事ではNVIDIA A100でのtorch.compileの検証と、vLLMとの比較を行ってみました。
vLLMなどの推論フレームワークのベンチマークを取った"LLM-Inference-Bench: Inference Benchmarking of Large Language Models on AI Accelerators"という論文によると、バッチ推論において特にvLLMは性能が良さそうです。他の条件についても、引き続き調査していきたいと思います。
参考
PyTorch 2.0の新機能「torch.compile」使ってみた - まったり勉強ノート
Introduction to torch.compile — PyTorch Tutorials 2.6.0+cu124 documentation