CCCMKホールディングス TECH LABの Tech Blog

TECH LABのエンジニアが技術情報を発信しています

ブログタイトル

StreamlitとAzure OpenAI Serviceを使って自然言語でグラフを描けるアプリを作ってみました。

CCCMKホールディングス TECH LAB三浦です。

10月になりました。ちょうど過ごしやすい時期で、一年の中で一番好きな時期かもしれません。外を歩いていても気持ちがいいので、今のうちに色々なところに出かけたいな、と思っています。

以前Streamlitを使ってChatアプリを作成する方法について調べてまとめました。

techblog.cccmk.co.jp

今回はその内容を少し発展させて、次のようなアプリを作ってみました。

アプリケーションを立ち上げてログインすると、次のような画面が表示され、

📈グラフを描画するAI📊

CSVファイルをアップロードすることが出来ます。

CSVをアップロードします。

ここで使用したデータは、Hugging FaceでDatasetsに公開されているこちらのdatasetです。

huggingface.co

Chatメッセージの入力欄に、datasetに関する質問を入力することが出来て、

Chatメッセージの入力

結果を表示することが出来ます。

プレビュー表示

積み上げ棒グラフを表示したり、

積み上げ棒グラフの表示

円グラフを表示したり、

円グラフの表示

ヒートマップを表示することが出来ます。

ヒートマップの表示

面白いですよね。私はデータ可視化のコードを書くのが苦手で、いつもものすごく時間がかかってしまうのですが、これなら簡単なグラフならすぐに描画出来るし、描画用のコードも表示しているので本当にグラフが間違っていないのかの確認も出来ます。

今回はこのアプリケーションをどのように作成したのか、ご紹介したいと思います。

仕組み

仕組みはそれほど複雑ではなく、ほぼChatGPT(Azure OpenAI Serviceの"gpt-35-turbo-16k")の汎用性を活用したものになっています。アップロードされたCSVファイルの中身についての簡単な情報と、ユーザーからの質問にplotlyのグラフを描画するpythonのコードで回答してほしい旨を含めたベースのプロンプトを用意し、質問文を埋め込んでgpt-35-turbo-16kに問合せ、返ってきた回答に含まれるコードと、そのコードをpythonの組み込み関数exec()で実行した結果を描画する、という流れです。

以下がその流れを簡単にまとめた図になります。

簡単な処理の流れ

プロンプトテンプレートの紹介

最初に今回作成したプロンプトテンプレートをご紹介します。2つ作成しており、1つはChatGPTに与えるシステムメッセージ用、もう1つがユーザー入力メッセージ用です。

# system message用
SYSTEM_PROMPT="""あなたはユーザーを支援するアシスタントです。
pythonのpandas dataframeを加工してグラフの生成を支援したり、会話に対して回答してください。
"""

# human message用
PROMPT = """pythonのpandas dataframeがdfという変数に格納されています。
このdfを使って与えられた質問に対してPythonのライブラリplotlyを使用してグラフを作成してください。
import plotly.express as px
import plotly.graph_objects as go
はすでに実行済みなので、コードに含めてはいけません。import文を追加するコードに含めてはいけません。

もしdfとまったく関係のない質問の場合は自然な会話をしてください。
回答は以下の形式で作成してください。

type: もし回答がpythonのコードの場合は"code"を、自然な会話の場合は"talk"を指定してください。
content: pythonのコードか自然な会話のテキストを入力してください。コードの場合はimport文を含めてはいけません。

dfに含まれるカラムと先頭の10行についての情報は以下の通りです。

# カラム一覧(df.columns)
{columns}

# 先頭10行(df.head().to_json())
{header}

# 例
質問: is_redの値で棒グラフを作成してください。
回答:
type:code
content:
temp = df.copy()
temp["count"] = 1
temp = temp.groupby("is_red").sum()[["count"]].reset_index()
fig = px.bar(temp, x='is_red', y='count', height=400)


質問:こんにちは!
回答:
type:talk
content:
こんにちは!なんでも聞いてください!

それでははじめてください!

import文を、追加するコードに含めてはいけません。

質問: {question}
回答:

"""

システムメッセージ用は特筆すべき点はないのですが、ユーザーメッセージ用はいくつかポイントがあります。

import文実行の禁止

このプロンプトで生成されるコードをpythonのexec()で実行することにしているのですが、その際意図していないコードが実行されることを防ぐため、不必要なモジュールのインポートを禁止します。exec()の実行時に使用可能なモジュールの制限をかけ、import文の実行も禁止するようにするため、生成されるコードにimport文が含まれているとエラーが発生してしまいます。これを防ぐため、import文の使用禁止を何度もプロンプトの中で言及しています。

質問文に対するモードの切り替え

プロンプトの中に次のような指示を加え、与えられたデータに対する質問か、それ以外かでモードを切り替えるようにしました。

もしdfとまったく関係のない質問の場合は自然な会話をしてください。

どちらのモードで生成された回答かは、回答に含まれる"type"の部分を見ることで分かるようにしました。モードは"code"と"talk"の2つがあり、codeはPythonのコードを生成した場合でtalkは会話の場合です。

talkモードでの会話の様子

Chatメッセージ描画用の処理

今回のアプリケーションで一番悩んだのがChatメッセージの描画処理です。特に"gpt-35-turbo-16k"からの応答(assistantメッセージ)の描画の処理に色々悩みました。描画処理は関数にまとめました。

def draw_chat_history(df):
    """
    UI上にこれまでのChat履歴を描画する。
    df: 現在解析対象のpandas dataframe
    """
    chat_history = st.session_state["chat_history"]

    for chat in chat_history:
        if chat["role"] == "user":
            # Userのメッセージの場合はテキストでメッセージを描画
            message = st.chat_message("user")
            message.write(chat["content"])
        elif chat["role"] == "assistant":
            # Assistantのメッセージの場合、typeがcodeかそれ以外かで描画方法を分ける
            message = st.chat_message("assistant")
            if chat["content"].startswith('type:code'):
                # codeの場合、ソースコードの部分を抜き出す。
                code = chat["content"].split("content:")[-1]
                message.code(code) #ソースコードを描画
                
                # ソースコードをexec()で実行する
                globals_for_exec = {"df":df,"px":px,"go":go,"dict":dict,"list":list,"__builtins__": {}}
                locals_from_exec = {}
                exec(code,globals_for_exec,locals_from_exec)
                
                # exec()内でfig変数にグラフが格納されるため、取り出してグラフ描画
                message.plotly_chart(locals_from_exec["fig"], use_container_width=True)
            else:
                # typeがcode以外の時はメッセージをテキストとして描画
                text = chat["content"].split("content:")[-1].replace('"','')    
                message.write(text)
        else:
            continue

この中で以下の箇所が生成されたコードを実行するところです。

globals_for_exec = {"df":df,"px":px,"go":go,"dict":dict,"list":list,"__builtins__": {}}
locals_from_exec = {}
exec(code,globals_for_exec,locals_from_exec)

exec()の第二引数に使用可能なモジュールを指定し、それ以外をexec()で実行される処理の中では使用できないようにしています。"__builtins__": {}を指定することでpythonの組み込みモジュールについても基本的に使用できないように制限をかけました。これでimport文実行時に呼ばれる組み込み関数__import__()も実行不可になります。そしてexec()の中で生成されたグラフオブジェクトfigは 辞書型変数locals_from_execにキー"fig"に紐づいて格納されます。

メインの処理

アプリケーションのメインの処理は以下の様に実装しました。思っていた以上にシンプルに書けました。LLMsの強力なパワーのおかげだと思います。

def main():
    """
    アプリケーションのメインモジュール
    """
    st.title("📈グラフを描画するAI📊")
    
    upload_file = st.file_uploader(
        label="CSVファイルをアップロード",
        type="csv")
    
    if upload_file and "df" not in st.session_state:
        # ファイルがアップロードされたらpandas dataframeにロードする
        df = pd.read_csv(upload_file)
        
        # PROMPTのロード
        base_prompt = PromptTemplate.from_template(PROMPT)
        base_prompt = base_prompt.partial(
            columns=str(df.columns),
            header=str(df.head().to_json()))
        base_prompt = HumanMessagePromptTemplate(prompt=base_prompt)
        system_prompt = SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT)
        
        # 各種変数をst.session_stateにセットする
        st.session_state["system_messages"] = system_prompt
        st.session_state["chat_history"] = []
        st.session_state["base_prompt"] = base_prompt
        st.session_state["df"] = df

    user_input = st.chat_input() # Chatメッセージ入力欄
    
    if "df" in st.session_state and user_input:
        # データのロードがすみ、Chatメッセージが入力された
        
        # st.session_stateからの変数のロード
        input_prompt = st.session_state["base_prompt"]
        df = st.session_state["df"]
        chat_history = st.session_state["chat_history"]
        system_prompt = st.session_state["system_messages"]
        
        # LLMs(gpt-35-turbo-16k)からのレスポンスを取得
        chat_prompt = ChatPromptTemplate.from_messages([system_prompt, input_prompt])
        chain = LLMChain(llm=llm,prompt=chat_prompt)
        assistant_msg = chain.run(question=user_input)
        
        # chat履歴に追加
        chat_history.append({"role":"user","content":user_input})
        chat_history.append({"role":"assistant","content":assistant_msg})
        st.session_state["chat_history"] = chat_history
        
        # chat履歴の描画
        draw_chat_history(df)

まとめ

ということで、今回はCSVファイルをアップロードし、質問を入力するとその回答をグラフで描画してくれるアプリケーションをStreamlitとAzure OpenAI Serviceで作ってみました。プロンプトに含めた例示は本当に最小限のものだったのですが、それでも様々なグラフ描画に対応出来ることが分かり、びっくりしました。これでいろんなグラフをどんどん描くことが出来そうです。

それからStreamlitのchat関連の機能は直感的に使えるのにも関わらず、UIのデザインがきれいでいいな、と思いました!