CCCMKホールディングス TECH LABの Tech Blog

TECH LABのエンジニアが技術情報を発信しています

ブログタイトル

NVIDIA A100でのtorch.compileの効果を検証

こんにちは。テックラボの高橋です。

pytorchにtorch.compileという機能があることをご存知でしょうか?

torch 2.0から導入されたこの機能を利用することで、推論処理や学習処理を高速化できるとのことです。

今回はNVIDIA A100を用いて、torch.compileがどのくらい効果があるか検証してみました。

環境

  • pytorch 2.6
  • GPU NVIDIA A100 80G
  • ubuntu 20.04.6
  • nvidia-docker 24.0.9-1
  • モデル tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3

torch.compileとは

はじめに、torch.compileについて簡単に説明します。

pytorch.org

torch.compileは実行中にモデルを最適化し高速化を図る技術です。

上図中にあるTorchDynamo、TorchInductor、Tritonについて各処理を見ていきます。

torch.compileを実行すると、まずTorchDynamoがPythonバイトコードをFXグラフと呼ばれる計算グラフに変換します。

この際の内部構造はdypyfというライブラリを用いると可視化するこができます。

github.com

例えば以下のような処理があるとします。

def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

この関数をコンパイルしたものをdepyfで可視化すると、個々のグラフのモジュールが人間に読みやすい形式で、コメント付きで生成されます。 グラフモジュールのひとつを見てみましょう。

from __future__ import annotations
import torch
class GraphModule(torch.nn.Module):
    def forward(self, L_a_: "f32[10]", L_b_: "f32[10]"):
        l_a_ = L_a_
        l_b_ = L_b_
        
         # File: /workspace/check_depyf.py:8 in toy_example, code: x = a / (torch.abs(a) + 1)
        abs_1: "f32[10]" = torch.abs(l_a_)
        add: "f32[10]" = abs_1 + 1;  abs_1 = None
        x: "f32[10]" = l_a_ / add;  l_a_ = add = None
        
         # File: /workspace/check_depyf.py:9 in toy_example, code: if b.sum() < 0:
        sum_1: "f32[]" = l_b_.sum();  l_b_ = None
        lt: "b8[]" = sum_1 < 0;  sum_1 = None
        return (x, lt)
        

生成されたコメントを読むと、元の関数の以下の箇所がモジュール化されていることがわかります。

   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:

その後、このFXグラフをTorchInductorが最適化し、TritonがNVIDIAで処理を実行するための関数であるCUDAカーネルを生成します。

実装

pytorch.org

torch.compileのチュートリアルを参考に、以下のようなコードを実行してみます。

import os

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

device = "cuda"
ckpt = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"

model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(ckpt)

model.generation_config.max_length = 128

prompts = ["なぜ犬はこんなにもかわいいの?"] * 10

# without torch.compile
timings_without_compile = []
for prompt in prompts:
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    _, elapsed_time = timed(lambda: model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id))
    print(elapsed_time)
    timings_without_compile.append(elapsed_time)


torch._dynamo.reset()

# compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"

# with torch.compile
torch.compiler.cudagraph_mark_step_begin()

timings_with_compile = []
for prompt in prompts:
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    _, elapsed_time = timed(lambda: model.generate(**inputs, do_sample=False, pad_token_id=tokenizer.eos_token_id))
    print(elapsed_time)
    timings_with_compile.append(elapsed_time)

結果は以下のようになります。

橙色がコンパイル有りの結果ですが、最初の2回だけコンパイルのために時間がかかっています。

そこで、下記のようにして最初の2回を抜いた実行時間の平均を取ると、

print(f"{np.array(timings_without_compile[2:]).mean()=}")
print(f"{np.array(timings_with_compile[2:]).mean()=}")
  • コンパイル無しの平均値: 3.12秒

  • コンパイル有りの平均値: 1.79秒

となりました。1.7倍程度のスピードアップですね。

上記は同じプロンプトを10回実行していましたが、 試しにpromptsの内容を以下のように10種類に変更してみます。

prompts = [
    "なぜ犬はこんなにもかわいいの?",
    "なぜ猫は夜行性なの?",
    "宇宙の果てには何があるの?",
    "人間はなぜ夢を見るの?",
    "海の深さはどれくらい?",
    "なぜ空は青いの?",
    "恐竜はどのようにして絶滅したの?",
    "なぜ音楽を聴くと感動するの?",
    "人類はいつ火星に住むことができるの?",
    "なぜ植物は光合成をするの?",
    "どうして鳥は飛ぶの?"
]

結果はこちらです。

コンパイル有りの場合、文章を変更すると再コンパイルが走る場合があるようです。

vLLMとの比較

github.com

vLLMはLLMを高速に実行するためのライブラリです。 PagedAttentionという仕組み等を利用して高速に推論処理を行うことができます。

こちらを以下のコードで実行時間を計測してみます。

import os

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

device = "cuda"
ckpt = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"

prompts = ["なぜ犬はこんなにもかわいいの?"] * 10

model = LLM(
    model=ckpt,
    tokenizer=ckpt,
    dtype="float16"
    )

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_length=128
    )

timings_with_vllm = []
for prompt in prompts:
    _, elapsed_time = timed(lambda: model.generate(prompt, sampling_params))
    print(elapsed_time)
    timings_with_vllm.append(elapsed_time)

処理の平均値は1.61秒になりました。

デフォルトのtorch、torch.compileと比較すると以下のようになります。

GPU:NVIDIA A100 80G​
CUDA:12.6​
pytorch:2.6​
モデル:Llama-3.1-Swallow-8B-Instruct-v0.3​
量子化:float16​
試行:同一プロンプト x 10回

若干vLLMのほうが早いみたいですね。

今回はtorch.compile側がtorch2.6をほぼデフォルトパラメータで実行しており、FlashAttention等を無効化していません。 また、vLLMもパラメータ最適化は行っていないので、設定によってかなり結果は変わりうると思われます。

他の方の記事等を読むと、vLLMがより高速になる例があるようですので、 あくまでこのコード・環境での場合の値ということでご了承いただければと思います。

おわりに

本記事ではNVIDIA A100でのtorch.compileの検証と、vLLMとの比較を行ってみました。

vLLMなどの推論フレームワークのベンチマークを取った"LLM-Inference-Bench: Inference Benchmarking of Large Language Models on AI Accelerators"という論文によると、バッチ推論において特にvLLMは性能が良さそうです。他の条件についても、引き続き調査していきたいと思います。

参考

ローカルLLMの推論速度を高速化する5つの手法と比較評価

PyTorch 2.0の新機能「torch.compile」使ってみた - まったり勉強ノート

https://huggingface.co/meta-llama/Llama-3.1-8B/blob/a2ebf37472e2da04a7f54fcf8fc48ee78e36ec51/torch_compile_example

Introduction to torch.compile — PyTorch Tutorials 2.6.0+cu124 documentation