CCCMKホールディングス TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

FastAPIを使って機械学習モデルの推論機能をAPIで提供する方法を調べてみました!

こんにちは、技術開発の三浦です。

子どもの学校が夏休みに入り、朝の生活のリズムが少し変わりました。リズムが変わると最初は慣れないのですが、続けていると自然と慣れていくもので、1週間くらい経ってこの変化にも慣れてきました。次に訪れる変化は夏休みが終わる頃かな・・・と考えています。

最近、自分が学習させた機械学習のモデルをチームのメンバーに共有し、フィードバックをもらいたいと考えることがありました。モデルの学習はサーバ上のnotebookで行ったので、notebookの環境にアクセスしてもらえれば使えるのですが、複数人で同じnotebookにアクセスすると間違ってnotebookを編集してしまう可能性があるなど、問題もあります。

出来たらAPIのような形で、モデルの推論機能だけを提供できないかなぁと色々調べていました。

今回はそんな要望を実現するために、Pythonの軽量なWebフレームワークFastAPIを使って機械学習モデルの推論機能をAPIで提供する方法について調べたので、ご紹介します。

FastAPI

FastAPIはPython3.6以上で利用できる軽量なWebフレームワークで、WebAPIを構築するのに長けています。

fastapi.tiangolo.com

FastAPIは簡単で軽く、そしてAPIのドキュメントを自動で生成してくれる。以前誰かに教えてもらって、その時からいつか使ってみようと気になっていたフレームワークです。今回が試すのにとてもいい機会だと思い、実際に試してみたのですが、確かにその通りの印象を受けました。それだけでなく、クライアントから受け取ったデータを検証する機能も提供されていることが分かりました。

まずはFastAPIを理解するために、とてもシンプルなWebアプリケーションを作ってみました。次にモデルの推論機能をAPIで提供するアプリケーションを作りました。順を追ってご紹介します。

シンプルなWebアプリケーション

セットアップ

まずはFastAPIを利用するために必要なライブラリをインストールします。

pip install fastapi uvicorn

UvicornはWebサーバを起動するためのアプリケーションです。UvicornでWebサーバを立ち上げてFastAPIで作ったアプリケーションを動かします。

アプリケーションのコード

FastAPIを使ってアプリケーションを作っていきます。このアプリケーションには2つのエンドポイント"/"(root)"/msg"があります。"/"はGETメソッド、"/msg"はPOSTメソッドでリクエストを受け取ります。

from fastapi import FastAPI

app = FastAPI()

@app.get("/")
def root():
    return {"message":"こんにちは"}

@app.post("/msg")
def message(text):
    return {"message":f"{text},こんにちは"}

app = FastAPI()appの変数名は重要で、この変数名がアプリケーションの名称になります。@app.get("/")@app.post("/msg")でAPIのエンドポイントと、対応するメソッド(GETやPOST)を宣言し、リクエストを受けた後の処理はデコレートされている関数の中で定義します。

それではこのWebアプリケーションをUvicornで起動します。以下のコマンドをターミナル上で実行します。

uvicorn main:app --reload

--reloadはWebサーバ立ち上げ中にコードの変更が行われた時、自動的に再読み込みをしてくれるようにするためのオプションです。アプリケーションが重たくなってしまうようなので、開発の時だけ利用するようにし、公開(デプロイ)する時はこのオプションをつけないようにした方が良いようです。

アプリケーションが問題なく起動出来ると、以下のようなメッセージがターミナルに表示されます。

INFO:     Started reloader process [9176] using StatReload
INFO:     Started server process [4556]
INFO:     Waiting for application startup.
INFO:     Application startup complete.INFO:

ターミナル上にURL(http://127.0.0.1:8000)が表示されるので、そこにブラウザからアクセスすると以下のように表示されます。

ブラウザの表示

先ほどコードの中で見た、@app.get("/")でデコレートされている関数root()の処理が実行されていることが分かります。

次にブラウザで"http://127.0.0.1:8000/docs"にアクセスしてみます。

Swagger UI

Swagger UIが開き、このアプリケーションで提供されるAPIについてのドキュメントが表示されます。FastAPIで自動的に生成してくれるのがうれしいです!

Swagger UIではAPIを確認するだけでなく、なんとAPIの動作確認も出来てしまいます。たとえば以下はエンドポイント"/msg"にリクエストを送信したときの動作確認をした時の様子です。

送信するメッセージと、レスポンスを確認できます。

今度は@app.post("/msg")でデコレートされている関数message()が実行されたことが分かります。

これまでAPIの動作確認をするときはcurlコマンドなどを打ったりしていたのですが、Swagger UIは見やすくてとてもいいと思います。テストであれば、このままアプリケーションのUIとして使ってもらってもいいのでは・・・と思いました。

デプロイ

最後に同じネットワーク上の別のマシンからもアプリケーションを利用できるように、アプリケーションをデプロイしてみます。

uvicorn main:app --host 0.0.0.0 --port 80

これで同じネットワーク上の別のマシンからも利用できるようになります。

機械学習モデルを組み込んだWebアプリケーション

シンプルなWebアプリケーションを作ってみて、FastAPIの大まかな使い方が分かってきました。次は機械学習モデルを組み込んだ、APIでモデルの推論機能を提供するWebアプリケーションを作ってみます。ここではテキストが入力されると、それに続く自然なテキストを生成するモデルの機能をAPIで提供するアプリケーションを作ります。

ここで利用するモデルは、以前この記事でも使用したrinna株式会社様がHuggingFaceで公開されている日本語のGPT-2のテキスト生成モデルです。

huggingface.co

テキスト生成モジュール

まずはモデルを組み込んだテキスト生成モジュールを作ります。そのために必要になるライブラリをインストールします。

pip install torch transformers[ja] sentencepiece

そしてFastAPIのアプリケーションのファイルと同じ階層に、以下のようなファイルを作成します。

from transformers import T5Tokenizer, AutoModelForCausalLM

class ChatFunction:
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
        self.tokenizer.do_lower_case = True
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
    
    def generate_msg(self, text, max_length=30, num_return_sequences=1):
        input = self.tokenizer.encode(text, return_tensors="pt")
        print(input)
        output = self.model.generate(
            input, 
            do_sample=True, 
            max_length=max_length, 
            num_return_sequences=num_return_sequences
        )
        msg_list = self.tokenizer.batch_decode(output,skip_special_tokens=True)
        return [m[len(text):] for m in msg_list]


if __name__ == '__main__':
    chat = ChatFunction('rinna/japanese-gpt2-medium')
    text = 'テストだよ。'
    msg = chat.generate_msg(text)
    print(f'input:{text}')
    print(f'response:{msg}')

ファイルを保存したら、モジュール単体でテストしてみます。

python chat_function.py

以下のように出力されることを確認して、正しくモジュールが動いていることを確認します。

input:テストだよ。
response:['あとは、君がこのスレを使ってくれてるかわかんないけどね(笑 これで安心して帰れます!お']

アプリケーションにモジュールを組み込む

今度はこのテキスト生成モジュールをFastAPIのアプリケーションに組み込んでいきます。テキスト生成をする時は、"/msg"エンドポイントにPOSTでリクエストを送るようにします。リクエストのBodyの中に、続きを生成してほしいテキスト(message)と、生成するテキストの長さ(max_length)、生成するテキストの個数(num_return_sequences)を設定できるようにします。

そしてこのリクエストBodyの検証を、pydanticというライブラリを使って実装しています。pydanticBaseModelを継承したMessageクラスを作り、その中にリクエストBodyのデータフォーマットを書きます。あとはmessage()関数の引数の型をMessageクラスに指定します。

from fastapi import FastAPI
from chat_function import ChatFunction
from pydantic import BaseModel

app = FastAPI()
#chat(推論モジュール)のセッティング
chat = ChatFunction('rinna/japanese-gpt2-medium')

#POSTで受け取るデータフォーマットを定義
class Message(BaseModel):
    message: str
    max_length: int = 30 #デフォルト値
    num_return_sequences: int = 1 #デフォルト値

@app.get("/")
def root():
    return {"message":"こんにちは"}

@app.post("/msg")
def message(msg: Message):
    output = chat.generate_msg(
        msg.message,
        msg.max_length,
        msg.num_return_sequences
    )
    return {"message": output}

それでは最後にアプリケーションを実行してみて、エンドポイント"/msg"にPOSTでリクエストを送信し、正しく動作するかSwagger UIで確認してみます。

生成されたテキストがレスポンスで得られています。

ちゃんと動いているみたいです。max_lengthnum_return_sequencesパラメータも指定してみます。

テキストの長さが変わって複数のテキストが返ってきました。

パラメータによってテキストの長さや生成されるテキストの数がコントロール出来ていることが確認出来ました!

まとめ

ということで、今回はPythonの軽量なWebフレームワークFastAPIについて調べ、シンプルなアプリケーションから機械学習モデルの推論機能を提供するアプリケーションまで作ってみました。普段Webアプリケーションを作る機会があまりない自分にとっても、とてもシンプルで分かりやすいと感じました!FastAPIのドキュメントは引き続き目を通していきたいと思います。