CCCMKホールディングス TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

Dashを使って画像アノテーションツールを自作しました!

こんにちは、技術開発の三浦です。

ビジネス英会話を春から続けているのですが、レッスンごとに自分自身のことを英文で話すパートがあります。ビジネス英会話なので、「あなたの今のタスクはどんな役に立ちますか」といった内容の質問が多いのですが、普段の業務でも意識しないとなかなか考えないことなので、そういう点でもいい勉強になっています。

今回はPythonのDashというライブラリを使って簡単な画像アノテーションツールを自作した話をご紹介します!

アノテーションツール自作のきっかけ

機械学習、とりわけ深層学習で扱うような画像やテキストデータに対して、それをモデルにどのように認識してほしいかをラベル付けする行程をアノテーションと呼びます。最近作りたいと考えているモデルは画像に対し、その内容を説明する文章(キャプション)を生成するというものです。以前Transformerを使って画像キャプショニングをした話をこのブログでもご紹介しました。

techblog.cccmk.co.jp

こちらで紹介したモデルはCOCO Datasetというアノテーション済みのデータを使って学習したものですが、元データが英語なので、生成されるキャプションも英語です。このキャプションを日本語で生成出来るモデルを作りたいと考えており、そのためにも日本語のキャプションでアノテーションした学習データを用意しようと思っていました。

英語から日本語への翻訳は、機械翻訳などのサービスを利用して実現出来そうですが、機械翻訳された日本語を元の英語と画像を見ながらさっと確認し、必要に応じて修正できるようなツールが欲しいな・・・と考えていました。アノテーション用のツールは色々と公開されているのですが、今回の自分のニーズに合うものがなかなか見つからず、ならば自分で作ろう、と思い立ちました。

フロントエンド開発の壁・・・

思い立ったのはいいものの、普段あまりフロントエンド開発をしない自分にとって、HTML・CSS・JavaScriptを思い出しつつ書くのは結構時間がかかる作業です。しかもユーザーは自分を含めて社内の数名程度、機能もそこまで求められておらず、どちらかというと早くアノテーションの確認作業に入りたいのです。

一番早くアノテーションツールを作ることが出来る方法はなんだろう、と色々調べて行きついたのが、PythonのDashというライブラリを使う方法でした。

DashはWebアプリを構築し実行するPythonのライブラリで、バックエンドにPythonのFlaskを使用し、フロントエンドにReact.jsやBootstrapを使用します。以前このDashを使って簡単なダッシュボードアプリを作ったことがありました。

techblog.cccmk.co.jp

この時はデータ可視化の用途で利用したのですが、今回はもう少しWebアプリケーション開発寄りの用途でDashを使ってみました。

作ったもの

まずは今回Dashを使って作ったアノテーションツールをご覧ください。

自作アノテーションツール

少し見づらいのですが、下の段にある「<」「>」ボタンをクリックすると前後のデータに移動することが出来ます。別のデータに移動する時には、移動前に表示されていた日本語キャプションをJSONファイルで保存します。そして次のデータを表示する時に保存済みの日本語キャプションJSONファイルがあれば読み込んで表示させています。表示させているデータはCOCO Datasetの2017のものです。

作りが甘いところはたくさんありますが、概ねイメージしていたものを作ることが出来、後は使いながら必要に応じて機能を追加していこうと思っています。今回はDashでこのようなWebアプリを作る際のポイントを、いくつかご紹介します。

必要なコンポーネントを配置する

最初にアプリに必要なコンポーネントを画面上に配置します。他のWebアプリを開発する時もレイアウトやアプリの外観はWebフレームワークBootstrapを使用することで分かりやすく設定できるようになりますが、Dashでもdash-bootstrap-componentsというライブラリを使用することでBootstrapをPythonのコードで使用することが出来ます。

画像ファイルの読み込みは、画像処理ライブラリscikit-imageを使用しました。

from dash import Dash, html, dcc, Input, Output, ctx, State, ALL
import dash_bootstrap_components as dbc
import plotly.express as px
from skimage import io
import json
import os

app = Dash(__name__,external_stylesheets=[dbc.themes.DARKLY])

def get_image(target_num):
    '''
    アノテーション対象の画像番号(target_num)を指定すると
    該当する画像データを読み込んで返す関数
    '''
    image_data = io.imread('data/img/001.jpg')
    fig = px.imshow(image_data)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.update_layout(
        margin=dict(l=5, r=5, t=5, b=5),
    )
    return fig

def make_caption_area(target_num):
    '''
    英語と対応する日本語入力エリアを
    英語キャプション数分作成し、返す
    '''
    en_caption_list = [
        'a wooden desk is shown with a laptop, and various electronics.',
        'the laptop is on the desk left open.',
        'the is a desk with a laptop mouse and cd',
        'a desk with a laptop, keyboard, and mouse',
        'a laptop sitting on a desk near a lamp and plant.'
    ]
    jp_caption_list = [
        ''  for _ in range(len(en_caption_list))
    ]

    caption_area = [
        dbc.Row([
            dbc.Col(en_caption_list[i],width=3),
            dbc.Col(dcc.Textarea(value=jp_caption_list[i]),width=4)
        ]) for i in range(len(en_caption_list))
    ]

    return caption_area

app.layout = html.Div([
        dbc.Row(
            [
                dbc.Col(html.H1('日本語画像キャプショニングツール'))
            ]
        ),
        dbc.Row(
            [
                dbc.Col(dcc.Graph(figure=get_image(0)),width=3),
                dbc.Col(
                    [
                        dbc.Row(
                            [
                                dbc.Col(html.Em('英語'),width=4),
                                dbc.Col(html.Em('日本語'),width=5)
                            ]
                            + make_caption_area(0)
                        )
                    ]
                ),
            ]
        ),
        dbc.Row(
            children=[
                dbc.Col(dbc.Button('<',style={'size':1},id='prev')),
                dbc.Col(dbc.Button('>',style={'size':1},id='next'))
            ]
        )
    ]
)

if __name__ == '__main__':
    app.run_server(debug=True)

このファイルをapp.py等の名前で保存し、ターミナルでpython app.pyのように実行してアプリを起動すると、ローカル環境のブラウザでアプリを表示することが出来ます。

先のスクリプトを実行した様子

アプリケーションに動作を加える

このままでは「<」「>」ボタンをクリックしても何も変化しません。「<」「>」をクリックしたとき、1つ前や1つ後のデータを表示し、表示する前に日本語のキャプションをJSONファイルに保存する処理が実行されるようにします。ここがこのアプリの一番の肝になる部分です。

現在表示中のデータの番号を記録するdcc.Storeコンポーネント

Dashではブラウザ内にデータを保存する仕組みがあります。それがdcc.Storeコンポーネントです。dcc.Storeコンポーネントを利用することで、JSON形式でデータをブラウザに保存することが出来ます。アプリ内でデータのやり取りをするのに使えますが、あまり大きなサイズのデータの保存には向かないようです。

dcc.Storeは格納したデータを消去するタイミングをstorage_typeパラメータで指定することが出来ます。今回はデフォルトの設定のmemoryというパラメータを指定しました。memoryの場合はデータはメモリ上に保存され、ページがリフレッシュされる度にリセットされます。

以下のように、他のコンポーネントと同じようにレイアウトに追加して使用します。

        ・・・
        dbc.Row(
            [
                dbc.Col(dbc.Button('<',style={'size':1},id='prev')),
                dbc.Col(dbc.Button('>',style={'size':1},id='next'))
            ]
        ),
        dcc.Store(id='now_data_num',storage_type='memory')
    ]
) 

「<」「>」をクリックしたときのCallback関数を定義する

「<」「>」をクリックしたときに呼び出される処理をCallback関数として定義します。 Callback関数は関数の前に@app.callbackデコレータを付けてあげます。

直前にクリックされたボタンが「<」か「>」なのかはdash.callback_context(dash.ctx)を使って判別することが出来ます。

@app.callback(
    Output('data_area','children'),
    Output('now_data','data'),
    Input('prev','n_clicks'),
    Input('next','n_clicks'),
    State('now_data','data'),
    #日本語キャプションのTextAreaの値を取得する
    [State({'type':'jp_tran','index':ALL},'value')]
)
def update_data_area(prev, next, now_data, jp_list):
    '''
    「<」「>」をクリックしたとき、画像とキャプションを1つ前か1つ後のものに差し替える。
    差し替える際には現在の日本語キャプションデータをJSONファイルに保存する。
    '''
    now_data = now_data or {'data_num': 0}
    target_data_num = now_data['data_num']
    image_id = annotation_data[target_data_num]['image_id']

    #日本語キャプションデータを{image_id}.JSONで保存する
    save_jp_caption_list(image_id, jp_list)

    # 直前にクリックされたボタンに応じて移動先を決定する
    if 'prev' == ctx.triggered_id:
        if now_data['data_num'] > 0:
            now_data['data_num'] = now_page['data_num'] - 1
    if 'next' == ctx.triggered_id:
        if now_data['data_num'] + 1 < len(annotation_data):
            now_data['data_num'] = now_data['data_num'] + 1

    return [
                dbc.Col(dcc.Graph(figure=get_image(now_data['data_num'])),width=3),
                dbc.Col(
                    [
                        dbc.Row(
                            [
                                dbc.Col(html.Em('英語'),width=4),
                                dbc.Col(html.Em('日本語'),width=5)
                           ]
                        ),
                    ] + make_caption_area(now_data['data_num'])
                )
    ], now_data

これで「<」「>」ボタンをクリックして表示するデータを変え、アノテーションの内容をJSONファイルに保存することが出来るようになりました。

まとめ

ということで、今回はPythonのWebアプリケーション構築ライブラリDashを使って簡単なアノテーションツールを自作した話をご紹介しました。Dashのコンポーネントは色々と揃っているので、とりあえず動くものを用意したい、といった時にはDashを使うことでなんらかの形で表現することが出来ると思います。

また何かDashでアプリを作ったら、この場でご紹介したいと思います!